Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ModelZoo
SenseNova-SI
Commits
876a36a4
Commit
876a36a4
authored
May 27, 2026
by
raojy
Browse files
first
parent
eda2afb8
Changes
175
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
5223 additions
and
0 deletions
+5223
-0
SenseNova-SI-main/examples/Q8_4.jpg
SenseNova-SI-main/examples/Q8_4.jpg
+0
-0
SenseNova-SI-main/examples/Q9.jpg
SenseNova-SI-main/examples/Q9.jpg
+0
-0
SenseNova-SI-main/examples/bagel-generate-example.jpg
SenseNova-SI-main/examples/bagel-generate-example.jpg
+0
-0
SenseNova-SI-main/examples/bagel-think_generate-example.jpg
SenseNova-SI-main/examples/bagel-think_generate-example.jpg
+0
-0
SenseNova-SI-main/pyproject.toml
SenseNova-SI-main/pyproject.toml
+69
-0
SenseNova-SI-main/sensenova_si/__init__.py
SenseNova-SI-main/sensenova_si/__init__.py
+36
-0
SenseNova-SI-main/sensenova_si/bagel.py
SenseNova-SI-main/sensenova_si/bagel.py
+406
-0
SenseNova-SI-main/sensenova_si/bagel_utils/__init__.py
SenseNova-SI-main/sensenova_si/bagel_utils/__init__.py
+0
-0
SenseNova-SI-main/sensenova_si/bagel_utils/data/__init__.py
SenseNova-SI-main/sensenova_si/bagel_utils/data/__init__.py
+2
-0
SenseNova-SI-main/sensenova_si/bagel_utils/data/data_utils.py
...eNova-SI-main/sensenova_si/bagel_utils/data/data_utils.py
+198
-0
SenseNova-SI-main/sensenova_si/bagel_utils/data/transforms.py
...eNova-SI-main/sensenova_si/bagel_utils/data/transforms.py
+306
-0
SenseNova-SI-main/sensenova_si/bagel_utils/inferencer.py
SenseNova-SI-main/sensenova_si/bagel_utils/inferencer.py
+362
-0
SenseNova-SI-main/sensenova_si/bagel_utils/modeling/__init__.py
...ova-SI-main/sensenova_si/bagel_utils/modeling/__init__.py
+4
-0
SenseNova-SI-main/sensenova_si/bagel_utils/modeling/autoencoder.py
...-SI-main/sensenova_si/bagel_utils/modeling/autoencoder.py
+386
-0
SenseNova-SI-main/sensenova_si/bagel_utils/modeling/bagel/__init__.py
...-main/sensenova_si/bagel_utils/modeling/bagel/__init__.py
+17
-0
SenseNova-SI-main/sensenova_si/bagel_utils/modeling/bagel/bagel.py
...-SI-main/sensenova_si/bagel_utils/modeling/bagel/bagel.py
+1215
-0
SenseNova-SI-main/sensenova_si/bagel_utils/modeling/bagel/modeling_utils.py
...sensenova_si/bagel_utils/modeling/bagel/modeling_utils.py
+153
-0
SenseNova-SI-main/sensenova_si/bagel_utils/modeling/bagel/qwen2_navit.py
...in/sensenova_si/bagel_utils/modeling/bagel/qwen2_navit.py
+1457
-0
SenseNova-SI-main/sensenova_si/bagel_utils/modeling/bagel/siglip_navit.py
...n/sensenova_si/bagel_utils/modeling/bagel/siglip_navit.py
+419
-0
SenseNova-SI-main/sensenova_si/bagel_utils/modeling/cache_utils/taylorseer.py
...nsenova_si/bagel_utils/modeling/cache_utils/taylorseer.py
+193
-0
No files found.
SenseNova-SI-main/examples/Q8_4.jpg
0 → 100644
View file @
876a36a4
58.5 KB
SenseNova-SI-main/examples/Q9.jpg
0 → 100644
View file @
876a36a4
226 KB
SenseNova-SI-main/examples/bagel-generate-example.jpg
0 → 100644
View file @
876a36a4
86.4 KB
SenseNova-SI-main/examples/bagel-think_generate-example.jpg
0 → 100644
View file @
876a36a4
66.5 KB
SenseNova-SI-main/pyproject.toml
0 → 100644
View file @
876a36a4
[project]
name
=
"SenseNova-SI"
version
=
"0.1.0"
description
=
"Scaling Spatial Intelligence with Multimodal Foundation Models"
readme
=
"README.md"
requires-python
=
">=3.11"
keywords
=
[
"computer vision"
,
"multimodal"
,
"spatial intelligence"
,
"MLLM"
]
dependencies
=
[
"transformers>=4.57.0"
,
"Pillow"
,
"numpy"
,
"setuptools"
,
"einops>=0.8.1"
,
"timm>=1.0.22"
,
"accelerate>=1.11.0"
,
"opencv-python>=4.11.0.86"
,
]
[dependency-groups]
flash-attn
=
["flash-attn==2.7.4.post1"]
dev
=
["ruff==0.14.4"]
[project.optional-dependencies]
cu118
=
[
"torch>=2.4.0"
,
"torchvision"
]
cu121
=
[
"torch>=2.4.0"
,
"torchvision"
]
cu124
=
[
"torch>=2.4.0"
,
"torchvision"
]
cu126
=
[
"torch>=2.4.0"
,
"torchvision"
]
cu128
=
[
"torch>=2.4.0"
,
"torchvision"
]
cu129
=
[
"torch>=2.4.0"
,
"torchvision"
]
[tool.uv]
default-groups
=
["flash-attn"]
no-build-isolation-package
=
[
'flash-attn'
,
'setuptools'
]
conflicts
=
[
[
{
extra
=
"cu118"
}
,
{
extra
=
"cu121"
}
,
{
extra
=
"cu124"
}
,
{
extra
=
"cu126"
}
,
{
extra
=
"cu128"
}
,
{
extra
=
"cu129"
}
,
],
]
index
=
[
{
name
=
"pytorch-cu118"
,
url
=
"https://download.pytorch.org/whl/cu118"
,
explicit
=
true
}
,
{
name
=
"pytorch-cu121"
,
url
=
"https://download.pytorch.org/whl/cu121"
,
explicit
=
true
}
,
{
name
=
"pytorch-cu124"
,
url
=
"https://download.pytorch.org/whl/cu124"
,
explicit
=
true
}
,
{
name
=
"pytorch-cu126"
,
url
=
"https://download.pytorch.org/whl/cu126"
,
explicit
=
true
}
,
{
name
=
"pytorch-cu128"
,
url
=
"https://download.pytorch.org/whl/cu128"
,
explicit
=
true
}
,
{
name
=
"pytorch-cu129"
,
url
=
"https://download.pytorch.org/whl/cu129"
,
explicit
=
true
}
,
]
[tool.uv.sources]
torch
=
[
{
index
=
"pytorch-cu118"
,
extra
=
"cu118"
}
,
{
index
=
"pytorch-cu121"
,
extra
=
"cu121"
}
,
{
index
=
"pytorch-cu124"
,
extra
=
"cu124"
}
,
{
index
=
"pytorch-cu126"
,
extra
=
"cu126"
}
,
{
index
=
"pytorch-cu128"
,
extra
=
"cu128"
}
,
{
index
=
"pytorch-cu129"
,
extra
=
"cu129"
}
,
]
torchvision
=
[
{
index
=
"pytorch-cu118"
,
extra
=
"cu118"
}
,
{
index
=
"pytorch-cu121"
,
extra
=
"cu121"
}
,
{
index
=
"pytorch-cu124"
,
extra
=
"cu124"
}
,
{
index
=
"pytorch-cu126"
,
extra
=
"cu126"
}
,
{
index
=
"pytorch-cu128"
,
extra
=
"cu128"
}
,
{
index
=
"pytorch-cu129"
,
extra
=
"cu129"
}
,
]
SenseNova-SI-main/sensenova_si/__init__.py
0 → 100644
View file @
876a36a4
from
.bagel
import
SenseNovaSIBagelModel
from
.internvl
import
SenseNovaSIInternVLModel
from
.qwen
import
SenseNovaSIQwenModel
def
get_default_model_type
(
model_path
):
if
"qwen"
in
model_path
.
lower
():
return
"qwen"
elif
"internvl"
in
model_path
.
lower
():
return
"internvl"
elif
"bagel"
in
model_path
.
lower
():
return
"bagel"
else
:
raise
ValueError
(
f
"Unknown model type for
{
model_path
}
"
)
def
get_model
(
model_path
,
model_type
=
"auto"
):
if
model_type
==
"auto"
:
model_type
=
get_default_model_type
(
model_path
)
if
model_type
==
"qwen"
:
return
SenseNovaSIQwenModel
(
model_path
)
elif
model_type
==
"internvl"
:
return
SenseNovaSIInternVLModel
(
model_path
)
elif
model_type
==
"bagel"
:
return
SenseNovaSIBagelModel
(
model_path
)
else
:
raise
ValueError
(
f
"Unknown model type:
{
model_type
}
"
)
__all__
=
[
"get_default_model_type"
,
"get_model"
,
"SenseNovaSIInternVLModel"
,
"SenseNovaSIQwenModel"
,
"SenseNovaSIBagelModel"
,
]
SenseNova-SI-main/sensenova_si/bagel.py
0 → 100644
View file @
876a36a4
import
os
import
uuid
from
datetime
import
datetime
from
pathlib
import
Path
from
typing
import
Any
,
Dict
,
Optional
import
torch
from
accelerate
import
(
infer_auto_device_map
,
init_empty_weights
,
load_checkpoint_and_dispatch
,
)
from
huggingface_hub
import
snapshot_download
from
PIL
import
Image
from
.bagel_utils.data.transforms
import
ImageTransform
from
.bagel_utils.inferencer
import
InterleaveInferencer
from
.bagel_utils.modeling.autoencoder
import
load_ae
from
.bagel_utils.modeling.bagel
import
(
Bagel
,
BagelConfig
,
Qwen2Config
,
Qwen2ForCausalLM
,
SiglipVisionConfig
,
SiglipVisionModel
,
)
from
.bagel_utils.modeling.qwen2
import
Qwen2Tokenizer
from
.model
import
Model
from
.utils
import
add_special_tokens
BASE_PARAMS
:
Dict
[
str
,
Dict
[
str
,
Any
]]
=
{
"generate"
:
dict
(
cfg_text_scale
=
4.0
,
cfg_img_scale
=
1.0
,
cfg_interval
=
[
0.4
,
1.0
],
timestep_shift
=
3.0
,
num_timesteps
=
50
,
cfg_renorm_min
=
1.0
,
cfg_renorm_type
=
"global"
,
),
"think_generate"
:
dict
(
max_think_token_n
=
1000
,
do_sample
=
False
,
cfg_text_scale
=
4.0
,
cfg_img_scale
=
1.0
,
cfg_interval
=
[
0.4
,
1.0
],
timestep_shift
=
3.0
,
num_timesteps
=
50
,
cfg_renorm_min
=
1.0
,
cfg_renorm_type
=
"global"
,
think
=
True
,
),
"edit"
:
dict
(
cfg_text_scale
=
4.0
,
cfg_img_scale
=
2.0
,
cfg_interval
=
[
0.0
,
1.0
],
timestep_shift
=
3.0
,
num_timesteps
=
50
,
cfg_renorm_min
=
0.0
,
cfg_renorm_type
=
"text_channel"
,
),
"think_edit"
:
dict
(
max_think_token_n
=
1000
,
do_sample
=
False
,
cfg_text_scale
=
4.0
,
cfg_img_scale
=
2.0
,
cfg_interval
=
[
0.0
,
1.0
],
timestep_shift
=
3.0
,
num_timesteps
=
50
,
cfg_renorm_min
=
0.0
,
cfg_renorm_type
=
"text_channel"
,
think
=
True
,
),
"understanding"
:
dict
(
max_think_token_n
=
1000
,
do_sample
=
False
,
understanding_output
=
True
,
),
"think_understanding"
:
dict
(
max_think_token_n
=
1000
,
do_sample
=
False
,
understanding_output
=
True
,
think
=
True
,
),
}
class
SenseNovaSIBagelModel
(
Model
):
def
__init__
(
self
,
model_path
=
"sensenova/SenseNova-SI-1.1-BAGEL-7B-MoT"
,
generation_config
:
dict
[
str
,
Any
]
|
str
|
os
.
PathLike
|
None
=
None
,
mode
=
"understanding"
,
out_img_dir
=
"./output_images/test_bagel/"
,
dtype
:
str
=
"bf16"
,
):
super
().
__init__
(
generation_config
)
# 1. Parse params
self
.
precision
=
dtype
if
os
.
path
.
exists
(
model_path
):
cache_path
=
model_path
else
:
cache_path
=
snapshot_download
(
repo_id
=
model_path
)
self
.
model_path
=
cache_path
self
.
checkpoint_path
=
os
.
path
.
join
(
self
.
model_path
,
"model.safetensors"
)
# Bagel mode
env_mode
=
os
.
getenv
(
"BAGEL_MODE"
)
mode
=
env_mode
.
strip
()
if
env_mode
and
env_mode
.
strip
()
else
mode
if
mode
not
in
BASE_PARAMS
:
raise
ValueError
(
f
"Invalid mode '
{
mode
}
'. "
f
"Bagel Supported modes:
{
list
(
BASE_PARAMS
.
keys
())
}
"
)
self
.
mode
=
mode
env_out_img_dir
=
os
.
getenv
(
"BAGEL_OUT_IMG_DIR"
)
self
.
out_img_dir
=
(
env_out_img_dir
.
strip
()
if
env_out_img_dir
and
env_out_img_dir
.
strip
()
else
out_img_dir
)
msg
=
(
f
"[Bagel] mode = '
{
self
.
mode
}
' "
f
"(can be overridden with env var BAGEL_MODE); "
f
"out_img_dir = '
{
self
.
out_img_dir
}
' "
f
"(can be overridden with env var BAGEL_OUT_IMG_DIR)"
)
print
(
msg
)
# 2. Build model
model
,
vae_model
,
tokenizer
,
new_token_ids
,
vit_transform
,
vae_transform
=
(
self
.
_build_model
()
)
# 3. Load Checkpoint
model
=
self
.
_load_model_weights
(
model
)
# 4. Build inferencer
self
.
tokenizer
=
tokenizer
self
.
new_token_ids
=
new_token_ids
self
.
vit_transform
=
vit_transform
self
.
inferencer
=
InterleaveInferencer
(
model
=
model
,
vae_model
=
vae_model
,
tokenizer
=
tokenizer
,
vae_transform
=
vae_transform
,
vit_transform
=
vit_transform
,
new_token_ids
=
new_token_ids
,
)
torch
.
cuda
.
empty_cache
()
def
_build_model
(
self
):
# build llm config
llm_config
=
Qwen2Config
.
from_json_file
(
os
.
path
.
join
(
self
.
model_path
,
"llm_config.json"
)
)
llm_config
.
qk_norm
=
True
llm_config
.
tie_word_embeddings
=
False
llm_config
.
layer_module
=
"Qwen2MoTDecoderLayer"
# build vit config
vit_config
=
SiglipVisionConfig
.
from_json_file
(
os
.
path
.
join
(
self
.
model_path
,
"vit_config.json"
)
)
vit_config
.
rope
=
False
vit_config
.
num_hidden_layers
-=
1
vit_transform
=
ImageTransform
(
980
,
224
,
14
)
vae_transform
=
ImageTransform
(
1024
,
512
,
16
)
# build vae config
vae_model
,
vae_config
=
load_ae
(
local_path
=
os
.
path
.
join
(
self
.
model_path
,
"ae.safetensors"
)
)
# build tokenizer
tokenizer
=
Qwen2Tokenizer
.
from_pretrained
(
self
.
model_path
)
tokenizer
,
new_token_ids
,
_
=
add_special_tokens
(
tokenizer
)
# build model
model_config
=
BagelConfig
(
visual_gen
=
True
,
visual_und
=
True
,
llm_config
=
llm_config
,
vit_config
=
vit_config
,
vae_config
=
vae_config
,
latent_patch_size
=
2
,
max_latent_size
=
64
,
vit_max_num_patch_per_side
=
70
,
connector_act
=
"gelu_pytorch_tanh"
,
)
with
init_empty_weights
():
language_model
=
Qwen2ForCausalLM
(
llm_config
)
vit_model
=
SiglipVisionModel
(
vit_config
)
model
=
Bagel
(
language_model
,
vit_model
,
model_config
)
model
.
vit_model
.
vision_model
.
embeddings
.
convert_conv2d_to_linear
(
vit_config
)
return
model
,
vae_model
,
tokenizer
,
new_token_ids
,
vit_transform
,
vae_transform
def
_load_model_weights
(
self
,
model
):
device_map
=
infer_auto_device_map
(
model
,
no_split_module_classes
=
[
"Bagel"
,
"Qwen2MoTDecoderLayer"
]
)
same_device_modules
=
[
"language_model.model.embed_tokens"
,
"time_embedder"
,
"latent_pos_embed"
,
"vae2llm"
,
"llm2vae"
,
"connector"
,
"vit_pos_embed"
,
]
if
torch
.
cuda
.
device_count
()
==
1
:
first_device
=
device_map
.
get
(
same_device_modules
[
0
],
"cuda:0"
)
for
k
in
same_device_modules
:
if
k
in
device_map
:
device_map
[
k
]
=
first_device
else
:
device_map
[
k
]
=
"cuda:0"
else
:
first_device
=
device_map
.
get
(
same_device_modules
[
0
])
for
k
in
same_device_modules
:
if
k
in
device_map
:
device_map
[
k
]
=
first_device
if
self
.
precision
==
"bf16"
:
model
=
load_checkpoint_and_dispatch
(
model
,
checkpoint
=
self
.
checkpoint_path
,
device_map
=
device_map
,
offload_buffers
=
True
,
offload_folder
=
"offload"
,
dtype
=
torch
.
bfloat16
,
force_hooks
=
True
,
).
eval
()
elif
self
.
precision
==
"nf4"
:
from
accelerate.utils
import
BnbQuantizationConfig
,
load_and_quantize_model
model
=
load_and_quantize_model
(
model
,
weights_location
=
self
.
checkpoint_path
,
bnb_quantization_config
=
BnbQuantizationConfig
(
load_in_4bit
=
True
,
bnb_4bit_compute_dtype
=
torch
.
bfloat16
,
bnb_4bit_use_double_quant
=
False
,
bnb_4bit_quant_type
=
"nf4"
,
),
device_map
=
device_map
,
offload_folder
=
"offload"
,
).
eval
()
elif
self
.
precision
==
"int8"
:
from
accelerate.utils
import
BnbQuantizationConfig
,
load_and_quantize_model
model
=
load_and_quantize_model
(
model
,
weights_location
=
self
.
checkpoint_path
,
bnb_quantization_config
=
BnbQuantizationConfig
(
load_in_8bit
=
True
,
torch_dtype
=
torch
.
float32
),
device_map
=
device_map
,
offload_folder
=
"offload"
,
).
eval
()
else
:
raise
NotImplementedError
(
f
"Unsupported precision:
{
self
.
precision
}
"
)
return
model
def
_save_output_image
(
self
,
image
:
Image
.
Image
,
mode
:
str
,
img_path
:
Optional
[
str
],
)
->
str
:
if
image
is
None
:
raise
ValueError
(
f
"[OutputError] Mode=
{
mode
}
expected an image output, but got None."
)
root
=
Path
(
self
.
out_img_dir
)
images_root
=
root
/
(
f
"images"
)
images_root
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
if
mode
in
[
"edit"
,
"think_edit"
]:
if
img_path
:
src
=
Path
(
img_path
)
parent_name
=
src
.
parent
.
name
or
"default"
out_dir
=
images_root
/
parent_name
out_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
filename
=
src
.
name
else
:
out_dir
=
images_root
/
"edit"
out_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
ts
=
datetime
.
now
().
strftime
(
"%Y%m%d-%H%M%S-%f"
)
base
=
"sample"
filename
=
f
"
{
base
}
_edit_
{
ts
}
_
{
uuid
.
uuid4
().
hex
[:
8
]
}
.jpg"
out_path
=
out_dir
/
filename
elif
mode
in
[
"generate"
,
"think_generate"
]:
out_dir
=
images_root
out_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
ts
=
datetime
.
now
().
strftime
(
"%Y%m%d-%H%M%S-%f"
)
base
=
"sample"
filename
=
f
"
{
base
}
_
{
ts
}
_
{
uuid
.
uuid4
().
hex
[:
8
]
}
.jpg"
out_path
=
out_dir
/
filename
else
:
raise
ValueError
(
f
"[OutputError] Unexpected mode for image saving:
{
mode
}
"
)
image
.
save
(
out_path
)
return
str
(
out_path
)
def
generate
(
self
,
question
:
str
,
images
:
list
[
str
]
|
None
=
None
,
**
kwargs
):
mode
=
self
.
mode
images
=
images
or
[]
# Auto-prepend <image> placeholders if the question doesn't contain them
existing_count
=
question
.
count
(
"<image>"
)
if
images
and
existing_count
==
0
:
question
=
""
.
join
([
"<image>
\n
"
for
_
in
images
])
+
question
text_parts
=
question
.
split
(
"<image>"
)
if
len
(
text_parts
)
!=
len
(
images
)
+
1
:
raise
ValueError
(
f
"Text iamge tokens and number of images not match! "
)
input_lists
=
[]
input_img_paths
=
[]
for
i
,
part
in
enumerate
(
text_parts
):
text
=
part
.
strip
()
if
text
:
input_lists
.
append
(
text
)
if
i
<
len
(
images
):
img_path
=
images
[
i
]
try
:
image
=
Image
.
open
(
img_path
)
input_lists
.
append
(
image
)
input_img_paths
.
append
(
img_path
)
except
Exception
as
e
:
raise
RuntimeError
(
f
"Can not load image
{
img_path
}
:
{
e
}
"
)
from
e
params
=
dict
(
BASE_PARAMS
[
mode
])
understanding_output_flag
=
params
.
pop
(
"understanding_output"
,
False
)
think_flag
=
params
.
pop
(
"think"
,
False
)
res
=
self
.
inferencer
.
interleave_inference
(
input_lists
=
input_lists
,
think
=
think_flag
,
understanding_output
=
understanding_output_flag
,
**
params
,
)
ret
=
{
"image"
:
[],
"text"
:
[]}
for
i
in
res
:
if
isinstance
(
i
,
Image
.
Image
):
ret
[
"image"
].
append
(
i
)
elif
isinstance
(
i
,
str
):
ret
[
"text"
].
append
(
i
)
img_cnt
,
txt_cnt
=
len
(
ret
[
"image"
]),
len
(
ret
[
"text"
])
if
img_cnt
+
txt_cnt
!=
1
:
print
(
f
"[Warning] You are using
{
mode
}
mode, so the output has
{
img_cnt
}
images and
{
txt_cnt
}
texts"
)
if
txt_cnt
>
0
:
print
(
f
"[Warning] The text output is:
{
ret
[
'text'
][
0
]
}
"
)
ret
[
"image"
]
=
ret
[
"image"
][
0
]
if
img_cnt
else
None
ret
[
"text"
]
=
ret
[
"text"
][
0
]
if
txt_cnt
else
None
if
mode
in
[
"edit"
,
"think_edit"
,
"generate"
,
"think_generate"
]:
if
ret
[
"image"
]
is
not
None
:
if
len
(
input_img_paths
)
==
1
:
ref_img_path
=
input_img_paths
[
0
]
else
:
ref_img_path
=
None
img_path_out
=
self
.
_save_output_image
(
image
=
ret
[
"image"
],
mode
=
mode
,
img_path
=
ref_img_path
,
)
ret
[
"image"
]
=
img_path_out
res
=
img_path_out
else
:
res
=
None
else
:
res
=
ret
[
"text"
]
return
res
SenseNova-SI-main/sensenova_si/bagel_utils/__init__.py
0 → 100644
View file @
876a36a4
SenseNova-SI-main/sensenova_si/bagel_utils/data/__init__.py
0 → 100644
View file @
876a36a4
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
SenseNova-SI-main/sensenova_si/bagel_utils/data/data_utils.py
0 → 100644
View file @
876a36a4
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
import
math
import
random
import
torch
from
PIL
import
Image
# from torch.nn.attention.flex_attention import or_masks, and_masks
def
create_sparse_mask
(
document_lens
,
split_lens
,
attn_modes
,
device
):
def
causal_mask
(
b
,
h
,
q_idx
,
kv_idx
):
return
q_idx
>=
kv_idx
def
full_and_noise_mask
(
b
,
h
,
q_idx
,
kv_idx
):
return
(
full_and_noise_seq_id
[
q_idx
]
==
full_and_noise_seq_id
[
kv_idx
])
&
(
full_and_noise_seq_id
[
q_idx
]
>=
0
)
def
remove_noise_mask
(
b
,
h
,
q_idx
,
kv_idx
):
return
~
(
(
noise_seq_id
[
kv_idx
]
>=
0
)
&
(
noise_seq_id
[
q_idx
]
!=
noise_seq_id
[
kv_idx
])
)
def
sample_mask
(
b
,
h
,
q_idx
,
kv_idx
):
return
document_id
[
q_idx
]
==
document_id
[
kv_idx
]
full_and_noise_tmp
=
[]
noise_tmp
=
[]
for
i
,
(
length
,
model
)
in
enumerate
(
zip
(
split_lens
,
attn_modes
)):
value
=
i
if
model
in
[
"full"
,
"noise"
]
else
-
1
full_and_noise_tmp
.
extend
([
value
]
*
length
)
value_noise
=
i
if
model
==
"noise"
else
-
1
noise_tmp
.
extend
([
value_noise
]
*
length
)
full_and_noise_seq_id
=
torch
.
Tensor
(
full_and_noise_tmp
).
to
(
device
)
noise_seq_id
=
torch
.
Tensor
(
noise_tmp
).
to
(
device
)
document_id
=
torch
.
cat
(
[
torch
.
full
((
l
,),
i
)
for
i
,
l
in
enumerate
(
document_lens
,
start
=
1
)]
).
to
(
device
)
return
and_masks
(
or_masks
(
causal_mask
,
full_and_noise_mask
),
remove_noise_mask
,
sample_mask
)
def
patchify
(
image
,
patch_size
):
p
=
patch_size
c
,
h
,
w
=
image
.
shape
assert
h
%
p
==
0
and
w
%
p
==
0
image
=
image
.
reshape
(
c
,
h
//
p
,
p
,
w
//
p
,
p
)
image
=
torch
.
einsum
(
"chpwq->hwpqc"
,
image
)
image
=
image
.
reshape
(
-
1
,
p
**
2
*
c
)
return
image
def
get_flattened_position_ids_extrapolate
(
img_h
,
img_w
,
patch_size
,
max_num_patches_per_side
):
num_patches_h
,
num_patches_w
=
img_h
//
patch_size
,
img_w
//
patch_size
coords_h
=
torch
.
arange
(
0
,
num_patches_h
)
coords_w
=
torch
.
arange
(
0
,
num_patches_w
)
pos_ids
=
(
coords_h
[:,
None
]
*
max_num_patches_per_side
+
coords_w
).
flatten
()
return
pos_ids
def
get_flattened_position_ids_interpolate
(
img_h
,
img_w
,
patch_size
,
max_num_patches_per_side
):
num_patches_h
,
num_patches_w
=
img_h
//
patch_size
,
img_w
//
patch_size
boundaries
=
torch
.
arange
(
1
/
max_num_patches_per_side
,
1.0
,
1
/
max_num_patches_per_side
)
fractional_coords_h
=
torch
.
arange
(
0
,
1
-
1e-6
,
1
/
num_patches_h
)
fractional_coords_w
=
torch
.
arange
(
0
,
1
-
1e-6
,
1
/
num_patches_w
)
bucket_coords_h
=
torch
.
bucketize
(
fractional_coords_h
,
boundaries
,
right
=
True
)
bucket_coords_w
=
torch
.
bucketize
(
fractional_coords_w
,
boundaries
,
right
=
True
)
pos_ids
=
(
bucket_coords_h
[:,
None
]
*
max_num_patches_per_side
+
bucket_coords_w
).
flatten
()
return
pos_ids
def
prepare_attention_mask_per_sample
(
split_lens
,
attn_modes
,
device
=
"cpu"
):
"""
nested_split_lens: A list of N lists of ints. Each int indicates the length of a split within
a sample, where each sample contains multiple splits with different attn modes.
nested_attn_modes: whether to use full attn in each split.
"""
sample_len
=
sum
(
split_lens
)
attention_mask
=
torch
.
zeros
(
(
sample_len
,
sample_len
),
dtype
=
torch
.
bool
,
device
=
device
)
csum
=
0
for
s
,
attn_mode
in
zip
(
split_lens
,
attn_modes
):
assert
attn_mode
in
[
"causal"
,
"full"
,
"noise"
]
if
attn_mode
==
"causal"
:
attention_mask
[
csum
:
csum
+
s
,
csum
:
csum
+
s
]
=
torch
.
ones
(
(
s
,
s
),
device
=
device
).
tril
()
attention_mask
[
csum
:
csum
+
s
,
:
csum
]
=
1
else
:
attention_mask
[
csum
:
csum
+
s
,
csum
:
csum
+
s
]
=
torch
.
ones
((
s
,
s
))
attention_mask
[
csum
:
csum
+
s
,
:
csum
]
=
1
csum
+=
s
csum
=
0
for
s
,
attn_mode
in
zip
(
split_lens
,
attn_modes
):
if
attn_mode
==
"noise"
:
attention_mask
[:,
csum
:
csum
+
s
]
=
torch
.
zeros
((
sample_len
,
s
))
attention_mask
[
csum
:
csum
+
s
,
csum
:
csum
+
s
]
=
torch
.
ones
((
s
,
s
))
csum
+=
s
attention_mask
=
torch
.
zeros_like
(
attention_mask
,
dtype
=
torch
.
float
).
masked_fill_
(
~
attention_mask
,
float
(
"-inf"
)
)
return
attention_mask
def
split_integer_exp_decay
(
S
,
ng_sample_decay
=
1.0
):
if
ng_sample_decay
==
1.0
:
N
=
random
.
randint
(
1
,
S
)
else
:
base
=
(
1
-
ng_sample_decay
)
/
(
1
-
math
.
pow
(
ng_sample_decay
,
S
))
p
=
[
base
*
math
.
pow
(
ng_sample_decay
,
i
)
for
i
in
range
(
S
)]
N
=
random
.
choices
(
list
(
range
(
1
,
S
+
1
)),
p
,
k
=
1
)[
0
]
cumsum
=
[
0
]
+
sorted
(
random
.
sample
(
range
(
1
,
S
),
N
-
1
))
+
[
S
]
result
=
[
cumsum
[
i
+
1
]
-
cumsum
[
i
]
for
i
in
range
(
len
(
cumsum
)
-
1
)]
return
result
,
cumsum
def
pil_img2rgb
(
image
):
if
image
.
mode
==
"RGBA"
or
image
.
info
.
get
(
"transparency"
,
None
)
is
not
None
:
image
=
image
.
convert
(
"RGBA"
)
white
=
Image
.
new
(
mode
=
"RGB"
,
size
=
image
.
size
,
color
=
(
255
,
255
,
255
))
white
.
paste
(
image
,
mask
=
image
.
split
()[
3
])
image
=
white
else
:
image
=
image
.
convert
(
"RGB"
)
return
image
def
add_special_tokens
(
tokenizer
):
all_special_tokens
=
[]
for
k
,
v
in
tokenizer
.
special_tokens_map
.
items
():
if
isinstance
(
v
,
str
):
all_special_tokens
.
append
(
v
)
elif
isinstance
(
v
,
list
):
all_special_tokens
+=
v
new_tokens
=
[]
if
"<|im_start|>"
not
in
all_special_tokens
:
new_tokens
.
append
(
"<|im_start|>"
)
if
"<|im_end|>"
not
in
all_special_tokens
:
new_tokens
.
append
(
"<|im_end|>"
)
if
"<|vision_start|>"
not
in
all_special_tokens
:
new_tokens
.
append
(
"<|vision_start|>"
)
if
"<|vision_end|>"
not
in
all_special_tokens
:
new_tokens
.
append
(
"<|vision_end|>"
)
num_new_tokens
=
tokenizer
.
add_tokens
(
new_tokens
)
bos_token_id
=
tokenizer
.
convert_tokens_to_ids
(
"<|im_start|>"
)
eos_token_id
=
tokenizer
.
convert_tokens_to_ids
(
"<|im_end|>"
)
start_of_image
=
tokenizer
.
convert_tokens_to_ids
(
"<|vision_start|>"
)
end_of_image
=
tokenizer
.
convert_tokens_to_ids
(
"<|vision_end|>"
)
new_token_ids
=
dict
(
bos_token_id
=
bos_token_id
,
eos_token_id
=
eos_token_id
,
start_of_image
=
start_of_image
,
end_of_image
=
end_of_image
,
)
return
tokenizer
,
new_token_ids
,
num_new_tokens
def
len2weight
(
x
,
loss_reduction
=
"square"
):
if
x
==
0
:
return
x
if
loss_reduction
==
"token"
:
return
1
if
loss_reduction
==
"sample"
:
return
1
/
x
if
loss_reduction
==
"square"
:
return
1
/
(
x
**
0.5
)
raise
NotImplementedError
(
loss_reduction
)
SenseNova-SI-main/sensenova_si/bagel_utils/data/transforms.py
0 → 100644
View file @
876a36a4
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
import
random
import
cv2
import
numpy
as
np
import
torch
from
PIL
import
Image
from
torchvision
import
transforms
from
torchvision.transforms
import
InterpolationMode
from
torchvision.transforms
import
functional
as
F
class
MaxLongEdgeMinShortEdgeResize
(
torch
.
nn
.
Module
):
"""Resize the input image so that its longest side and shortest side are within a specified range,
ensuring that both sides are divisible by a specified stride.
Args:
max_size (int): Maximum size for the longest edge of the image.
min_size (int): Minimum size for the shortest edge of the image.
stride (int): Value by which the height and width of the image must be divisible.
max_pixels (int): Maximum pixels for the full image.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
``InterpolationMode.BILINEAR``, and ``InterpolationMode.BICUBIC`` are supported.
The corresponding Pillow integer constants, e.g., ``PIL.Image.BILINEAR`` are also accepted.
antialias (bool, optional): Whether to apply antialiasing (default is True).
"""
def
__init__
(
self
,
max_size
:
int
,
min_size
:
int
,
stride
:
int
,
max_pixels
:
int
,
interpolation
=
InterpolationMode
.
BICUBIC
,
antialias
=
True
,
):
super
().
__init__
()
self
.
max_size
=
max_size
self
.
min_size
=
min_size
self
.
stride
=
stride
self
.
max_pixels
=
max_pixels
self
.
interpolation
=
interpolation
self
.
antialias
=
antialias
def
_make_divisible
(
self
,
value
,
stride
):
"""Ensure the value is divisible by the stride."""
return
max
(
stride
,
int
(
round
(
value
/
stride
)
*
stride
))
def
_apply_scale
(
self
,
width
,
height
,
scale
):
new_width
=
round
(
width
*
scale
)
new_height
=
round
(
height
*
scale
)
new_width
=
self
.
_make_divisible
(
new_width
,
self
.
stride
)
new_height
=
self
.
_make_divisible
(
new_height
,
self
.
stride
)
return
new_width
,
new_height
def
forward
(
self
,
img
,
img_num
=
1
):
"""
Args:
img (PIL Image): Image to be resized.
img_num (int): Number of images, used to change max_tokens.
Returns:
PIL Image or Tensor: Rescaled image with divisible dimensions.
"""
if
isinstance
(
img
,
torch
.
Tensor
):
height
,
width
=
img
.
shape
[
-
2
:]
else
:
width
,
height
=
img
.
size
scale
=
min
(
self
.
max_size
/
max
(
width
,
height
),
1.0
)
scale
=
max
(
scale
,
self
.
min_size
/
min
(
width
,
height
))
new_width
,
new_height
=
self
.
_apply_scale
(
width
,
height
,
scale
)
# Ensure the number of pixels does not exceed max_pixels
if
new_width
*
new_height
>
self
.
max_pixels
/
img_num
:
scale
=
self
.
max_pixels
/
img_num
/
(
new_width
*
new_height
)
new_width
,
new_height
=
self
.
_apply_scale
(
new_width
,
new_height
,
scale
)
# Ensure longest edge does not exceed max_size
if
max
(
new_width
,
new_height
)
>
self
.
max_size
:
scale
=
self
.
max_size
/
max
(
new_width
,
new_height
)
new_width
,
new_height
=
self
.
_apply_scale
(
new_width
,
new_height
,
scale
)
return
F
.
resize
(
img
,
(
new_height
,
new_width
),
self
.
interpolation
,
antialias
=
self
.
antialias
)
class
ImageTransform
:
def
__init__
(
self
,
max_image_size
,
min_image_size
,
image_stride
,
max_pixels
=
14
*
14
*
9
*
1024
,
image_mean
=
[
0.5
,
0.5
,
0.5
],
image_std
=
[
0.5
,
0.5
,
0.5
],
):
self
.
stride
=
image_stride
self
.
resize_transform
=
MaxLongEdgeMinShortEdgeResize
(
max_size
=
max_image_size
,
min_size
=
min_image_size
,
stride
=
image_stride
,
max_pixels
=
max_pixels
,
)
self
.
to_tensor_transform
=
transforms
.
ToTensor
()
self
.
normalize_transform
=
transforms
.
Normalize
(
mean
=
image_mean
,
std
=
image_std
,
inplace
=
True
)
def
__call__
(
self
,
img
,
img_num
=
1
):
img
=
self
.
resize_transform
(
img
,
img_num
=
img_num
)
img
=
self
.
to_tensor_transform
(
img
)
img
=
self
.
normalize_transform
(
img
)
return
img
def
decolorization
(
image
):
gray_image
=
image
.
convert
(
"L"
)
return
(
Image
.
merge
(
image
.
mode
,
[
gray_image
]
*
3
)
if
image
.
mode
in
(
"RGB"
,
"L"
)
else
gray_image
)
def
downscale
(
image
,
scale_factor
):
new_width
=
int
(
round
(
image
.
width
*
scale_factor
))
new_height
=
int
(
round
(
image
.
height
*
scale_factor
))
new_width
=
max
(
1
,
new_width
)
new_height
=
max
(
1
,
new_height
)
return
image
.
resize
((
new_width
,
new_height
),
resample
=
Image
.
BICUBIC
)
def
crop
(
image
,
crop_factors
):
target_h
,
target_w
=
crop_factors
img_w
,
img_h
=
image
.
size
if
target_h
>
img_h
or
target_w
>
img_w
:
raise
ValueError
(
"Crop size exceeds image dimensions"
)
x
=
random
.
randint
(
0
,
img_w
-
target_w
)
y
=
random
.
randint
(
0
,
img_h
-
target_h
)
return
image
.
crop
((
x
,
y
,
x
+
target_w
,
y
+
target_h
)),
[
[
x
,
y
],
[
x
+
target_w
,
y
+
target_h
],
]
def
motion_blur_opencv
(
image
,
kernel_size
=
15
,
angle
=
0
):
# 线性核
kernel
=
np
.
zeros
((
kernel_size
,
kernel_size
),
dtype
=
np
.
float32
)
kernel
[
kernel_size
//
2
,
:]
=
np
.
ones
(
kernel_size
,
dtype
=
np
.
float32
)
# 旋转核
center
=
(
kernel_size
/
2
-
0.5
,
kernel_size
/
2
-
0.5
)
M
=
cv2
.
getRotationMatrix2D
(
center
,
angle
,
1
)
rotated_kernel
=
cv2
.
warpAffine
(
kernel
,
M
,
(
kernel_size
,
kernel_size
))
# 归一化核
rotated_kernel
/=
rotated_kernel
.
sum
()
if
rotated_kernel
.
sum
()
!=
0
else
1
img
=
np
.
array
(
image
)
if
img
.
ndim
==
2
:
blurred
=
cv2
.
filter2D
(
img
,
-
1
,
rotated_kernel
,
borderType
=
cv2
.
BORDER_REFLECT
)
else
:
# 对于彩色图像,各通道独立卷积
blurred
=
np
.
zeros_like
(
img
)
for
c
in
range
(
img
.
shape
[
2
]):
blurred
[...,
c
]
=
cv2
.
filter2D
(
img
[...,
c
],
-
1
,
rotated_kernel
,
borderType
=
cv2
.
BORDER_REFLECT
)
return
Image
.
fromarray
(
blurred
.
astype
(
np
.
uint8
))
def
shuffle_patch
(
image
,
num_splits
,
gap_size
=
2
):
"""将图像分割为块(允许尺寸不整除),随机打乱后拼接,块间保留间隙"""
h_splits
,
w_splits
=
num_splits
img_w
,
img_h
=
image
.
size
base_patch_h
=
img_h
//
h_splits
patch_heights
=
[
base_patch_h
]
*
(
h_splits
-
1
)
patch_heights
.
append
(
img_h
-
sum
(
patch_heights
))
base_patch_w
=
img_w
//
w_splits
patch_widths
=
[
base_patch_w
]
*
(
w_splits
-
1
)
patch_widths
.
append
(
img_w
-
sum
(
patch_widths
))
patches
=
[]
current_y
=
0
for
i
in
range
(
h_splits
):
current_x
=
0
patch_h
=
patch_heights
[
i
]
for
j
in
range
(
w_splits
):
patch_w
=
patch_widths
[
j
]
patch
=
image
.
crop
(
(
current_x
,
current_y
,
current_x
+
patch_w
,
current_y
+
patch_h
)
)
patches
.
append
(
patch
)
current_x
+=
patch_w
current_y
+=
patch_h
random
.
shuffle
(
patches
)
total_width
=
sum
(
patch_widths
)
+
(
w_splits
-
1
)
*
gap_size
total_height
=
sum
(
patch_heights
)
+
(
h_splits
-
1
)
*
gap_size
new_image
=
Image
.
new
(
image
.
mode
,
(
total_width
,
total_height
),
color
=
(
255
,
255
,
255
)
)
current_y
=
0
# 当前行的起始 Y 坐标
patch_idx
=
0
# 当前处理的块索引
for
i
in
range
(
h_splits
):
current_x
=
0
# 当前列的起始 X 坐标
patch_h
=
patch_heights
[
i
]
# 当前行块的高度
for
j
in
range
(
w_splits
):
# 取出打乱后的块
patch
=
patches
[
patch_idx
]
patch_w
=
patch_widths
[
j
]
# 当前列块的宽度
# 粘贴块(左上角坐标为 (current_x, current_y))
new_image
.
paste
(
patch
,
(
current_x
,
current_y
))
# 更新 X 坐标(下一个块的起始位置 = 当前块宽度 + 间隙)
current_x
+=
patch_w
+
gap_size
patch_idx
+=
1
# 更新 Y 坐标(下一行的起始位置 = 当前行高度 + 间隙)
current_y
+=
patch_h
+
gap_size
return
new_image
def
inpainting
(
image
,
num_splits
,
blank_ratio
=
0.3
,
blank_color
=
(
255
,
255
,
255
)):
"""
图像分割后随机空白部分patch,用于inpainting任务
参数:
image: PIL.Image 输入图像(RGB模式)
h_splits: int 行分割数(垂直方向分割块数)
w_splits: int 列分割数(水平方向分割块数)
blank_ratio: float 空白patch的比例(0~1)
blank_color: tuple 空白区域的颜色(RGB,如白色(255,255,255))
返回:
PIL.Image 处理后拼接的图像
"""
h_splits
,
w_splits
=
num_splits
img_w
,
img_h
=
image
.
size
base_patch_h
=
img_h
//
h_splits
patch_heights
=
[
base_patch_h
]
*
(
h_splits
-
1
)
patch_heights
.
append
(
img_h
-
sum
(
patch_heights
))
base_patch_w
=
img_w
//
w_splits
patch_widths
=
[
base_patch_w
]
*
(
w_splits
-
1
)
patch_widths
.
append
(
img_w
-
sum
(
patch_widths
))
patches
=
[]
current_y
=
0
for
i
in
range
(
h_splits
):
current_x
=
0
patch_h
=
patch_heights
[
i
]
for
j
in
range
(
w_splits
):
patch_w
=
patch_widths
[
j
]
patch
=
image
.
crop
(
(
current_x
,
current_y
,
current_x
+
patch_w
,
current_y
+
patch_h
)
)
patches
.
append
(
patch
)
current_x
+=
patch_w
current_y
+=
patch_h
total_patches
=
h_splits
*
w_splits
num_blank
=
int
(
total_patches
*
blank_ratio
)
num_blank
=
max
(
0
,
min
(
num_blank
,
total_patches
))
blank_indices
=
random
.
sample
(
range
(
total_patches
),
num_blank
)
processed_patches
=
[]
for
idx
,
patch
in
enumerate
(
patches
):
if
idx
in
blank_indices
:
blank_patch
=
Image
.
new
(
"RGB"
,
patch
.
size
,
color
=
blank_color
)
processed_patches
.
append
(
blank_patch
)
else
:
processed_patches
.
append
(
patch
)
# 创建结果图像(尺寸与原图一致)
result_image
=
Image
.
new
(
"RGB"
,
(
img_w
,
img_h
))
current_y
=
0
patch_idx
=
0
for
i
in
range
(
h_splits
):
current_x
=
0
patch_h
=
patch_heights
[
i
]
for
j
in
range
(
w_splits
):
# 取出处理后的patch
patch
=
processed_patches
[
patch_idx
]
patch_w
=
patch_widths
[
j
]
# 粘贴到原位置
result_image
.
paste
(
patch
,
(
current_x
,
current_y
))
current_x
+=
patch_w
patch_idx
+=
1
current_y
+=
patch_h
return
result_image
SenseNova-SI-main/sensenova_si/bagel_utils/inferencer.py
0 → 100644
View file @
876a36a4
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
from
copy
import
deepcopy
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Union
import
torch
from
PIL
import
Image
from
.data.data_utils
import
pil_img2rgb
from
.modeling.bagel.qwen2_navit
import
NaiveCache
VLM_THINK_SYSTEM_PROMPT
=
"""You should first think about the reasoning process in the mind and then provide the user with the answer.
The reasoning process is enclosed within <think> </think> tags, i.e. <think> reasoning process here </think> answer here"""
GEN_THINK_SYSTEM_PROMPT
=
"""You should first think about the planning process in the mind and then generate the image.
The planning process is enclosed within <think> </think> tags, i.e. <think> planning process here </think> image here"""
class
InterleaveInferencer
:
def
__init__
(
self
,
model
,
vae_model
,
tokenizer
,
vae_transform
,
vit_transform
,
new_token_ids
):
self
.
model
=
model
self
.
vae_model
=
vae_model
self
.
tokenizer
=
tokenizer
self
.
vae_transform
=
vae_transform
self
.
vit_transform
=
vit_transform
self
.
new_token_ids
=
new_token_ids
def
init_gen_context
(
self
):
gen_context
=
{
"kv_lens"
:
[
0
],
"ropes"
:
[
0
],
"past_key_values"
:
NaiveCache
(
self
.
model
.
config
.
llm_config
.
num_hidden_layers
),
}
return
gen_context
@
torch
.
no_grad
()
def
update_context_text
(
self
,
text
,
gen_context
):
# used for interleave data, currently only support 1 data inference,
past_key_values
=
gen_context
[
"past_key_values"
]
kv_lens
=
gen_context
[
"kv_lens"
]
ropes
=
gen_context
[
"ropes"
]
generation_input
,
kv_lens
,
ropes
=
self
.
model
.
prepare_prompts
(
curr_kvlens
=
kv_lens
,
curr_rope
=
ropes
,
prompts
=
[
text
],
tokenizer
=
self
.
tokenizer
,
new_token_ids
=
self
.
new_token_ids
,
)
past_key_values
=
self
.
model
.
forward_cache_update_text
(
past_key_values
,
**
generation_input
)
gen_context
[
"kv_lens"
]
=
kv_lens
gen_context
[
"ropes"
]
=
ropes
gen_context
[
"past_key_values"
]
=
past_key_values
return
gen_context
@
torch
.
no_grad
()
def
update_context_image
(
self
,
image
,
gen_context
,
vae
=
True
,
vit
=
True
):
# used for interleave data, currently only support 1 data inference,
assert
vae
or
vit
past_key_values
=
gen_context
[
"past_key_values"
]
kv_lens
=
gen_context
[
"kv_lens"
]
ropes
=
gen_context
[
"ropes"
]
if
vae
:
## update vae
generation_input
,
kv_lens
,
ropes
=
self
.
model
.
prepare_vae_images
(
curr_kvlens
=
kv_lens
,
curr_rope
=
ropes
,
images
=
[
image
],
transforms
=
self
.
vae_transform
,
new_token_ids
=
self
.
new_token_ids
,
)
past_key_values
=
self
.
model
.
forward_cache_update_vae
(
self
.
vae_model
,
past_key_values
,
**
generation_input
)
if
vit
:
## update vit
generation_input
,
kv_lens
,
ropes
=
self
.
model
.
prepare_vit_images
(
curr_kvlens
=
kv_lens
,
curr_rope
=
ropes
,
images
=
[
image
],
transforms
=
self
.
vit_transform
,
new_token_ids
=
self
.
new_token_ids
,
)
past_key_values
=
self
.
model
.
forward_cache_update_vit
(
past_key_values
,
**
generation_input
)
gen_context
[
"kv_lens"
]
=
kv_lens
gen_context
[
"ropes"
]
=
ropes
gen_context
[
"past_key_values"
]
=
past_key_values
return
gen_context
@
torch
.
no_grad
()
def
gen_image
(
self
,
image_shape
,
gen_context
,
cfg_text_scale
=
4.0
,
cfg_img_scale
=
1.5
,
cfg_text_precontext
=
None
,
cfg_img_precontext
=
None
,
cfg_interval
=
(
0.4
,
1.0
),
cfg_renorm_min
=
0.0
,
cfg_renorm_type
=
"global"
,
num_timesteps
=
50
,
timestep_shift
=
3.0
,
enable_taylorseer
=
False
,
):
# print(cfg_renorm_type)
past_key_values
=
gen_context
[
"past_key_values"
]
kv_lens
=
gen_context
[
"kv_lens"
]
ropes
=
gen_context
[
"ropes"
]
generation_input
=
self
.
model
.
prepare_vae_latent
(
curr_kvlens
=
kv_lens
,
curr_rope
=
ropes
,
image_sizes
=
[
image_shape
],
new_token_ids
=
self
.
new_token_ids
,
)
# text cfg
cfg_text_past_key_values
=
cfg_text_precontext
[
"past_key_values"
]
kv_lens_cfg
=
cfg_text_precontext
[
"kv_lens"
]
ropes_cfg
=
cfg_text_precontext
[
"ropes"
]
generation_input_cfg_text
=
self
.
model
.
prepare_vae_latent_cfg
(
curr_kvlens
=
kv_lens_cfg
,
curr_rope
=
ropes_cfg
,
image_sizes
=
[
image_shape
],
)
# img cfg
cfg_img_past_key_values
=
cfg_img_precontext
[
"past_key_values"
]
kv_lens_cfg
=
cfg_img_precontext
[
"kv_lens"
]
ropes_cfg
=
cfg_img_precontext
[
"ropes"
]
generation_input_cfg_img
=
self
.
model
.
prepare_vae_latent_cfg
(
curr_kvlens
=
kv_lens_cfg
,
curr_rope
=
ropes_cfg
,
image_sizes
=
[
image_shape
],
)
unpacked_latent
=
self
.
model
.
generate_image
(
past_key_values
=
past_key_values
,
cfg_text_past_key_values
=
cfg_text_past_key_values
,
cfg_img_past_key_values
=
cfg_img_past_key_values
,
num_timesteps
=
num_timesteps
,
cfg_text_scale
=
cfg_text_scale
,
cfg_img_scale
=
cfg_img_scale
,
cfg_interval
=
cfg_interval
,
cfg_renorm_min
=
cfg_renorm_min
,
cfg_renorm_type
=
cfg_renorm_type
,
timestep_shift
=
timestep_shift
,
**
generation_input
,
cfg_text_packed_position_ids
=
generation_input_cfg_text
[
"cfg_packed_position_ids"
],
cfg_text_packed_query_indexes
=
generation_input_cfg_text
[
"cfg_packed_query_indexes"
],
cfg_text_key_values_lens
=
generation_input_cfg_text
[
"cfg_key_values_lens"
],
cfg_text_packed_key_value_indexes
=
generation_input_cfg_text
[
"cfg_packed_key_value_indexes"
],
cfg_img_packed_position_ids
=
generation_input_cfg_img
[
"cfg_packed_position_ids"
],
cfg_img_packed_query_indexes
=
generation_input_cfg_img
[
"cfg_packed_query_indexes"
],
cfg_img_key_values_lens
=
generation_input_cfg_img
[
"cfg_key_values_lens"
],
cfg_img_packed_key_value_indexes
=
generation_input_cfg_img
[
"cfg_packed_key_value_indexes"
],
enable_taylorseer
=
enable_taylorseer
,
)
image
=
self
.
decode_image
(
unpacked_latent
[
0
],
image_shape
)
return
image
def
decode_image
(
self
,
latent
,
image_shape
):
H
,
W
=
image_shape
h
,
w
=
H
//
self
.
model
.
latent_downsample
,
W
//
self
.
model
.
latent_downsample
latent
=
latent
.
reshape
(
1
,
h
,
w
,
self
.
model
.
latent_patch_size
,
self
.
model
.
latent_patch_size
,
self
.
model
.
latent_channel
,
)
latent
=
torch
.
einsum
(
"nhwpqc->nchpwq"
,
latent
)
latent
=
latent
.
reshape
(
1
,
self
.
model
.
latent_channel
,
h
*
self
.
model
.
latent_patch_size
,
w
*
self
.
model
.
latent_patch_size
,
)
image
=
self
.
vae_model
.
decode
(
latent
)
image
=
(
image
*
0.5
+
0.5
).
clamp
(
0
,
1
)[
0
].
permute
(
1
,
2
,
0
)
*
255
image
=
Image
.
fromarray
((
image
).
to
(
torch
.
uint8
).
cpu
().
numpy
())
return
image
@
torch
.
no_grad
()
def
gen_text
(
self
,
gen_context
,
max_length
:
int
=
500
,
do_sample
:
bool
=
True
,
temperature
:
float
=
1.0
,
):
gen_context
=
deepcopy
(
gen_context
)
past_key_values
=
gen_context
[
"past_key_values"
]
kv_lens
=
gen_context
[
"kv_lens"
]
ropes
=
gen_context
[
"ropes"
]
generation_input
=
self
.
model
.
prepare_start_tokens
(
kv_lens
,
ropes
,
self
.
new_token_ids
)
unpacked_latent
=
self
.
model
.
generate_text
(
past_key_values
=
past_key_values
,
max_length
=
max_length
,
do_sample
=
do_sample
,
temperature
=
temperature
,
end_token_id
=
self
.
new_token_ids
[
"eos_token_id"
],
**
generation_input
,
)
output
=
self
.
tokenizer
.
decode
(
unpacked_latent
[:,
0
])
output
=
output
.
split
(
"<|im_end|>"
)[
0
].
split
(
"<|im_start|>"
)[
1
]
return
output
@
torch
.
no_grad
()
def
interleave_inference
(
self
,
input_lists
:
List
[
Union
[
str
,
Image
.
Image
]],
think
=
False
,
understanding_output
=
False
,
max_think_token_n
=
1000
,
do_sample
=
False
,
text_temperature
=
0.3
,
cfg_text_scale
=
3.0
,
cfg_img_scale
=
1.5
,
cfg_interval
=
[
0.4
,
1.0
],
timestep_shift
=
3.0
,
num_timesteps
=
50
,
cfg_renorm_min
=
0.0
,
cfg_renorm_type
=
"global"
,
image_shapes
=
(
1024
,
1024
),
enable_taylorseer
=
False
,
)
->
List
[
Union
[
str
,
Image
.
Image
]]:
output_list
=
[]
gen_context
=
self
.
init_gen_context
()
cfg_text_context
=
deepcopy
(
gen_context
)
cfg_img_context
=
deepcopy
(
gen_context
)
with
torch
.
autocast
(
device_type
=
"cuda"
,
enabled
=
True
,
dtype
=
torch
.
bfloat16
):
if
think
:
if
understanding_output
:
system_prompt
=
VLM_THINK_SYSTEM_PROMPT
else
:
system_prompt
=
GEN_THINK_SYSTEM_PROMPT
gen_context
=
self
.
update_context_text
(
system_prompt
,
gen_context
)
cfg_img_context
=
self
.
update_context_text
(
system_prompt
,
cfg_img_context
)
for
input_term
in
input_lists
:
if
isinstance
(
input_term
,
str
):
cfg_text_context
=
deepcopy
(
gen_context
)
gen_context
=
self
.
update_context_text
(
input_term
,
gen_context
)
cfg_img_context
=
self
.
update_context_text
(
input_term
,
cfg_img_context
)
elif
isinstance
(
input_term
,
Image
.
Image
):
input_term
=
self
.
vae_transform
.
resize_transform
(
pil_img2rgb
(
input_term
)
)
gen_context
=
self
.
update_context_image
(
input_term
,
gen_context
,
vae
=
not
understanding_output
)
image_shapes
=
input_term
.
size
[::
-
1
]
cfg_text_context
=
deepcopy
(
gen_context
)
else
:
raise
ValueError
(
f
"Unsupported input type:
{
type
(
input_term
)
}
"
)
if
understanding_output
:
gen_text
=
self
.
gen_text
(
gen_context
,
do_sample
=
do_sample
,
temperature
=
text_temperature
,
max_length
=
max_think_token_n
,
)
output_list
.
append
(
gen_text
)
else
:
if
think
:
gen_text
=
self
.
gen_text
(
gen_context
,
do_sample
=
do_sample
,
temperature
=
text_temperature
,
max_length
=
max_think_token_n
,
)
gen_context
=
self
.
update_context_text
(
gen_text
,
gen_context
)
output_list
.
append
(
gen_text
)
img
=
self
.
gen_image
(
image_shapes
,
gen_context
,
cfg_text_precontext
=
cfg_text_context
,
cfg_img_precontext
=
cfg_img_context
,
cfg_text_scale
=
cfg_text_scale
,
cfg_img_scale
=
cfg_img_scale
,
cfg_interval
=
cfg_interval
,
timestep_shift
=
timestep_shift
,
num_timesteps
=
num_timesteps
,
cfg_renorm_min
=
cfg_renorm_min
,
cfg_renorm_type
=
cfg_renorm_type
,
enable_taylorseer
=
enable_taylorseer
,
)
output_list
.
append
(
img
)
return
output_list
def
__call__
(
self
,
image
:
Optional
[
Image
.
Image
]
=
None
,
text
:
Optional
[
str
]
=
None
,
**
kargs
)
->
Dict
[
str
,
Any
]:
output_dict
=
{
"image"
:
None
,
"text"
:
None
}
if
image
is
None
and
text
is
None
:
print
(
"Please provide at least one input: either an image or text."
)
return
output_dict
input_list
=
[]
if
image
is
not
None
:
input_list
.
append
(
image
)
if
text
is
not
None
:
input_list
.
append
(
text
)
output_list
=
self
.
interleave_inference
(
input_list
,
**
kargs
)
for
i
in
output_list
:
if
isinstance
(
i
,
Image
.
Image
):
output_dict
[
"image"
]
=
i
elif
isinstance
(
i
,
str
):
output_dict
[
"text"
]
=
i
return
output_dict
SenseNova-SI-main/sensenova_si/bagel_utils/modeling/__init__.py
0 → 100644
View file @
876a36a4
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
from
.
import
autoencoder
,
bagel
,
qwen2
,
siglip
SenseNova-SI-main/sensenova_si/bagel_utils/modeling/autoencoder.py
0 → 100644
View file @
876a36a4
# Copyright (c) 2024 Black Forest Labs.
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
#
# This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025-05-20.
#
# Original file was released under Apache-2.0, with the full license text
# available at https://github.com/black-forest-labs/flux/blob/main/LICENSE.
#
# This modified file is released under the same license.
from
dataclasses
import
dataclass
import
torch
from
einops
import
rearrange
from
safetensors.torch
import
load_file
as
load_sft
from
torch
import
Tensor
,
nn
@
dataclass
class
AutoEncoderParams
:
resolution
:
int
in_channels
:
int
downsample
:
int
ch
:
int
out_ch
:
int
ch_mult
:
list
[
int
]
num_res_blocks
:
int
z_channels
:
int
scale_factor
:
float
shift_factor
:
float
def
swish
(
x
:
Tensor
)
->
Tensor
:
return
x
*
torch
.
sigmoid
(
x
)
class
AttnBlock
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
):
super
().
__init__
()
self
.
in_channels
=
in_channels
self
.
norm
=
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
self
.
q
=
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
)
self
.
k
=
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
)
self
.
v
=
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
)
self
.
proj_out
=
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
)
def
attention
(
self
,
h_
:
Tensor
)
->
Tensor
:
h_
=
self
.
norm
(
h_
)
q
=
self
.
q
(
h_
)
k
=
self
.
k
(
h_
)
v
=
self
.
v
(
h_
)
b
,
c
,
h
,
w
=
q
.
shape
q
=
rearrange
(
q
,
"b c h w -> b 1 (h w) c"
).
contiguous
()
k
=
rearrange
(
k
,
"b c h w -> b 1 (h w) c"
).
contiguous
()
v
=
rearrange
(
v
,
"b c h w -> b 1 (h w) c"
).
contiguous
()
h_
=
nn
.
functional
.
scaled_dot_product_attention
(
q
,
k
,
v
)
return
rearrange
(
h_
,
"b 1 (h w) c -> b c h w"
,
h
=
h
,
w
=
w
,
c
=
c
,
b
=
b
)
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
return
x
+
self
.
proj_out
(
self
.
attention
(
x
))
class
ResnetBlock
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
):
super
().
__init__
()
self
.
in_channels
=
in_channels
out_channels
=
in_channels
if
out_channels
is
None
else
out_channels
self
.
out_channels
=
out_channels
self
.
norm1
=
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
self
.
conv1
=
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
norm2
=
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
out_channels
,
eps
=
1e-6
,
affine
=
True
)
self
.
conv2
=
nn
.
Conv2d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
self
.
in_channels
!=
self
.
out_channels
:
self
.
nin_shortcut
=
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
forward
(
self
,
x
):
h
=
x
h
=
self
.
norm1
(
h
)
h
=
swish
(
h
)
h
=
self
.
conv1
(
h
)
h
=
self
.
norm2
(
h
)
h
=
swish
(
h
)
h
=
self
.
conv2
(
h
)
if
self
.
in_channels
!=
self
.
out_channels
:
x
=
self
.
nin_shortcut
(
x
)
return
x
+
h
class
Downsample
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
):
super
().
__init__
()
# no asymmetric padding in torch conv, must do it ourselves
self
.
conv
=
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
3
,
stride
=
2
,
padding
=
0
)
def
forward
(
self
,
x
:
Tensor
):
pad
=
(
0
,
1
,
0
,
1
)
x
=
nn
.
functional
.
pad
(
x
,
pad
,
mode
=
"constant"
,
value
=
0
)
x
=
self
.
conv
(
x
)
return
x
class
Upsample
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
):
super
().
__init__
()
self
.
conv
=
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
x
:
Tensor
):
x
=
nn
.
functional
.
interpolate
(
x
,
scale_factor
=
2.0
,
mode
=
"nearest"
)
x
=
self
.
conv
(
x
)
return
x
class
Encoder
(
nn
.
Module
):
def
__init__
(
self
,
resolution
:
int
,
in_channels
:
int
,
ch
:
int
,
ch_mult
:
list
[
int
],
num_res_blocks
:
int
,
z_channels
:
int
,
):
super
().
__init__
()
self
.
ch
=
ch
self
.
num_resolutions
=
len
(
ch_mult
)
self
.
num_res_blocks
=
num_res_blocks
self
.
resolution
=
resolution
self
.
in_channels
=
in_channels
# downsampling
self
.
conv_in
=
nn
.
Conv2d
(
in_channels
,
self
.
ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
curr_res
=
resolution
in_ch_mult
=
(
1
,)
+
tuple
(
ch_mult
)
self
.
in_ch_mult
=
in_ch_mult
self
.
down
=
nn
.
ModuleList
()
block_in
=
self
.
ch
for
i_level
in
range
(
self
.
num_resolutions
):
block
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
block_in
=
ch
*
in_ch_mult
[
i_level
]
block_out
=
ch
*
ch_mult
[
i_level
]
for
_
in
range
(
self
.
num_res_blocks
):
block
.
append
(
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_out
))
block_in
=
block_out
down
=
nn
.
Module
()
down
.
block
=
block
down
.
attn
=
attn
if
i_level
!=
self
.
num_resolutions
-
1
:
down
.
downsample
=
Downsample
(
block_in
)
curr_res
=
curr_res
//
2
self
.
down
.
append
(
down
)
# middle
self
.
mid
=
nn
.
Module
()
self
.
mid
.
block_1
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
)
self
.
mid
.
attn_1
=
AttnBlock
(
block_in
)
self
.
mid
.
block_2
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
)
# end
self
.
norm_out
=
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
block_in
,
eps
=
1e-6
,
affine
=
True
)
self
.
conv_out
=
nn
.
Conv2d
(
block_in
,
2
*
z_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
# downsampling
hs
=
[
self
.
conv_in
(
x
)]
for
i_level
in
range
(
self
.
num_resolutions
):
for
i_block
in
range
(
self
.
num_res_blocks
):
h
=
self
.
down
[
i_level
].
block
[
i_block
](
hs
[
-
1
])
if
len
(
self
.
down
[
i_level
].
attn
)
>
0
:
h
=
self
.
down
[
i_level
].
attn
[
i_block
](
h
)
hs
.
append
(
h
)
if
i_level
!=
self
.
num_resolutions
-
1
:
hs
.
append
(
self
.
down
[
i_level
].
downsample
(
hs
[
-
1
]))
# middle
h
=
hs
[
-
1
]
h
=
self
.
mid
.
block_1
(
h
)
h
=
self
.
mid
.
attn_1
(
h
)
h
=
self
.
mid
.
block_2
(
h
)
# end
h
=
self
.
norm_out
(
h
)
h
=
swish
(
h
)
h
=
self
.
conv_out
(
h
)
return
h
class
Decoder
(
nn
.
Module
):
def
__init__
(
self
,
ch
:
int
,
out_ch
:
int
,
ch_mult
:
list
[
int
],
num_res_blocks
:
int
,
in_channels
:
int
,
resolution
:
int
,
z_channels
:
int
,
):
super
().
__init__
()
self
.
ch
=
ch
self
.
num_resolutions
=
len
(
ch_mult
)
self
.
num_res_blocks
=
num_res_blocks
self
.
resolution
=
resolution
self
.
in_channels
=
in_channels
self
.
ffactor
=
2
**
(
self
.
num_resolutions
-
1
)
# compute in_ch_mult, block_in and curr_res at lowest res
block_in
=
ch
*
ch_mult
[
self
.
num_resolutions
-
1
]
curr_res
=
resolution
//
2
**
(
self
.
num_resolutions
-
1
)
self
.
z_shape
=
(
1
,
z_channels
,
curr_res
,
curr_res
)
# z to block_in
self
.
conv_in
=
nn
.
Conv2d
(
z_channels
,
block_in
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
# middle
self
.
mid
=
nn
.
Module
()
self
.
mid
.
block_1
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
)
self
.
mid
.
attn_1
=
AttnBlock
(
block_in
)
self
.
mid
.
block_2
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
)
# upsampling
self
.
up
=
nn
.
ModuleList
()
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
block
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
block_out
=
ch
*
ch_mult
[
i_level
]
for
_
in
range
(
self
.
num_res_blocks
+
1
):
block
.
append
(
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_out
))
block_in
=
block_out
up
=
nn
.
Module
()
up
.
block
=
block
up
.
attn
=
attn
if
i_level
!=
0
:
up
.
upsample
=
Upsample
(
block_in
)
curr_res
=
curr_res
*
2
self
.
up
.
insert
(
0
,
up
)
# prepend to get consistent order
# end
self
.
norm_out
=
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
block_in
,
eps
=
1e-6
,
affine
=
True
)
self
.
conv_out
=
nn
.
Conv2d
(
block_in
,
out_ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
z
:
Tensor
)
->
Tensor
:
# z to block_in
h
=
self
.
conv_in
(
z
)
# middle
h
=
self
.
mid
.
block_1
(
h
)
h
=
self
.
mid
.
attn_1
(
h
)
h
=
self
.
mid
.
block_2
(
h
)
# upsampling
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
h
=
self
.
up
[
i_level
].
block
[
i_block
](
h
)
if
len
(
self
.
up
[
i_level
].
attn
)
>
0
:
h
=
self
.
up
[
i_level
].
attn
[
i_block
](
h
)
if
i_level
!=
0
:
h
=
self
.
up
[
i_level
].
upsample
(
h
)
# end
h
=
self
.
norm_out
(
h
)
h
=
swish
(
h
)
h
=
self
.
conv_out
(
h
)
return
h
class
DiagonalGaussian
(
nn
.
Module
):
def
__init__
(
self
,
sample
:
bool
=
True
,
chunk_dim
:
int
=
1
):
super
().
__init__
()
self
.
sample
=
sample
self
.
chunk_dim
=
chunk_dim
def
forward
(
self
,
z
:
Tensor
)
->
Tensor
:
mean
,
logvar
=
torch
.
chunk
(
z
,
2
,
dim
=
self
.
chunk_dim
)
if
self
.
sample
:
std
=
torch
.
exp
(
0.5
*
logvar
)
return
mean
+
std
*
torch
.
randn_like
(
mean
)
else
:
return
mean
class
AutoEncoder
(
nn
.
Module
):
def
__init__
(
self
,
params
:
AutoEncoderParams
):
super
().
__init__
()
self
.
encoder
=
Encoder
(
resolution
=
params
.
resolution
,
in_channels
=
params
.
in_channels
,
ch
=
params
.
ch
,
ch_mult
=
params
.
ch_mult
,
num_res_blocks
=
params
.
num_res_blocks
,
z_channels
=
params
.
z_channels
,
)
self
.
decoder
=
Decoder
(
resolution
=
params
.
resolution
,
in_channels
=
params
.
in_channels
,
ch
=
params
.
ch
,
out_ch
=
params
.
out_ch
,
ch_mult
=
params
.
ch_mult
,
num_res_blocks
=
params
.
num_res_blocks
,
z_channels
=
params
.
z_channels
,
)
self
.
reg
=
DiagonalGaussian
()
self
.
scale_factor
=
params
.
scale_factor
self
.
shift_factor
=
params
.
shift_factor
def
encode
(
self
,
x
:
Tensor
)
->
Tensor
:
z
=
self
.
reg
(
self
.
encoder
(
x
))
z
=
self
.
scale_factor
*
(
z
-
self
.
shift_factor
)
return
z
def
decode
(
self
,
z
:
Tensor
)
->
Tensor
:
z
=
z
/
self
.
scale_factor
+
self
.
shift_factor
return
self
.
decoder
(
z
)
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
return
self
.
decode
(
self
.
encode
(
x
))
def
print_load_warning
(
missing
:
list
[
str
],
unexpected
:
list
[
str
])
->
None
:
if
len
(
missing
)
>
0
and
len
(
unexpected
)
>
0
:
print
(
f
"Got
{
len
(
missing
)
}
missing keys:
\n\t
"
+
"
\n\t
"
.
join
(
missing
))
print
(
"
\n
"
+
"-"
*
79
+
"
\n
"
)
print
(
f
"Got
{
len
(
unexpected
)
}
unexpected keys:
\n\t
"
+
"
\n\t
"
.
join
(
unexpected
))
elif
len
(
missing
)
>
0
:
print
(
f
"Got
{
len
(
missing
)
}
missing keys:
\n\t
"
+
"
\n\t
"
.
join
(
missing
))
elif
len
(
unexpected
)
>
0
:
print
(
f
"Got
{
len
(
unexpected
)
}
unexpected keys:
\n\t
"
+
"
\n\t
"
.
join
(
unexpected
))
def
load_ae
(
local_path
:
str
)
->
AutoEncoder
:
ae_params
=
AutoEncoderParams
(
resolution
=
256
,
in_channels
=
3
,
downsample
=
8
,
ch
=
128
,
out_ch
=
3
,
ch_mult
=
[
1
,
2
,
4
,
4
],
num_res_blocks
=
2
,
z_channels
=
16
,
scale_factor
=
0.3611
,
shift_factor
=
0.1159
,
)
# Loading the autoencoder
ae
=
AutoEncoder
(
ae_params
)
if
local_path
is
not
None
:
sd
=
load_sft
(
local_path
)
missing
,
unexpected
=
ae
.
load_state_dict
(
sd
,
strict
=
False
,
assign
=
True
)
print_load_warning
(
missing
,
unexpected
)
return
ae
,
ae_params
SenseNova-SI-main/sensenova_si/bagel_utils/modeling/bagel/__init__.py
0 → 100644
View file @
876a36a4
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
from
.bagel
import
Bagel
,
BagelConfig
from
.qwen2_navit
import
Qwen2Config
,
Qwen2ForCausalLM
,
Qwen2Model
from
.siglip_navit
import
SiglipVisionConfig
,
SiglipVisionModel
__all__
=
[
"BagelConfig"
,
"Bagel"
,
"Qwen2Config"
,
"Qwen2Model"
,
"Qwen2ForCausalLM"
,
"SiglipVisionConfig"
,
"SiglipVisionModel"
,
]
SenseNova-SI-main/sensenova_si/bagel_utils/modeling/bagel/bagel.py
0 → 100644
View file @
876a36a4
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
import
copy
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
tqdm
import
tqdm
# from torch.nn.attention.flex_attention import create_block_mask
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.modeling_utils
import
PreTrainedModel
from
...data.data_utils
import
(
create_sparse_mask
,
get_flattened_position_ids_extrapolate
,
get_flattened_position_ids_interpolate
,
patchify
,
)
from
..cache_utils.taylorseer
import
cache_init
from
.modeling_utils
import
MLPconnector
,
PositionEmbedding
,
TimestepEmbedder
from
.qwen2_navit
import
NaiveCache
class
BagelConfig
(
PretrainedConfig
):
def
__init__
(
self
,
visual_gen
=
True
,
visual_und
=
True
,
llm_config
=
None
,
vit_config
=
None
,
vae_config
=
None
,
latent_patch_size
=
2
,
max_latent_size
=
32
,
vit_max_num_patch_per_side
=
70
,
connector_act
=
"gelu_pytorch_tanh"
,
interpolate_pos
=
False
,
timestep_shift
=
1.0
,
**
kwargs
,
):
super
().
__init__
(
**
kwargs
)
self
.
visual_gen
=
visual_gen
self
.
visual_und
=
visual_und
self
.
llm_config
=
llm_config
self
.
vit_config
=
vit_config
self
.
vae_config
=
vae_config
self
.
latent_patch_size
=
latent_patch_size
self
.
max_latent_size
=
max_latent_size
self
.
vit_max_num_patch_per_side
=
vit_max_num_patch_per_side
self
.
connector_act
=
connector_act
self
.
interpolate_pos
=
interpolate_pos
self
.
timestep_shift
=
timestep_shift
class
Bagel
(
PreTrainedModel
):
config_class
=
BagelConfig
base_model_prefix
=
"bagel"
def
__init__
(
self
,
language_model
,
vit_model
,
config
:
BagelConfig
):
super
().
__init__
(
config
)
self
.
language_model
=
language_model
self
.
hidden_size
=
config
.
llm_config
.
hidden_size
self
.
use_moe
=
"Mo"
in
config
.
llm_config
.
layer_module
self
.
num_heads
=
config
.
llm_config
.
num_attention_heads
if
config
.
visual_gen
:
self
.
latent_patch_size
=
config
.
latent_patch_size
self
.
timestep_shift
=
config
.
timestep_shift
self
.
latent_downsample
=
(
config
.
vae_config
.
downsample
*
config
.
latent_patch_size
)
self
.
max_latent_size
=
config
.
max_latent_size
self
.
latent_channel
=
config
.
vae_config
.
z_channels
self
.
patch_latent_dim
=
self
.
latent_patch_size
**
2
*
self
.
latent_channel
self
.
time_embedder
=
TimestepEmbedder
(
self
.
hidden_size
)
self
.
vae2llm
=
nn
.
Linear
(
self
.
patch_latent_dim
,
self
.
hidden_size
)
self
.
llm2vae
=
nn
.
Linear
(
self
.
hidden_size
,
self
.
patch_latent_dim
)
self
.
latent_pos_embed
=
PositionEmbedding
(
self
.
max_latent_size
,
self
.
hidden_size
)
if
config
.
visual_und
:
self
.
vit_model
=
vit_model
self
.
vit_patch_size
=
config
.
vit_config
.
patch_size
self
.
vit_max_num_patch_per_side
=
config
.
vit_max_num_patch_per_side
self
.
vit_hidden_size
=
config
.
vit_config
.
hidden_size
self
.
connector
=
MLPconnector
(
self
.
vit_hidden_size
,
self
.
hidden_size
,
config
.
connector_act
)
self
.
vit_pos_embed
=
PositionEmbedding
(
self
.
vit_max_num_patch_per_side
,
self
.
hidden_size
)
if
config
.
interpolate_pos
:
self
.
get_flattened_position_ids
=
get_flattened_position_ids_interpolate
else
:
self
.
get_flattened_position_ids
=
get_flattened_position_ids_extrapolate
self
.
config
=
config
self
.
_init_weights
()
def
_init_weights
(
self
):
if
self
.
config
.
visual_gen
:
nn
.
init
.
constant_
(
self
.
llm2vae
.
weight
,
0
)
nn
.
init
.
constant_
(
self
.
llm2vae
.
bias
,
0
)
def
forward
(
self
,
sequence_length
:
int
,
packed_text_ids
:
torch
.
LongTensor
,
packed_text_indexes
:
torch
.
LongTensor
,
sample_lens
:
List
[
int
],
packed_position_ids
:
torch
.
LongTensor
,
nested_attention_masks
:
List
[
torch
.
Tensor
]
=
None
,
split_lens
:
List
[
int
]
=
None
,
attn_modes
:
List
[
str
]
=
None
,
# for visual understanding
ce_loss_indexes
:
Optional
[
torch
.
BoolTensor
]
=
None
,
packed_label_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
packed_vit_tokens
:
Optional
[
torch
.
Tensor
]
=
None
,
packed_vit_token_indexes
:
Optional
[
torch
.
LongTensor
]
=
None
,
packed_vit_position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
vit_token_seqlens
:
Optional
[
torch
.
IntTensor
]
=
None
,
# for visual generation
padded_latent
:
Optional
[
torch
.
Tensor
]
=
None
,
patchified_vae_latent_shapes
:
Optional
[
List
[
Tuple
[
int
,
int
]]]
=
None
,
packed_latent_position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
packed_vae_token_indexes
:
Optional
[
torch
.
LongTensor
]
=
None
,
packed_timesteps
:
Optional
[
torch
.
LongTensor
]
=
None
,
mse_loss_indexes
:
Optional
[
torch
.
BoolTensor
]
=
None
,
)
->
torch
.
Tensor
:
"""
Args:
sequence_length: length of sequence.
packed_text_ids: 1-D int tensor, packed text token ids.
packed_text_indexes: 1-D int tensor, packed text token indexes in sequence.
sample_lens: A list of N ints, length of each sample in packed_sequence.
nested_attention_masks: A list of N 2-D float tensor, where 0.0 means attention and
-inf means ignore.
packed_position_ids: packed 1-D positions, an image has only one global position shared
by all latent tokens.
packed_vit_tokens: packed patchified image tokens for vit model.
packed_vit_position_ids: 1-D int tensor, the position of each token for vit model.
packed_vit_token_indexes: 1-D int tensor, packed vit token indexes in sequence.
vit_token_seqlens: 1-D int tensor, the length of each image tokens for vit model.
packed_label_ids: 1-D int tensor, packed label token ids.
ce_loss_indexes: 1-D bool tensor, where to compute ce loss.
padded_latent: padded latent from VAE encoder.
patchified_vae_latent_shapes: A list of (h, w) tuples, patchfied latent shapes of each image.
packed_latent_position_ids: 1-D int tensor, the position of each token for latent.
packed_vae_token_indexes: 1-D int tensor, padded image token indexes in sequence.
packed_timesteps: 1-D float tensor, flow timesteps. 0 indicates use clean image.
mse_loss_indexes: 1-D bool tensor, where to compute mse loss.
"""
packed_text_embedding
=
self
.
language_model
.
model
.
embed_tokens
(
packed_text_ids
)
packed_sequence
=
packed_text_embedding
.
new_zeros
(
size
=
(
sequence_length
,
self
.
hidden_size
)
)
packed_sequence
[
packed_text_indexes
]
=
packed_text_embedding
if
nested_attention_masks
is
None
:
sparse_mask
=
create_sparse_mask
(
sample_lens
,
split_lens
,
attn_modes
,
packed_text_embedding
.
device
)
seqlen
=
sum
(
sample_lens
)
block_mask
=
create_block_mask
(
sparse_mask
,
B
=
1
,
H
=
self
.
num_heads
,
Q_LEN
=
seqlen
,
KV_LEN
=
seqlen
,
device
=
packed_text_embedding
.
device
,
BLOCK_SIZE
=
128
,
_compile
=
True
,
)
attention_mask
=
block_mask
else
:
attention_mask
=
nested_attention_masks
if
self
.
config
.
visual_und
:
cu_seqlens
=
torch
.
nn
.
functional
.
pad
(
torch
.
cumsum
(
vit_token_seqlens
,
dim
=
0
),
(
1
,
0
)
)
cu_seqlens
=
cu_seqlens
.
to
(
torch
.
int32
)
max_seqlen
=
torch
.
max
(
vit_token_seqlens
).
item
()
packed_vit_token_embed
=
self
.
vit_model
(
packed_pixel_values
=
packed_vit_tokens
,
packed_flattened_position_ids
=
packed_vit_position_ids
,
cu_seqlens
=
cu_seqlens
,
max_seqlen
=
max_seqlen
,
)
packed_vit_token_embed
=
self
.
connector
(
packed_vit_token_embed
)
vit_token_pos_emb
=
self
.
vit_pos_embed
(
packed_vit_position_ids
)
packed_vit_token_embed
=
packed_vit_token_embed
+
vit_token_pos_emb
packed_sequence
[
packed_vit_token_indexes
]
=
packed_vit_token_embed
if
self
.
config
.
visual_gen
:
p
=
self
.
latent_patch_size
packed_latent
=
[]
for
latent
,
(
h
,
w
)
in
zip
(
padded_latent
,
patchified_vae_latent_shapes
):
latent
=
latent
[:,
:
h
*
p
,
:
w
*
p
].
reshape
(
self
.
latent_channel
,
h
,
p
,
w
,
p
)
latent
=
torch
.
einsum
(
"chpwq->hwpqc"
,
latent
).
reshape
(
-
1
,
p
*
p
*
self
.
latent_channel
)
packed_latent
.
append
(
latent
)
packed_latent_clean
=
torch
.
cat
(
packed_latent
,
dim
=
0
)
noise
=
torch
.
randn_like
(
packed_latent_clean
)
packed_timesteps
=
torch
.
sigmoid
(
packed_timesteps
)
packed_timesteps
=
(
self
.
timestep_shift
*
packed_timesteps
/
(
1
+
(
self
.
timestep_shift
-
1
)
*
packed_timesteps
)
)
packed_latent
=
(
1
-
packed_timesteps
[:,
None
]
)
*
packed_latent_clean
+
packed_timesteps
[:,
None
]
*
noise
packed_timestep_embeds
=
self
.
time_embedder
(
packed_timesteps
)
latent_token_pos_emb
=
self
.
latent_pos_embed
(
packed_latent_position_ids
)
packed_latent
=
(
self
.
vae2llm
(
packed_latent
)
+
packed_timestep_embeds
+
latent_token_pos_emb
)
packed_sequence
[
packed_vae_token_indexes
]
=
packed_latent
extra_inputs
=
{}
if
self
.
use_moe
:
packed_und_token_indexes
=
packed_text_indexes
if
packed_vit_token_indexes
is
not
None
:
packed_und_token_indexes
=
torch
.
cat
(
[
packed_text_indexes
,
packed_vit_token_indexes
],
dim
=
0
)
extra_inputs
.
update
(
packed_und_token_indexes
=
packed_und_token_indexes
,
packed_gen_token_indexes
=
packed_vae_token_indexes
,
)
last_hidden_state
=
self
.
language_model
(
packed_sequence
=
packed_sequence
,
sample_lens
=
sample_lens
,
attention_mask
=
attention_mask
,
packed_position_ids
=
packed_position_ids
,
**
extra_inputs
,
)
mse
=
None
if
self
.
config
.
visual_gen
:
packed_mse_preds
=
self
.
llm2vae
(
last_hidden_state
[
mse_loss_indexes
])
target
=
(
noise
-
packed_latent_clean
)
# NOTE: v_t=dx_t/dt=x_1-x_0, pointing from data to noise
has_mse
=
packed_timesteps
>
0
mse
=
(
packed_mse_preds
-
target
[
has_mse
])
**
2
ce
=
None
if
ce_loss_indexes
is
not
None
:
packed_ce_preds
=
self
.
language_model
.
lm_head
(
last_hidden_state
[
ce_loss_indexes
]
)
ce
=
F
.
cross_entropy
(
packed_ce_preds
,
packed_label_ids
,
reduction
=
"none"
)
return
dict
(
mse
=
mse
,
ce
=
ce
)
def
prepare_prompts
(
self
,
curr_kvlens
,
curr_rope
,
prompts
,
tokenizer
,
new_token_ids
):
packed_text_ids
=
list
()
packed_text_position_ids
=
list
()
text_token_lens
=
list
()
packed_text_indexes
=
list
()
packed_key_value_indexes
=
list
()
curr
=
0
newlens
,
new_rope
=
list
(),
list
()
for
prompt
,
curr_kvlen
,
curr_position_id
in
zip
(
prompts
,
curr_kvlens
,
curr_rope
):
packed_key_value_indexes
.
extend
(
range
(
curr
,
curr
+
curr_kvlen
))
curr
+=
curr_kvlen
text_ids
=
tokenizer
.
encode
(
prompt
)
text_ids
=
(
[
new_token_ids
[
"bos_token_id"
]]
+
text_ids
+
[
new_token_ids
[
"eos_token_id"
]]
)
text_token_lens
.
append
(
len
(
text_ids
))
packed_text_ids
.
extend
(
text_ids
)
packed_text_position_ids
.
extend
(
range
(
curr_position_id
,
curr_position_id
+
len
(
text_ids
))
)
packed_text_indexes
.
extend
(
range
(
curr
,
curr
+
len
(
text_ids
)))
newlens
.
append
(
curr_kvlen
+
len
(
text_ids
))
new_rope
.
append
(
curr_position_id
+
len
(
text_ids
))
curr
+=
len
(
text_ids
)
generation_input
=
{
"text_token_lens"
:
torch
.
tensor
(
text_token_lens
,
dtype
=
torch
.
int
),
"packed_text_ids"
:
torch
.
tensor
(
packed_text_ids
,
dtype
=
torch
.
long
),
"packed_text_position_ids"
:
torch
.
tensor
(
packed_text_position_ids
,
dtype
=
torch
.
long
),
"packed_text_indexes"
:
torch
.
tensor
(
packed_text_indexes
,
dtype
=
torch
.
long
),
"packed_key_value_indexes"
:
torch
.
tensor
(
packed_key_value_indexes
,
dtype
=
torch
.
long
),
"key_values_lens"
:
torch
.
tensor
(
curr_kvlens
,
dtype
=
torch
.
int
),
}
return
generation_input
,
newlens
,
new_rope
@
torch
.
no_grad
def
forward_cache_update_text
(
self
,
past_key_values
:
NaiveCache
,
packed_text_ids
:
torch
.
IntTensor
,
packed_text_position_ids
:
torch
.
LongTensor
,
text_token_lens
:
torch
.
LongTensor
,
packed_text_indexes
:
torch
.
LongTensor
,
packed_key_value_indexes
:
torch
.
LongTensor
,
key_values_lens
:
torch
.
IntTensor
,
):
packed_text_embedding
=
self
.
language_model
.
model
.
embed_tokens
(
packed_text_ids
)
extra_inputs
=
{}
if
self
.
use_moe
:
extra_inputs
=
{
"mode"
:
"und"
}
output
=
self
.
language_model
.
forward_inference
(
packed_query_sequence
=
packed_text_embedding
,
query_lens
=
text_token_lens
,
packed_query_position_ids
=
packed_text_position_ids
,
packed_query_indexes
=
packed_text_indexes
,
past_key_values
=
past_key_values
,
packed_key_value_indexes
=
packed_key_value_indexes
,
key_values_lens
=
key_values_lens
,
update_past_key_values
=
True
,
is_causal
=
True
,
**
extra_inputs
,
)
past_key_values
=
output
.
past_key_values
return
past_key_values
def
prepare_vit_images
(
self
,
curr_kvlens
,
curr_rope
,
images
,
transforms
,
new_token_ids
):
packed_vit_token_indexes
=
list
()
vit_token_seqlens
,
packed_vit_tokens
,
packed_vit_position_ids
=
(
list
(),
list
(),
list
(),
)
packed_text_ids
,
packed_text_indexes
=
list
(),
list
()
packed_seqlens
,
packed_position_ids
,
packed_indexes
=
list
(),
list
(),
list
()
packed_key_value_indexes
=
list
()
_curr
=
curr
=
0
newlens
,
new_rope
=
list
(),
list
()
for
image
,
curr_kvlen
,
curr_position_id
in
zip
(
images
,
curr_kvlens
,
curr_rope
):
packed_key_value_indexes
.
extend
(
range
(
curr
,
curr
+
curr_kvlen
))
curr
+=
curr_kvlen
packed_text_ids
.
append
(
new_token_ids
[
"start_of_image"
])
packed_text_indexes
.
append
(
_curr
)
packed_indexes
.
append
(
curr
)
curr
+=
1
_curr
+=
1
image_tensor
=
transforms
(
image
)
vit_position_ids
=
self
.
get_flattened_position_ids
(
image_tensor
.
size
(
1
),
image_tensor
.
size
(
2
),
self
.
vit_patch_size
,
max_num_patches_per_side
=
self
.
vit_max_num_patch_per_side
,
)
vit_tokens
=
patchify
(
image_tensor
,
self
.
vit_patch_size
)
packed_vit_tokens
.
append
(
vit_tokens
)
num_img_tokens
=
vit_tokens
.
shape
[
0
]
packed_vit_position_ids
.
append
(
vit_position_ids
)
vit_token_seqlens
.
append
(
num_img_tokens
)
packed_vit_token_indexes
.
extend
(
range
(
_curr
,
_curr
+
num_img_tokens
))
packed_indexes
.
extend
(
range
(
curr
,
curr
+
num_img_tokens
))
curr
+=
num_img_tokens
_curr
+=
num_img_tokens
packed_text_ids
.
append
(
new_token_ids
[
"end_of_image"
])
packed_text_indexes
.
append
(
_curr
)
packed_indexes
.
append
(
curr
)
curr
+=
1
_curr
+=
1
packed_position_ids
.
extend
([
curr_position_id
]
*
(
num_img_tokens
+
2
))
packed_seqlens
.
append
(
num_img_tokens
+
2
)
newlens
.
append
(
curr_kvlen
+
num_img_tokens
+
2
)
new_rope
.
append
(
curr_position_id
+
1
)
generation_input
=
{
"packed_text_ids"
:
torch
.
tensor
(
packed_text_ids
,
dtype
=
torch
.
long
),
"packed_text_indexes"
:
torch
.
tensor
(
packed_text_indexes
,
dtype
=
torch
.
long
),
"vit_token_seqlens"
:
torch
.
tensor
(
vit_token_seqlens
,
dtype
=
torch
.
int
),
"packed_vit_tokens"
:
torch
.
cat
(
packed_vit_tokens
,
dim
=
0
),
"packed_vit_position_ids"
:
torch
.
cat
(
packed_vit_position_ids
,
dim
=
0
),
"packed_vit_token_indexes"
:
torch
.
tensor
(
packed_vit_token_indexes
,
dtype
=
torch
.
long
),
"packed_position_ids"
:
torch
.
tensor
(
packed_position_ids
,
dtype
=
torch
.
long
),
"packed_seqlens"
:
torch
.
tensor
(
packed_seqlens
,
dtype
=
torch
.
int
),
"packed_indexes"
:
torch
.
tensor
(
packed_indexes
,
dtype
=
torch
.
long
),
"packed_key_value_indexes"
:
torch
.
tensor
(
packed_key_value_indexes
,
dtype
=
torch
.
long
),
"key_values_lens"
:
torch
.
tensor
(
curr_kvlens
,
dtype
=
torch
.
int
),
}
return
generation_input
,
newlens
,
new_rope
@
torch
.
no_grad
def
forward_cache_update_vit
(
self
,
past_key_values
:
NaiveCache
,
packed_text_ids
:
torch
.
LongTensor
,
packed_text_indexes
:
torch
.
LongTensor
,
packed_vit_tokens
:
torch
.
Tensor
,
packed_vit_token_indexes
:
torch
.
LongTensor
,
packed_vit_position_ids
:
torch
.
LongTensor
,
vit_token_seqlens
:
torch
.
IntTensor
,
packed_position_ids
:
torch
.
LongTensor
,
packed_seqlens
:
torch
.
IntTensor
,
packed_indexes
:
torch
.
LongTensor
,
packed_key_value_indexes
:
torch
.
LongTensor
,
key_values_lens
:
torch
.
IntTensor
,
):
packed_text_embedding
=
self
.
language_model
.
model
.
embed_tokens
(
packed_text_ids
)
packed_sequence
=
packed_text_embedding
.
new_zeros
(
(
sum
(
packed_seqlens
),
self
.
hidden_size
)
)
packed_sequence
[
packed_text_indexes
]
=
packed_text_embedding
cu_seqlens
=
torch
.
nn
.
functional
.
pad
(
torch
.
cumsum
(
vit_token_seqlens
,
dim
=
0
),
(
1
,
0
)
)
cu_seqlens
=
cu_seqlens
.
to
(
torch
.
int32
)
max_seqlen
=
torch
.
max
(
vit_token_seqlens
).
item
()
packed_vit_token_embed
=
self
.
vit_model
(
packed_pixel_values
=
packed_vit_tokens
,
packed_flattened_position_ids
=
packed_vit_position_ids
,
cu_seqlens
=
cu_seqlens
,
max_seqlen
=
max_seqlen
,
)
packed_vit_token_embed
=
self
.
connector
(
packed_vit_token_embed
)
pos_emb
=
self
.
vit_pos_embed
(
packed_vit_position_ids
)
packed_vit_token_embed
=
packed_vit_token_embed
+
pos_emb
if
packed_vit_token_embed
.
dtype
!=
packed_sequence
.
dtype
:
packed_vit_token_embed
=
packed_vit_token_embed
.
to
(
packed_sequence
.
dtype
)
packed_sequence
[
packed_vit_token_indexes
]
=
packed_vit_token_embed
extra_inputs
=
{}
if
self
.
use_moe
:
extra_inputs
=
{
"mode"
:
"und"
}
output
=
self
.
language_model
.
forward_inference
(
packed_query_sequence
=
packed_sequence
,
query_lens
=
packed_seqlens
,
packed_query_position_ids
=
packed_position_ids
,
packed_query_indexes
=
packed_indexes
,
past_key_values
=
past_key_values
,
packed_key_value_indexes
=
packed_key_value_indexes
,
key_values_lens
=
key_values_lens
,
update_past_key_values
=
True
,
is_causal
=
False
,
**
extra_inputs
,
)
past_key_values
=
output
.
past_key_values
return
past_key_values
def
prepare_vae_images
(
self
,
curr_kvlens
,
curr_rope
,
images
,
transforms
,
new_token_ids
,
timestep
=
0
):
patchified_vae_latent_shapes
,
packed_vae_position_ids
=
list
(),
list
()
packed_vae_token_indexes
=
list
()
packed_text_ids
,
packed_text_indexes
=
list
(),
list
()
packed_seqlens
,
packed_position_ids
,
packed_indexes
=
list
(),
list
(),
list
()
packed_key_value_indexes
=
list
()
_curr
=
curr
=
0
vae_image_tensors
=
list
()
newlens
,
new_rope
=
list
(),
list
()
for
image
,
curr_kvlen
,
curr_position_id
in
zip
(
images
,
curr_kvlens
,
curr_rope
):
packed_key_value_indexes
.
extend
(
range
(
curr
,
curr
+
curr_kvlen
))
curr
+=
curr_kvlen
packed_text_ids
.
append
(
new_token_ids
[
"start_of_image"
])
packed_text_indexes
.
append
(
_curr
)
packed_indexes
.
append
(
curr
)
curr
+=
1
_curr
+=
1
image_tensor
=
transforms
(
image
)
vae_image_tensors
.
append
(
image_tensor
)
vae_posiiton_ids
=
self
.
get_flattened_position_ids
(
image_tensor
.
size
(
1
),
image_tensor
.
size
(
2
),
self
.
latent_downsample
,
max_num_patches_per_side
=
self
.
max_latent_size
,
)
packed_vae_position_ids
.
append
(
vae_posiiton_ids
)
H
,
W
=
image_tensor
.
shape
[
1
:]
h
=
H
//
self
.
latent_downsample
w
=
W
//
self
.
latent_downsample
patchified_vae_latent_shapes
.
append
((
h
,
w
))
num_img_tokens
=
w
*
h
packed_vae_token_indexes
.
extend
(
range
(
_curr
,
_curr
+
num_img_tokens
))
packed_indexes
.
extend
(
range
(
curr
,
curr
+
num_img_tokens
))
curr
+=
num_img_tokens
_curr
+=
num_img_tokens
packed_text_ids
.
append
(
new_token_ids
[
"end_of_image"
])
packed_text_indexes
.
append
(
_curr
)
packed_indexes
.
append
(
curr
)
curr
+=
1
_curr
+=
1
packed_position_ids
.
extend
([
curr_position_id
]
*
(
num_img_tokens
+
2
))
packed_seqlens
.
append
(
num_img_tokens
+
2
)
newlens
.
append
(
curr_kvlen
+
num_img_tokens
+
2
)
new_rope
.
append
(
curr_position_id
+
1
)
image_sizes
=
[
item
.
shape
for
item
in
vae_image_tensors
]
max_image_size
=
[
max
(
item
)
for
item
in
list
(
zip
(
*
image_sizes
))]
padded_images
=
torch
.
zeros
(
size
=
(
len
(
vae_image_tensors
),
*
max_image_size
))
for
i
,
image_tensor
in
enumerate
(
vae_image_tensors
):
padded_images
[
i
,
:,
:
image_tensor
.
shape
[
1
],
:
image_tensor
.
shape
[
2
]]
=
(
image_tensor
)
generation_input
=
{
"padded_images"
:
padded_images
,
"patchified_vae_latent_shapes"
:
patchified_vae_latent_shapes
,
"packed_vae_position_ids"
:
torch
.
cat
(
packed_vae_position_ids
,
dim
=
0
),
"packed_timesteps"
:
torch
.
tensor
([
timestep
]),
"packed_vae_token_indexes"
:
torch
.
tensor
(
packed_vae_token_indexes
,
dtype
=
torch
.
long
),
"packed_text_ids"
:
torch
.
tensor
(
packed_text_ids
,
dtype
=
torch
.
long
),
"packed_text_indexes"
:
torch
.
tensor
(
packed_text_indexes
,
dtype
=
torch
.
long
),
"packed_position_ids"
:
torch
.
tensor
(
packed_position_ids
,
dtype
=
torch
.
long
),
"packed_seqlens"
:
torch
.
tensor
(
packed_seqlens
,
dtype
=
torch
.
int
),
"packed_indexes"
:
torch
.
tensor
(
packed_indexes
,
dtype
=
torch
.
long
),
"packed_key_value_indexes"
:
torch
.
tensor
(
packed_key_value_indexes
,
dtype
=
torch
.
long
),
"key_values_lens"
:
torch
.
tensor
(
curr_kvlens
,
dtype
=
torch
.
int
),
}
return
generation_input
,
newlens
,
new_rope
@
torch
.
no_grad
def
forward_cache_update_vae
(
self
,
vae_model
,
past_key_values
:
NaiveCache
,
padded_images
:
torch
.
Tensor
,
patchified_vae_latent_shapes
:
List
,
packed_vae_position_ids
:
torch
.
LongTensor
,
packed_timesteps
:
torch
.
Tensor
,
packed_vae_token_indexes
:
torch
.
LongTensor
,
packed_text_ids
:
torch
.
LongTensor
,
packed_text_indexes
:
torch
.
LongTensor
,
packed_position_ids
:
torch
.
LongTensor
,
packed_seqlens
:
torch
.
IntTensor
,
packed_indexes
:
torch
.
LongTensor
,
key_values_lens
:
torch
.
IntTensor
,
packed_key_value_indexes
:
torch
.
Tensor
,
):
packed_text_embedding
=
self
.
language_model
.
model
.
embed_tokens
(
packed_text_ids
)
packed_sequence
=
packed_text_embedding
.
new_zeros
(
(
sum
(
packed_seqlens
),
self
.
hidden_size
)
)
packed_sequence
[
packed_text_indexes
]
=
packed_text_embedding
padded_latent
=
vae_model
.
encode
(
padded_images
)
p
=
self
.
latent_patch_size
packed_latent
=
list
()
for
latent
,
(
h
,
w
)
in
zip
(
padded_latent
,
patchified_vae_latent_shapes
):
latent
=
latent
[:,
:
h
*
p
,
:
w
*
p
].
reshape
(
self
.
latent_channel
,
h
,
p
,
w
,
p
)
latent
=
torch
.
einsum
(
"chpwq->hwpqc"
,
latent
).
reshape
(
-
1
,
p
*
p
*
self
.
latent_channel
)
packed_latent
.
append
(
latent
)
packed_latent
=
torch
.
cat
(
packed_latent
,
dim
=
0
)
packed_pos_embed
=
self
.
latent_pos_embed
(
packed_vae_position_ids
)
packed_timestep_embeds
=
self
.
time_embedder
(
packed_timesteps
)
packed_latent
=
(
self
.
vae2llm
(
packed_latent
)
+
packed_timestep_embeds
+
packed_pos_embed
)
if
packed_latent
.
dtype
!=
packed_sequence
.
dtype
:
packed_latent
=
packed_latent
.
to
(
packed_sequence
.
dtype
)
packed_sequence
[
packed_vae_token_indexes
]
=
packed_latent
extra_inputs
=
{}
if
self
.
use_moe
:
extra_inputs
=
{
"mode"
:
"gen"
,
"packed_vae_token_indexes"
:
packed_vae_token_indexes
,
"packed_text_indexes"
:
packed_text_indexes
,
}
output
=
self
.
language_model
.
forward_inference
(
packed_query_sequence
=
packed_sequence
,
query_lens
=
packed_seqlens
,
packed_query_position_ids
=
packed_position_ids
,
packed_query_indexes
=
packed_indexes
,
past_key_values
=
past_key_values
,
key_values_lens
=
key_values_lens
,
packed_key_value_indexes
=
packed_key_value_indexes
,
update_past_key_values
=
True
,
is_causal
=
False
,
**
extra_inputs
,
)
past_key_values
=
output
.
past_key_values
return
past_key_values
def
prepare_vae_latent
(
self
,
curr_kvlens
,
curr_rope
,
image_sizes
,
new_token_ids
):
packed_text_ids
,
packed_text_indexes
=
list
(),
list
()
packed_vae_position_ids
,
packed_vae_token_indexes
,
packed_init_noises
=
(
list
(),
list
(),
list
(),
)
packed_position_ids
,
packed_seqlens
,
packed_indexes
=
list
(),
list
(),
list
()
packed_key_value_indexes
=
list
()
query_curr
=
curr
=
0
for
(
H
,
W
),
curr_kvlen
,
curr_position_id
in
zip
(
image_sizes
,
curr_kvlens
,
curr_rope
):
packed_key_value_indexes
.
extend
(
range
(
curr
,
curr
+
curr_kvlen
))
curr
+=
curr_kvlen
packed_text_ids
.
append
(
new_token_ids
[
"start_of_image"
])
packed_text_indexes
.
append
(
query_curr
)
packed_indexes
.
append
(
curr
)
curr
+=
1
query_curr
+=
1
vae_posiiton_ids
=
self
.
get_flattened_position_ids
(
H
,
W
,
self
.
latent_downsample
,
max_num_patches_per_side
=
self
.
max_latent_size
,
)
packed_vae_position_ids
.
append
(
vae_posiiton_ids
)
h
,
w
=
H
//
self
.
latent_downsample
,
W
//
self
.
latent_downsample
num_image_tokens
=
h
*
w
packed_init_noises
.
append
(
torch
.
randn
(
num_image_tokens
,
self
.
latent_channel
*
self
.
latent_patch_size
**
2
)
)
packed_vae_token_indexes
.
extend
(
range
(
query_curr
,
query_curr
+
num_image_tokens
)
)
packed_indexes
.
extend
(
range
(
curr
,
curr
+
num_image_tokens
))
curr
+=
num_image_tokens
query_curr
+=
num_image_tokens
packed_text_ids
.
append
(
new_token_ids
[
"end_of_image"
])
packed_text_indexes
.
append
(
query_curr
)
packed_indexes
.
append
(
curr
)
curr
+=
1
query_curr
+=
1
packed_position_ids
.
extend
([
curr_position_id
]
*
(
num_image_tokens
+
2
))
packed_seqlens
.
append
(
num_image_tokens
+
2
)
generation_input
=
{
"packed_text_ids"
:
torch
.
tensor
(
packed_text_ids
,
dtype
=
torch
.
long
),
"packed_text_indexes"
:
torch
.
tensor
(
packed_text_indexes
,
dtype
=
torch
.
long
),
"packed_init_noises"
:
torch
.
cat
(
packed_init_noises
,
dim
=
0
),
"packed_vae_position_ids"
:
torch
.
cat
(
packed_vae_position_ids
,
dim
=
0
),
"packed_vae_token_indexes"
:
torch
.
tensor
(
packed_vae_token_indexes
,
dtype
=
torch
.
long
),
"packed_seqlens"
:
torch
.
tensor
(
packed_seqlens
,
dtype
=
torch
.
int
),
"packed_position_ids"
:
torch
.
tensor
(
packed_position_ids
,
dtype
=
torch
.
long
),
"key_values_lens"
:
torch
.
tensor
(
curr_kvlens
,
dtype
=
torch
.
int
),
"packed_indexes"
:
torch
.
tensor
(
packed_indexes
,
dtype
=
torch
.
long
),
"packed_key_value_indexes"
:
torch
.
tensor
(
packed_key_value_indexes
,
dtype
=
torch
.
long
),
}
return
generation_input
def
prepare_vae_latent_cfg
(
self
,
curr_kvlens
,
curr_rope
,
image_sizes
):
packed_position_ids
,
packed_indexes
,
packed_key_value_indexes
=
(
list
(),
list
(),
list
(),
)
query_curr
=
curr
=
0
for
(
H
,
W
),
curr_kvlen
,
curr_position_id
in
zip
(
image_sizes
,
curr_kvlens
,
curr_rope
):
packed_key_value_indexes
.
extend
(
range
(
curr
,
curr
+
curr_kvlen
))
curr
+=
curr_kvlen
packed_indexes
.
append
(
curr
)
curr
+=
1
query_curr
+=
1
h
,
w
=
H
//
self
.
latent_downsample
,
W
//
self
.
latent_downsample
num_image_tokens
=
h
*
w
packed_indexes
.
extend
(
range
(
curr
,
curr
+
num_image_tokens
))
curr
+=
num_image_tokens
query_curr
+=
num_image_tokens
packed_indexes
.
append
(
curr
)
curr
+=
1
query_curr
+=
1
packed_position_ids
.
extend
([
curr_position_id
]
*
(
num_image_tokens
+
2
))
generation_input
=
{
"cfg_packed_position_ids"
:
torch
.
tensor
(
packed_position_ids
,
dtype
=
torch
.
long
),
"cfg_key_values_lens"
:
torch
.
tensor
(
curr_kvlens
,
dtype
=
torch
.
int
),
"cfg_packed_query_indexes"
:
torch
.
tensor
(
packed_indexes
,
dtype
=
torch
.
long
),
"cfg_packed_key_value_indexes"
:
torch
.
tensor
(
packed_key_value_indexes
,
dtype
=
torch
.
long
),
}
return
generation_input
@
torch
.
no_grad
def
generate_image
(
self
,
packed_text_ids
:
torch
.
LongTensor
,
packed_text_indexes
:
torch
.
LongTensor
,
packed_init_noises
:
torch
.
Tensor
,
packed_vae_position_ids
:
torch
.
LongTensor
,
packed_vae_token_indexes
:
torch
.
LongTensor
,
packed_seqlens
:
torch
.
IntTensor
,
packed_position_ids
:
torch
.
LongTensor
,
packed_indexes
:
torch
.
LongTensor
,
past_key_values
:
NaiveCache
,
key_values_lens
:
torch
.
IntTensor
,
packed_key_value_indexes
:
torch
.
LongTensor
,
num_timesteps
:
int
=
24
,
timestep_shift
:
float
=
1.0
,
cfg_renorm_min
:
float
=
0.0
,
cfg_renorm_type
:
str
=
"global"
,
cfg_interval
:
Optional
[
Tuple
[
float
,
float
]]
=
[
0
,
1
],
# cfg_text
cfg_text_scale
:
float
=
1.0
,
cfg_text_packed_query_indexes
:
Optional
[
torch
.
LongTensor
]
=
None
,
cfg_text_packed_position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
cfg_text_past_key_values
:
Optional
[
NaiveCache
]
=
None
,
cfg_text_key_values_lens
:
Optional
[
torch
.
IntTensor
]
=
None
,
cfg_text_packed_key_value_indexes
:
Optional
[
torch
.
LongTensor
]
=
None
,
# cfg_img
cfg_img_scale
:
float
=
1.0
,
cfg_img_packed_query_indexes
:
Optional
[
torch
.
LongTensor
]
=
None
,
cfg_img_packed_position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
cfg_img_past_key_values
:
Optional
[
NaiveCache
]
=
None
,
cfg_img_key_values_lens
:
Optional
[
torch
.
IntTensor
]
=
None
,
cfg_img_packed_key_value_indexes
:
Optional
[
torch
.
LongTensor
]
=
None
,
cfg_type
:
str
=
"parallel"
,
# cache_args
enable_taylorseer
=
False
,
):
if
enable_taylorseer
:
self
.
language_model
.
model
.
enable_taylorseer
=
True
model_pred_cache_dic
,
model_pred_current
=
cache_init
(
self
,
num_timesteps
)
model_pred_text_cache_dic
,
model_pred_text_current
=
cache_init
(
self
,
num_timesteps
)
model_pred_img_cache_dic
,
model_pred_img_current
=
cache_init
(
self
,
num_timesteps
)
else
:
self
.
language_model
.
model
.
enable_taylorseer
=
False
model_pred_cache_dic
,
model_pred_current
=
None
,
None
model_pred_text_cache_dic
,
model_pred_text_current
=
None
,
None
model_pred_img_cache_dic
,
model_pred_img_current
=
None
,
None
x_t
=
packed_init_noises
timesteps
=
torch
.
linspace
(
1
,
0
,
num_timesteps
,
device
=
x_t
.
device
)
timesteps
=
timestep_shift
*
timesteps
/
(
1
+
(
timestep_shift
-
1
)
*
timesteps
)
dts
=
timesteps
[:
-
1
]
-
timesteps
[
1
:]
timesteps
=
timesteps
[:
-
1
]
for
i
,
t
in
tqdm
(
enumerate
(
timesteps
),
total
=
len
(
timesteps
)):
timestep
=
torch
.
tensor
([
t
]
*
x_t
.
shape
[
0
],
device
=
x_t
.
device
)
if
t
>
cfg_interval
[
0
]
and
t
<=
cfg_interval
[
1
]:
cfg_text_scale_
=
cfg_text_scale
cfg_img_scale_
=
cfg_img_scale
else
:
cfg_text_scale_
=
1.0
cfg_img_scale_
=
1.0
v_t
=
self
.
_forward_flow
(
x_t
=
x_t
,
timestep
=
timestep
,
packed_vae_token_indexes
=
packed_vae_token_indexes
,
packed_vae_position_ids
=
packed_vae_position_ids
,
packed_text_ids
=
packed_text_ids
,
packed_text_indexes
=
packed_text_indexes
,
packed_position_ids
=
packed_position_ids
,
packed_indexes
=
packed_indexes
,
packed_seqlens
=
packed_seqlens
,
key_values_lens
=
key_values_lens
,
past_key_values
=
past_key_values
,
packed_key_value_indexes
=
packed_key_value_indexes
,
cfg_renorm_min
=
cfg_renorm_min
,
cfg_renorm_type
=
cfg_renorm_type
,
# cfg_text
cfg_text_scale
=
cfg_text_scale_
,
cfg_text_packed_position_ids
=
cfg_text_packed_position_ids
,
cfg_text_packed_query_indexes
=
cfg_text_packed_query_indexes
,
cfg_text_key_values_lens
=
cfg_text_key_values_lens
,
cfg_text_past_key_values
=
cfg_text_past_key_values
,
cfg_text_packed_key_value_indexes
=
cfg_text_packed_key_value_indexes
,
# cfg_img
cfg_img_scale
=
cfg_img_scale_
,
cfg_img_packed_position_ids
=
cfg_img_packed_position_ids
,
cfg_img_packed_query_indexes
=
cfg_img_packed_query_indexes
,
cfg_img_key_values_lens
=
cfg_img_key_values_lens
,
cfg_img_past_key_values
=
cfg_img_past_key_values
,
cfg_img_packed_key_value_indexes
=
cfg_img_packed_key_value_indexes
,
cfg_type
=
cfg_type
,
# cache
model_pred_cache_dic
=
model_pred_cache_dic
,
model_pred_current
=
model_pred_current
,
model_pred_text_cache_dic
=
model_pred_text_cache_dic
,
model_pred_text_current
=
model_pred_text_current
,
model_pred_img_cache_dic
=
model_pred_img_cache_dic
,
model_pred_img_current
=
model_pred_img_current
,
)
x_t
=
(
x_t
-
v_t
.
to
(
x_t
.
device
)
*
dts
[
i
]
)
# velocity pointing from data to noise
if
enable_taylorseer
:
del
model_pred_cache_dic
,
model_pred_current
del
model_pred_text_cache_dic
,
model_pred_text_current
del
model_pred_img_cache_dic
,
model_pred_img_current
unpacked_latent
=
x_t
.
split
((
packed_seqlens
-
2
).
tolist
())
return
unpacked_latent
@
torch
.
no_grad
def
_forward_flow
(
self
,
x_t
:
torch
.
Tensor
,
timestep
:
torch
.
LongTensor
,
packed_vae_token_indexes
:
torch
.
LongTensor
,
packed_vae_position_ids
:
torch
.
LongTensor
,
packed_text_ids
:
torch
.
LongTensor
,
packed_text_indexes
:
torch
.
LongTensor
,
packed_indexes
:
torch
.
LongTensor
,
packed_position_ids
:
torch
.
LongTensor
,
packed_seqlens
:
torch
.
IntTensor
,
key_values_lens
:
torch
.
IntTensor
,
past_key_values
:
NaiveCache
,
packed_key_value_indexes
:
torch
.
LongTensor
,
cfg_renorm_min
:
float
=
0.0
,
cfg_renorm_type
:
str
=
"global"
,
# cfg_text
cfg_text_scale
:
float
=
1.0
,
cfg_text_packed_position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
cfg_text_packed_query_indexes
:
Optional
[
torch
.
LongTensor
]
=
None
,
cfg_text_key_values_lens
:
Optional
[
torch
.
Tensor
]
=
None
,
cfg_text_past_key_values
:
Optional
[
NaiveCache
]
=
None
,
cfg_text_packed_key_value_indexes
:
Optional
[
torch
.
LongTensor
]
=
None
,
# cfg_img
cfg_img_scale
:
float
=
1.0
,
cfg_img_packed_position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
cfg_img_packed_query_indexes
:
Optional
[
torch
.
LongTensor
]
=
None
,
cfg_img_key_values_lens
:
Optional
[
torch
.
Tensor
]
=
None
,
cfg_img_past_key_values
:
Optional
[
NaiveCache
]
=
None
,
cfg_img_packed_key_value_indexes
:
Optional
[
torch
.
LongTensor
]
=
None
,
cfg_type
:
str
=
"parallel"
,
# cache
model_pred_cache_dic
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
model_pred_current
:
Optional
[
int
]
=
None
,
model_pred_text_cache_dic
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
model_pred_text_current
:
Optional
[
int
]
=
None
,
model_pred_img_cache_dic
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
model_pred_img_current
:
Optional
[
int
]
=
None
,
):
packed_text_embedding
=
self
.
language_model
.
model
.
embed_tokens
(
packed_text_ids
)
packed_sequence
=
packed_text_embedding
.
new_zeros
(
(
sum
(
packed_seqlens
),
self
.
hidden_size
)
)
packed_sequence
[
packed_text_indexes
]
=
packed_text_embedding
assert
timestep
.
unique
().
shape
[
0
]
==
1
packed_pos_embed
=
self
.
latent_pos_embed
(
packed_vae_position_ids
)
packed_timestep_embeds
=
self
.
time_embedder
(
timestep
)
x_t
=
self
.
vae2llm
(
x_t
)
+
packed_timestep_embeds
+
packed_pos_embed
if
x_t
.
dtype
!=
packed_sequence
.
dtype
:
x_t
=
x_t
.
to
(
packed_sequence
.
dtype
)
packed_sequence
[
packed_vae_token_indexes
]
=
x_t
extra_inputs
=
{}
if
self
.
use_moe
:
extra_inputs
=
{
"mode"
:
"gen"
,
"packed_vae_token_indexes"
:
packed_vae_token_indexes
,
"packed_text_indexes"
:
packed_text_indexes
,
}
if
self
.
language_model
.
model
.
enable_taylorseer
:
self
.
language_model
.
model
.
cache_dic
=
model_pred_cache_dic
self
.
language_model
.
model
.
current
=
model_pred_current
output
=
self
.
language_model
.
forward_inference
(
packed_query_sequence
=
packed_sequence
,
query_lens
=
packed_seqlens
,
packed_query_position_ids
=
packed_position_ids
,
packed_query_indexes
=
packed_indexes
,
past_key_values
=
past_key_values
,
key_values_lens
=
key_values_lens
,
packed_key_value_indexes
=
packed_key_value_indexes
,
update_past_key_values
=
False
,
is_causal
=
False
,
**
extra_inputs
,
)
v_t
=
self
.
llm2vae
(
output
.
packed_query_sequence
)
v_t
=
v_t
[
packed_vae_token_indexes
]
if
cfg_text_scale
>
1.0
:
if
self
.
language_model
.
model
.
enable_taylorseer
:
self
.
language_model
.
model
.
cache_dic
=
model_pred_text_cache_dic
self
.
language_model
.
model
.
current
=
model_pred_text_current
cfg_text_output
=
self
.
language_model
.
forward_inference
(
packed_query_sequence
=
packed_sequence
,
query_lens
=
packed_seqlens
,
packed_query_position_ids
=
cfg_text_packed_position_ids
,
packed_query_indexes
=
cfg_text_packed_query_indexes
,
past_key_values
=
cfg_text_past_key_values
,
key_values_lens
=
cfg_text_key_values_lens
,
packed_key_value_indexes
=
cfg_text_packed_key_value_indexes
,
update_past_key_values
=
False
,
is_causal
=
False
,
**
extra_inputs
,
)
cfg_text_v_t
=
self
.
llm2vae
(
cfg_text_output
.
packed_query_sequence
)
cfg_text_v_t
=
cfg_text_v_t
[
packed_vae_token_indexes
]
if
cfg_img_scale
>
1.0
:
if
self
.
language_model
.
model
.
enable_taylorseer
:
self
.
language_model
.
model
.
cache_dic
=
model_pred_img_cache_dic
self
.
language_model
.
model
.
current
=
model_pred_img_current
cfg_img_output
=
self
.
language_model
.
forward_inference
(
packed_query_sequence
=
packed_sequence
,
query_lens
=
packed_seqlens
,
packed_query_position_ids
=
cfg_img_packed_position_ids
,
packed_query_indexes
=
cfg_img_packed_query_indexes
,
past_key_values
=
cfg_img_past_key_values
,
key_values_lens
=
cfg_img_key_values_lens
,
packed_key_value_indexes
=
cfg_img_packed_key_value_indexes
,
update_past_key_values
=
False
,
is_causal
=
False
,
**
extra_inputs
,
)
cfg_img_v_t
=
self
.
llm2vae
(
cfg_img_output
.
packed_query_sequence
)
cfg_img_v_t
=
cfg_img_v_t
[
packed_vae_token_indexes
]
if
cfg_text_scale
>
1.0
:
if
cfg_renorm_type
==
"text_channel"
:
v_t_text_
=
cfg_text_v_t
+
cfg_text_scale
*
(
v_t
-
cfg_text_v_t
)
norm_v_t
=
torch
.
norm
(
v_t
,
dim
=-
1
,
keepdim
=
True
)
norm_v_t_text_
=
torch
.
norm
(
v_t_text_
,
dim
=-
1
,
keepdim
=
True
)
scale
=
(
norm_v_t
/
(
norm_v_t_text_
+
1e-8
)).
clamp
(
min
=
cfg_renorm_min
,
max
=
1.0
)
v_t_text
=
v_t_text_
*
scale
if
cfg_img_scale
>
1.0
:
v_t
=
cfg_img_v_t
+
cfg_img_scale
*
(
v_t_text
-
cfg_img_v_t
)
else
:
v_t
=
v_t_text
else
:
v_t_text_
=
cfg_text_v_t
+
cfg_text_scale
*
(
v_t
-
cfg_text_v_t
)
if
cfg_img_scale
>
1.0
:
v_t_
=
cfg_img_v_t
+
cfg_img_scale
*
(
v_t_text_
-
cfg_img_v_t
)
else
:
v_t_
=
v_t_text_
# NOTE norm is computed over all dimensions, thus currently only supports batch_size = 1 with navit
if
cfg_renorm_type
==
"global"
:
norm_v_t
=
torch
.
norm
(
v_t
)
norm_v_t_
=
torch
.
norm
(
v_t_
)
elif
cfg_renorm_type
==
"channel"
:
norm_v_t
=
torch
.
norm
(
v_t
,
dim
=-
1
,
keepdim
=
True
)
norm_v_t_
=
torch
.
norm
(
v_t_
,
dim
=-
1
,
keepdim
=
True
)
else
:
raise
NotImplementedError
(
f
"
{
cfg_renorm_type
}
is not suppoprted"
)
scale
=
(
norm_v_t
/
(
norm_v_t_
+
1e-8
)).
clamp
(
min
=
cfg_renorm_min
,
max
=
1.0
)
v_t
=
v_t_
*
scale
else
:
# No CFG
pass
return
v_t
def
prepare_start_tokens
(
self
,
curr_kvlens
,
curr_rope
,
new_token_ids
):
packed_start_tokens
,
packed_key_value_indexes
=
list
(),
list
()
packed_query_position_ids
=
list
()
curr
=
0
for
curr_kvlen
,
curr_position_id
in
zip
(
curr_kvlens
,
curr_rope
):
packed_key_value_indexes
.
extend
(
range
(
curr
,
curr
+
curr_kvlen
))
packed_start_tokens
.
append
(
new_token_ids
[
"bos_token_id"
])
packed_query_position_ids
.
append
(
curr_position_id
)
curr
+=
curr_kvlen
generation_input
=
{
"packed_start_tokens"
:
torch
.
tensor
(
packed_start_tokens
,
dtype
=
torch
.
long
),
"packed_query_position_ids"
:
torch
.
tensor
(
packed_query_position_ids
,
dtype
=
torch
.
long
),
"key_values_lens"
:
torch
.
tensor
(
curr_kvlens
,
dtype
=
torch
.
int
),
"packed_key_value_indexes"
:
torch
.
tensor
(
packed_key_value_indexes
,
dtype
=
torch
.
long
),
}
return
generation_input
@
torch
.
no_grad
def
generate_text
(
self
,
past_key_values
:
NaiveCache
,
packed_key_value_indexes
:
torch
.
LongTensor
,
key_values_lens
:
torch
.
IntTensor
,
packed_start_tokens
:
torch
.
LongTensor
,
packed_query_position_ids
:
torch
.
LongTensor
,
max_length
:
int
,
do_sample
:
bool
=
False
,
temperature
:
float
=
1.0
,
end_token_id
:
int
=
None
,
):
step
=
0
generated_sequence
=
[]
curr_tokens
=
packed_start_tokens
while
step
<
max_length
:
generated_sequence
.
append
(
curr_tokens
)
packed_text_embedding
=
self
.
language_model
.
model
.
embed_tokens
(
curr_tokens
)
query_lens
=
torch
.
ones_like
(
curr_tokens
)
packed_query_indexes
=
torch
.
cumsum
(
key_values_lens
,
dim
=
0
)
+
torch
.
arange
(
0
,
len
(
key_values_lens
),
device
=
key_values_lens
.
device
,
dtype
=
key_values_lens
.
dtype
,
)
uppacked
=
list
(
packed_key_value_indexes
.
split
(
key_values_lens
.
tolist
(),
dim
=
0
)
)
for
i
in
range
(
len
(
uppacked
)):
uppacked
[
i
]
+=
i
packed_key_value_indexes
=
torch
.
cat
(
uppacked
,
dim
=
0
)
extra_inputs
=
{}
if
self
.
use_moe
:
extra_inputs
=
{
"mode"
:
"und"
}
output
=
self
.
language_model
.
forward_inference
(
packed_query_sequence
=
packed_text_embedding
,
query_lens
=
query_lens
,
packed_query_position_ids
=
packed_query_position_ids
,
packed_query_indexes
=
packed_query_indexes
,
past_key_values
=
past_key_values
,
key_values_lens
=
key_values_lens
,
packed_key_value_indexes
=
packed_key_value_indexes
,
update_past_key_values
=
True
,
is_causal
=
True
,
**
extra_inputs
,
)
past_key_values
=
output
.
past_key_values
packed_query_sequence
=
output
.
packed_query_sequence
pred_logits
=
self
.
language_model
.
lm_head
(
packed_query_sequence
)
if
do_sample
:
probs
=
nn
.
functional
.
softmax
(
pred_logits
/
temperature
,
dim
=-
1
)
curr_tokens
=
torch
.
multinomial
(
probs
,
num_samples
=
1
).
squeeze
(
1
)
else
:
curr_tokens
=
torch
.
argmax
(
pred_logits
,
dim
=-
1
)
uppacked
=
list
(
packed_key_value_indexes
.
split
(
key_values_lens
.
tolist
(),
dim
=
0
)
)
for
i
in
range
(
len
(
uppacked
)):
uppacked
[
i
]
=
torch
.
cat
(
[
uppacked
[
i
],
torch
.
tensor
([
uppacked
[
i
][
-
1
]
+
1
],
device
=
uppacked
[
i
].
device
),
],
dim
=
0
,
)
packed_key_value_indexes
=
torch
.
cat
(
uppacked
,
dim
=
0
)
key_values_lens
=
key_values_lens
+
1
packed_query_position_ids
=
packed_query_position_ids
+
1
step
+=
1
if
(
end_token_id
is
not
None
and
curr_tokens
[
0
]
==
end_token_id
):
# only support batch=1
break
output_device
=
generated_sequence
[
0
].
device
return
torch
.
stack
([
i
.
to
(
output_device
)
for
i
in
generated_sequence
],
dim
=
0
)
# for evaluation
@
torch
.
no_grad
()
def
chat
(
self
,
tokenizer
,
new_token_ids
,
image_transform
,
images
,
prompt
,
max_length
:
int
,
do_sample
:
bool
=
False
,
temperature
:
float
=
1.0
,
):
device
=
next
(
self
.
parameters
()).
device
if
isinstance
(
new_token_ids
,
dict
):
for
k
,
v
in
new_token_ids
.
items
():
if
torch
.
is_tensor
(
v
):
new_token_ids
[
k
]
=
v
.
to
(
device
)
elif
torch
.
is_tensor
(
new_token_ids
):
new_token_ids
=
new_token_ids
.
to
(
device
)
# prefill
past_key_values
=
NaiveCache
(
self
.
config
.
llm_config
.
num_hidden_layers
)
newlens
=
[
0
]
new_rope
=
[
0
]
# add images
for
image
in
images
:
generation_input
,
newlens
,
new_rope
=
self
.
prepare_vit_images
(
curr_kvlens
=
newlens
,
curr_rope
=
new_rope
,
images
=
[
image
],
transforms
=
image_transform
,
new_token_ids
=
new_token_ids
,
)
for
k
,
v
in
generation_input
.
items
():
if
torch
.
is_tensor
(
v
):
generation_input
[
k
]
=
v
.
to
(
device
)
with
torch
.
amp
.
autocast
(
"cuda"
,
enabled
=
True
,
dtype
=
torch
.
bfloat16
):
past_key_values
=
self
.
forward_cache_update_vit
(
past_key_values
,
**
generation_input
)
# add text
generation_input
,
newlens
,
new_rope
=
self
.
prepare_prompts
(
curr_kvlens
=
newlens
,
curr_rope
=
new_rope
,
prompts
=
[
prompt
],
tokenizer
=
tokenizer
,
new_token_ids
=
new_token_ids
,
)
for
k
,
v
in
generation_input
.
items
():
if
torch
.
is_tensor
(
v
):
generation_input
[
k
]
=
v
.
to
(
device
)
with
torch
.
amp
.
autocast
(
"cuda"
,
enabled
=
True
,
dtype
=
torch
.
bfloat16
):
past_key_values
=
self
.
forward_cache_update_text
(
past_key_values
,
**
generation_input
)
# decode
generation_input
=
self
.
prepare_start_tokens
(
newlens
,
new_rope
,
new_token_ids
)
for
k
,
v
in
generation_input
.
items
():
if
torch
.
is_tensor
(
v
):
generation_input
[
k
]
=
v
.
to
(
device
)
with
torch
.
amp
.
autocast
(
"cuda"
,
enabled
=
True
,
dtype
=
torch
.
bfloat16
):
unpacked_latent
=
self
.
generate_text
(
past_key_values
=
past_key_values
,
max_length
=
max_length
,
do_sample
=
do_sample
,
temperature
=
temperature
,
end_token_id
=
new_token_ids
[
"eos_token_id"
],
**
generation_input
,
)
output
=
tokenizer
.
decode
(
unpacked_latent
[:,
0
])
output
=
output
.
split
(
"<|im_end|>"
)[
0
].
split
(
"<|im_start|>"
)[
1
]
return
output
SenseNova-SI-main/sensenova_si/bagel_utils/modeling/bagel/modeling_utils.py
0 → 100644
View file @
876a36a4
# Copyright (c) 2022 Facebook, Inc. and its affiliates.
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: CC BY-NC 4.0
#
# This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025-05-20.
#
# Original file was released under CC BY-NC 4.0, with the full license text
# available at https://github.com/facebookresearch/DiT/blob/main/LICENSE.txt.
#
# This modified file is released under the same license.
import
math
import
numpy
as
np
import
torch
from
torch
import
nn
from
transformers.activations
import
ACT2FN
# --------------------------------------------------------
# 2D sine-cosine position embedding
# References:
# DiT: https://github.com/facebookresearch/DiT/blob/main/models.py
# --------------------------------------------------------
def
get_2d_sincos_pos_embed
(
embed_dim
,
grid_size
,
cls_token
=
False
,
extra_tokens
=
0
):
grid_h
=
np
.
arange
(
grid_size
,
dtype
=
np
.
float32
)
grid_w
=
np
.
arange
(
grid_size
,
dtype
=
np
.
float32
)
grid
=
np
.
meshgrid
(
grid_w
,
grid_h
)
# here w goes first
grid
=
np
.
stack
(
grid
,
axis
=
0
)
grid
=
grid
.
reshape
([
2
,
1
,
grid_size
,
grid_size
])
pos_embed
=
get_2d_sincos_pos_embed_from_grid
(
embed_dim
,
grid
)
if
cls_token
and
extra_tokens
>
0
:
pos_embed
=
np
.
concatenate
(
[
np
.
zeros
([
extra_tokens
,
embed_dim
]),
pos_embed
],
axis
=
0
)
return
pos_embed
def
get_2d_sincos_pos_embed_from_grid
(
embed_dim
,
grid
):
assert
embed_dim
%
2
==
0
# use half of dimensions to encode grid_h
emb_h
=
get_1d_sincos_pos_embed_from_grid
(
embed_dim
//
2
,
grid
[
0
])
# (H*W, D/2)
emb_w
=
get_1d_sincos_pos_embed_from_grid
(
embed_dim
//
2
,
grid
[
1
])
# (H*W, D/2)
emb
=
np
.
concatenate
([
emb_h
,
emb_w
],
axis
=
1
)
# (H*W, D)
return
emb
def
get_1d_sincos_pos_embed_from_grid
(
embed_dim
,
pos
):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert
embed_dim
%
2
==
0
omega
=
np
.
arange
(
embed_dim
//
2
,
dtype
=
np
.
float64
)
omega
/=
embed_dim
/
2.0
omega
=
1.0
/
10000
**
omega
# (D/2,)
pos
=
pos
.
reshape
(
-
1
)
# (M,)
out
=
np
.
einsum
(
"m,d->md"
,
pos
,
omega
)
# (M, D/2), outer product
emb_sin
=
np
.
sin
(
out
)
# (M, D/2)
emb_cos
=
np
.
cos
(
out
)
# (M, D/2)
emb
=
np
.
concatenate
([
emb_sin
,
emb_cos
],
axis
=
1
)
# (M, D)
return
emb
# --------------------------------------------------------
# TimestepEmbedder
# Reference:
# DiT: https://github.com/facebookresearch/DiT/blob/main/models.py
# --------------------------------------------------------
class
TimestepEmbedder
(
nn
.
Module
):
"""
Embeds scalar timesteps into vector representations.
"""
def
__init__
(
self
,
hidden_size
,
frequency_embedding_size
=
256
):
super
().
__init__
()
self
.
mlp
=
nn
.
Sequential
(
nn
.
Linear
(
frequency_embedding_size
,
hidden_size
,
bias
=
True
),
nn
.
SiLU
(),
nn
.
Linear
(
hidden_size
,
hidden_size
,
bias
=
True
),
)
self
.
frequency_embedding_size
=
frequency_embedding_size
@
staticmethod
def
timestep_embedding
(
t
,
dim
,
max_period
=
10000
):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
half
=
dim
//
2
freqs
=
torch
.
exp
(
-
math
.
log
(
max_period
)
*
torch
.
arange
(
start
=
0
,
end
=
half
,
dtype
=
torch
.
float32
)
/
half
).
to
(
device
=
t
.
device
)
args
=
t
[:,
None
].
float
()
*
freqs
[
None
]
embedding
=
torch
.
cat
([
torch
.
cos
(
args
),
torch
.
sin
(
args
)],
dim
=-
1
)
if
dim
%
2
:
embedding
=
torch
.
cat
(
[
embedding
,
torch
.
zeros_like
(
embedding
[:,
:
1
])],
dim
=-
1
)
return
embedding
def
forward
(
self
,
t
):
t_freq
=
self
.
timestep_embedding
(
t
,
self
.
frequency_embedding_size
)
t_emb
=
self
.
mlp
(
t_freq
)
return
t_emb
class
MLPconnector
(
nn
.
Module
):
def
__init__
(
self
,
in_dim
:
int
,
out_dim
:
int
,
hidden_act
:
str
):
super
().
__init__
()
self
.
activation_fn
=
ACT2FN
[
hidden_act
]
self
.
fc1
=
nn
.
Linear
(
in_dim
,
out_dim
)
self
.
fc2
=
nn
.
Linear
(
out_dim
,
out_dim
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
=
self
.
fc1
(
hidden_states
)
hidden_states
=
self
.
activation_fn
(
hidden_states
)
hidden_states
=
self
.
fc2
(
hidden_states
)
return
hidden_states
class
PositionEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
max_num_patch_per_side
,
hidden_size
):
super
().
__init__
()
self
.
max_num_patch_per_side
=
max_num_patch_per_side
self
.
hidden_size
=
hidden_size
self
.
pos_embed
=
nn
.
Parameter
(
torch
.
zeros
(
max_num_patch_per_side
**
2
,
hidden_size
),
requires_grad
=
False
)
self
.
_init_weights
()
def
_init_weights
(
self
):
# Initialize (and freeze) pos_embed by sin-cos embedding:
pos_embed
=
get_2d_sincos_pos_embed
(
self
.
hidden_size
,
self
.
max_num_patch_per_side
)
self
.
pos_embed
.
data
.
copy_
(
torch
.
from_numpy
(
pos_embed
).
float
())
def
forward
(
self
,
position_ids
):
return
self
.
pos_embed
[
position_ids
]
SenseNova-SI-main/sensenova_si/bagel_utils/modeling/bagel/qwen2_navit.py
0 → 100644
View file @
876a36a4
# Copyright (c) 2024 The Qwen Team and The HuggingFace Inc. team.
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
#
# This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025-05-20.
#
# Original file was released under Apache-2.0, with the full license text
# available at https://github.com/huggingface/transformers/blob/main/LICENSE.
#
# This modified file is released under the same license.
from
dataclasses
import
dataclass
from
functools
import
partial
from
typing
import
List
,
Optional
,
Tuple
import
torch
from
flash_attn
import
flash_attn_varlen_func
from
torch
import
nn
from
torch.nn.attention
import
SDPBackend
,
sdpa_kernel
# from torch.nn.attention.flex_attention import flex_attention
from
torch.nn.functional
import
scaled_dot_product_attention
from
transformers.utils
import
ModelOutput
from
..cache_utils.taylorseer
import
(
cal_type
,
derivative_approximation
,
taylor_cache_init
,
taylor_formula
,
)
from
..qwen2.configuration_qwen2
import
Qwen2Config
as
_Qwen2Config
from
..qwen2.modeling_qwen2
import
(
Qwen2Attention
,
Qwen2MLP
,
Qwen2PreTrainedModel
,
Qwen2RMSNorm
,
Qwen2RotaryEmbedding
,
apply_rotary_pos_emb
,
)
torch
.
_dynamo
.
config
.
cache_size_limit
=
512
torch
.
_dynamo
.
config
.
accumulated_cache_size_limit
=
4096
# flex_attention = torch.compile(flex_attention) # , dynamic=True, mode='max-autotune'
# flex_attention = torch.compile(flex_attention)
class
Qwen2Config
(
_Qwen2Config
):
r
"""
This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a
Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of
Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 151936):
Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`Qwen2Model`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 22016):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*, defaults to 32):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 32768):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be tied.
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
accordingly.
Expected contents:
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
use_sliding_window (`bool`, *optional*, defaults to `False`):
Whether to use sliding window attention.
sliding_window (`int`, *optional*, defaults to 4096):
Sliding window attention (SWA) window size. If not specified, will default to `4096`.
max_window_layers (`int`, *optional*, defaults to 28):
The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
```python
>>> from transformers import Qwen2Model, Qwen2Config
>>> # Initializing a Qwen2 style configuration
>>> configuration = Qwen2Config()
>>> # Initializing a model from the Qwen2-7B style configuration
>>> model = Qwen2Model(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type
=
"qwen2"
keys_to_ignore_at_inference
=
[
"past_key_values"
]
def
__init__
(
self
,
vocab_size
=
151936
,
hidden_size
=
4096
,
intermediate_size
=
22016
,
num_hidden_layers
=
32
,
num_attention_heads
=
32
,
num_key_value_heads
=
32
,
hidden_act
=
"silu"
,
max_position_embeddings
=
32768
,
initializer_range
=
0.02
,
rms_norm_eps
=
1e-6
,
use_cache
=
True
,
tie_word_embeddings
=
False
,
rope_theta
=
10000.0
,
rope_scaling
=
None
,
use_sliding_window
=
False
,
sliding_window
=
4096
,
max_window_layers
=
28
,
attention_dropout
=
0.0
,
is_causal
=
True
,
_attn_implementation
=
"flash_attention_2"
,
qk_norm
=
True
,
layer_module
=
"Qwen2DecoderLayer"
,
freeze_und
=
False
,
**
kwargs
,
):
super
().
__init__
(
vocab_size
=
vocab_size
,
hidden_size
=
hidden_size
,
intermediate_size
=
intermediate_size
,
num_hidden_layers
=
num_hidden_layers
,
num_attention_heads
=
num_attention_heads
,
num_key_value_heads
=
num_key_value_heads
,
hidden_act
=
hidden_act
,
max_position_embeddings
=
max_position_embeddings
,
initializer_range
=
initializer_range
,
rms_norm_eps
=
rms_norm_eps
,
use_cache
=
use_cache
,
tie_word_embeddings
=
tie_word_embeddings
,
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
use_sliding_window
=
use_sliding_window
,
sliding_window
=
sliding_window
,
max_window_layers
=
max_window_layers
,
attention_dropout
=
attention_dropout
,
is_causal
=
is_causal
,
_attn_implementation
=
_attn_implementation
,
**
kwargs
,
)
self
.
qk_norm
=
qk_norm
self
.
layer_module
=
layer_module
self
.
freeze_und
=
freeze_und
class
NaiveCache
:
def
__init__
(
self
,
num_layers
):
self
.
key_cache
=
{
k
:
None
for
k
in
range
(
num_layers
)}
self
.
value_cache
=
{
k
:
None
for
k
in
range
(
num_layers
)}
@
property
def
num_layers
(
self
):
return
len
(
self
.
key_cache
)
@
property
def
seq_lens
(
self
):
if
self
.
key_cache
[
0
]
is
not
None
:
return
self
.
key_cache
[
0
].
shape
[
0
]
else
:
return
0
@
dataclass
class
BaseNavitOutputWithPast
(
ModelOutput
):
packed_query_sequence
:
torch
.
FloatTensor
=
None
past_key_values
:
Optional
[
NaiveCache
]
=
None
def
pad_sequence
(
tensor
,
pad_size
):
H
,
L
,
D
=
tensor
.
shape
pad_tensor
=
tensor
.
new_zeros
((
H
,
pad_size
,
D
))
return
torch
.
cat
([
tensor
,
pad_tensor
],
dim
=
1
)
class
PackedAttention
(
Qwen2Attention
):
def
__init__
(
self
,
config
,
layer_idx
:
Optional
[
int
]
=
None
):
super
().
__init__
(
config
,
layer_idx
)
if
self
.
config
.
qk_norm
:
self
.
q_norm
=
Qwen2RMSNorm
(
self
.
head_dim
,
eps
=
config
.
rms_norm_eps
)
self
.
k_norm
=
Qwen2RMSNorm
(
self
.
head_dim
,
eps
=
config
.
rms_norm_eps
)
else
:
self
.
q_norm
=
nn
.
Identity
()
self
.
k_norm
=
nn
.
Identity
()
def
forward
(
self
,
*
args
,
**
kwargs
):
if
self
.
training
:
return
self
.
forward_train
(
*
args
,
**
kwargs
)
else
:
return
self
.
forward_inference
(
*
args
,
**
kwargs
)
def
forward_train
(
self
,
packed_sequence
:
torch
.
Tensor
,
sample_lens
:
List
[
int
],
attention_mask
:
List
[
torch
.
Tensor
],
packed_position_embeddings
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
):
packed_query_states
=
self
.
q_proj
(
packed_sequence
).
view
(
-
1
,
self
.
num_heads
,
self
.
head_dim
)
packed_key_states
=
self
.
k_proj
(
packed_sequence
).
view
(
-
1
,
self
.
num_key_value_heads
,
self
.
head_dim
)
packed_value_states
=
self
.
v_proj
(
packed_sequence
).
view
(
-
1
,
self
.
num_key_value_heads
,
self
.
head_dim
)
packed_query_states
=
self
.
q_norm
(
packed_query_states
)
packed_key_states
=
self
.
k_norm
(
packed_key_states
)
packed_cos
,
packed_sin
=
packed_position_embeddings
packed_query_states
,
packed_key_states
=
apply_rotary_pos_emb
(
packed_query_states
,
packed_key_states
,
packed_cos
,
packed_sin
,
unsqueeze_dim
=
1
,
)
if
isinstance
(
attention_mask
,
List
):
packed_key_states
=
packed_key_states
[:,
:,
None
,
:].
repeat
(
1
,
1
,
self
.
num_key_value_groups
,
1
)
packed_key_states
=
packed_key_states
.
reshape
(
-
1
,
self
.
num_heads
,
self
.
head_dim
)
packed_value_states
=
packed_value_states
[:,
:,
None
,
:].
repeat
(
1
,
1
,
self
.
num_key_value_groups
,
1
)
packed_value_states
=
packed_value_states
.
reshape
(
-
1
,
self
.
num_heads
,
self
.
head_dim
)
unpacked_query_states
=
packed_query_states
.
transpose
(
0
,
1
).
split
(
sample_lens
,
dim
=
1
)
unpacked_key_states
=
packed_key_states
.
transpose
(
0
,
1
).
split
(
sample_lens
,
dim
=
1
)
unpacked_value_states
=
packed_value_states
.
transpose
(
0
,
1
).
split
(
sample_lens
,
dim
=
1
)
upacked_attn_output
=
[]
for
(
query_states
,
key_states
,
value_states
,
attention_mask_per_sample
,
)
in
zip
(
unpacked_query_states
,
unpacked_key_states
,
unpacked_value_states
,
attention_mask
,
):
with
sdpa_kernel
(
backends
=
[
SDPBackend
.
EFFICIENT_ATTENTION
]):
attn_output
=
scaled_dot_product_attention
(
query_states
.
to
(
torch
.
bfloat16
).
unsqueeze
(
0
),
key_states
.
to
(
torch
.
bfloat16
).
unsqueeze
(
0
),
value_states
.
to
(
torch
.
bfloat16
).
unsqueeze
(
0
),
attention_mask_per_sample
.
to
(
torch
.
bfloat16
).
unsqueeze
(
0
),
)
upacked_attn_output
.
append
(
attn_output
.
squeeze
(
0
))
packed_attn_output
=
torch
.
cat
(
upacked_attn_output
,
dim
=
1
)
else
:
pad_size
=
sum
(
sample_lens
)
-
packed_query_states
.
shape
[
0
]
packed_query_states
=
pad_sequence
(
packed_query_states
.
permute
(
1
,
0
,
2
),
pad_size
)
packed_key_states
=
pad_sequence
(
packed_key_states
.
permute
(
1
,
0
,
2
),
pad_size
)
packed_value_states
=
pad_sequence
(
packed_value_states
.
permute
(
1
,
0
,
2
),
pad_size
)
packed_attn_output
=
flex_attention
(
packed_query_states
.
unsqueeze
(
0
),
packed_key_states
.
unsqueeze
(
0
),
packed_value_states
.
unsqueeze
(
0
),
enable_gqa
=
True
,
block_mask
=
attention_mask
,
)
end_index
=
packed_attn_output
.
shape
[
2
]
-
pad_size
packed_attn_output
=
packed_attn_output
[
0
,
:,
:
end_index
,
:]
packed_attn_output
=
packed_attn_output
.
transpose
(
0
,
1
).
reshape
(
-
1
,
self
.
hidden_size
)
packed_attn_output
=
self
.
o_proj
(
packed_attn_output
)
return
packed_attn_output
def
forward_inference
(
self
,
packed_query_sequence
:
torch
.
Tensor
,
query_lens
:
torch
.
Tensor
,
packed_query_position_embeddings
:
torch
.
Tensor
,
packed_query_indexes
:
torch
.
Tensor
,
past_key_values
:
Optional
[
NaiveCache
]
=
None
,
key_values_lens
:
Optional
[
torch
.
Tensor
]
=
None
,
packed_key_value_indexes
:
Optional
[
torch
.
Tensor
]
=
None
,
update_past_key_values
=
True
,
is_causal
=
True
,
):
packed_query_states
=
self
.
q_proj
(
packed_query_sequence
).
view
(
-
1
,
self
.
num_heads
,
self
.
head_dim
)
packed_key_states
=
self
.
k_proj
(
packed_query_sequence
).
view
(
-
1
,
self
.
num_key_value_heads
,
self
.
head_dim
)
packed_value_states
=
self
.
v_proj
(
packed_query_sequence
).
view
(
-
1
,
self
.
num_key_value_heads
,
self
.
head_dim
)
packed_query_states
=
self
.
q_norm
(
packed_query_states
)
packed_key_states
=
self
.
k_norm
(
packed_key_states
)
packed_cos
,
packed_sin
=
packed_query_position_embeddings
packed_query_states
,
packed_key_states
=
apply_rotary_pos_emb
(
packed_query_states
,
packed_key_states
,
packed_cos
,
packed_sin
,
unsqueeze_dim
=
1
,
)
packed_query_states
=
packed_query_states
.
to
(
torch
.
bfloat16
)
packed_key_states
=
packed_key_states
.
to
(
torch
.
bfloat16
)
packed_value_states
=
packed_value_states
.
to
(
torch
.
bfloat16
)
if
(
past_key_values
is
not
None
and
past_key_values
.
key_cache
[
self
.
layer_idx
]
is
not
None
):
past_key_states
=
past_key_values
.
key_cache
[
self
.
layer_idx
]
past_value_states
=
past_key_values
.
value_cache
[
self
.
layer_idx
]
seqlens
=
sum
(
query_lens
)
+
sum
(
key_values_lens
)
merged_key_states
=
past_key_states
.
new_zeros
(
(
seqlens
,
self
.
num_key_value_heads
,
self
.
head_dim
)
)
merged_value_states
=
past_key_states
.
new_zeros
(
(
seqlens
,
self
.
num_key_value_heads
,
self
.
head_dim
)
)
merged_key_states
[
packed_query_indexes
]
=
packed_key_states
merged_key_states
[
packed_key_value_indexes
]
=
past_key_states
merged_value_states
[
packed_query_indexes
]
=
packed_value_states
merged_value_states
[
packed_key_value_indexes
]
=
past_value_states
key_values_lens
=
key_values_lens
+
query_lens
else
:
merged_key_states
=
packed_key_states
merged_value_states
=
packed_value_states
key_values_lens
=
query_lens
cu_seqlens_q
=
torch
.
nn
.
functional
.
pad
(
torch
.
cumsum
(
query_lens
,
dim
=
0
),
(
1
,
0
))
cu_seqlens_k
=
torch
.
nn
.
functional
.
pad
(
torch
.
cumsum
(
key_values_lens
,
dim
=
0
),
(
1
,
0
)
)
packed_attn_output
=
flash_attn_varlen_func
(
q
=
packed_query_states
,
k
=
merged_key_states
,
v
=
merged_value_states
,
cu_seqlens_q
=
cu_seqlens_q
.
to
(
torch
.
int32
),
cu_seqlens_k
=
cu_seqlens_k
.
to
(
torch
.
int32
),
max_seqlen_q
=
max
(
query_lens
).
item
(),
max_seqlen_k
=
max
(
key_values_lens
).
item
(),
causal
=
is_causal
,
)
packed_attn_output
=
packed_attn_output
.
reshape
(
-
1
,
self
.
hidden_size
)
packed_attn_output
=
self
.
o_proj
(
packed_attn_output
)
if
update_past_key_values
:
past_key_values
.
key_cache
[
self
.
layer_idx
]
=
merged_key_states
past_key_values
.
value_cache
[
self
.
layer_idx
]
=
merged_value_states
return
packed_attn_output
,
past_key_values
class
PackedAttentionMoT
(
Qwen2Attention
):
def
__init__
(
self
,
config
,
layer_idx
:
Optional
[
int
]
=
None
):
super
().
__init__
(
config
,
layer_idx
)
if
self
.
config
.
qk_norm
:
self
.
q_norm
=
Qwen2RMSNorm
(
self
.
head_dim
,
eps
=
config
.
rms_norm_eps
)
self
.
k_norm
=
Qwen2RMSNorm
(
self
.
head_dim
,
eps
=
config
.
rms_norm_eps
)
self
.
q_norm_moe_gen
=
Qwen2RMSNorm
(
self
.
head_dim
,
eps
=
config
.
rms_norm_eps
)
self
.
k_norm_moe_gen
=
Qwen2RMSNorm
(
self
.
head_dim
,
eps
=
config
.
rms_norm_eps
)
else
:
self
.
q_norm
=
nn
.
Identity
()
self
.
k_norm
=
nn
.
Identity
()
self
.
q_norm_moe_gen
=
nn
.
Identity
()
self
.
k_norm_moe_gen
=
nn
.
Identity
()
self
.
q_proj_moe_gen
=
nn
.
Linear
(
self
.
hidden_size
,
self
.
num_heads
*
self
.
head_dim
,
bias
=
True
)
self
.
k_proj_moe_gen
=
nn
.
Linear
(
self
.
hidden_size
,
self
.
num_key_value_heads
*
self
.
head_dim
,
bias
=
True
)
self
.
v_proj_moe_gen
=
nn
.
Linear
(
self
.
hidden_size
,
self
.
num_key_value_heads
*
self
.
head_dim
,
bias
=
True
)
self
.
o_proj_moe_gen
=
nn
.
Linear
(
self
.
num_heads
*
self
.
head_dim
,
self
.
hidden_size
,
bias
=
False
)
def
forward
(
self
,
*
args
,
**
kwargs
):
if
self
.
training
:
return
self
.
forward_train
(
*
args
,
**
kwargs
)
else
:
return
self
.
forward_inference
(
*
args
,
**
kwargs
)
def
forward_train
(
self
,
packed_sequence
:
torch
.
Tensor
,
sample_lens
:
List
[
int
],
attention_mask
,
packed_position_embeddings
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
packed_und_token_indexes
:
torch
.
LongTensor
,
packed_gen_token_indexes
:
torch
.
LongTensor
,
):
packed_query_states
=
packed_sequence
.
new_zeros
(
(
packed_sequence
.
shape
[
0
],
self
.
num_heads
*
self
.
head_dim
)
)
packed_key_states
=
packed_sequence
.
new_zeros
(
(
packed_sequence
.
shape
[
0
],
self
.
num_key_value_heads
*
self
.
head_dim
)
)
packed_value_states
=
packed_sequence
.
new_zeros
(
(
packed_sequence
.
shape
[
0
],
self
.
num_key_value_heads
*
self
.
head_dim
)
)
packed_sequence_und
=
packed_sequence
[
packed_und_token_indexes
]
packed_sequence_gen
=
packed_sequence
[
packed_gen_token_indexes
]
packed_query_states
[
packed_und_token_indexes
]
=
self
.
q_proj
(
packed_sequence_und
)
packed_query_states
[
packed_gen_token_indexes
]
=
self
.
q_proj_moe_gen
(
packed_sequence_gen
)
packed_key_states
[
packed_und_token_indexes
]
=
self
.
k_proj
(
packed_sequence_und
)
packed_key_states
[
packed_gen_token_indexes
]
=
self
.
k_proj_moe_gen
(
packed_sequence_gen
)
packed_value_states
[
packed_und_token_indexes
]
=
self
.
v_proj
(
packed_sequence_und
)
packed_value_states
[
packed_gen_token_indexes
]
=
self
.
v_proj_moe_gen
(
packed_sequence_gen
)
packed_query_states
=
packed_query_states
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_dim
)
packed_key_states
=
packed_key_states
.
view
(
-
1
,
self
.
num_key_value_heads
,
self
.
head_dim
)
packed_value_states
=
packed_value_states
.
view
(
-
1
,
self
.
num_key_value_heads
,
self
.
head_dim
)
if
self
.
config
.
freeze_und
:
packed_value_states
[
packed_und_token_indexes
]
=
packed_value_states
[
packed_und_token_indexes
].
detach
()
packed_query_states_
=
packed_query_states
.
new_zeros
(
packed_query_states
.
shape
)
packed_key_states_
=
packed_key_states
.
new_zeros
(
packed_key_states
.
shape
)
packed_query_states_
[
packed_und_token_indexes
]
=
self
.
q_norm
(
packed_query_states
[
packed_und_token_indexes
]
)
if
self
.
config
.
freeze_und
:
packed_query_states_
[
packed_und_token_indexes
]
=
packed_query_states_
[
packed_und_token_indexes
].
detach
()
packed_query_states_
[
packed_gen_token_indexes
]
=
self
.
q_norm_moe_gen
(
packed_query_states
[
packed_gen_token_indexes
]
)
packed_key_states_
[
packed_und_token_indexes
]
=
self
.
k_norm
(
packed_key_states
[
packed_und_token_indexes
]
)
if
self
.
config
.
freeze_und
:
packed_key_states_
[
packed_und_token_indexes
]
=
packed_key_states_
[
packed_und_token_indexes
].
detach
()
packed_key_states_
[
packed_gen_token_indexes
]
=
self
.
k_norm_moe_gen
(
packed_key_states
[
packed_gen_token_indexes
]
)
packed_cos
,
packed_sin
=
packed_position_embeddings
packed_query_states_
,
packed_key_states_
=
apply_rotary_pos_emb
(
packed_query_states_
,
packed_key_states_
,
packed_cos
,
packed_sin
,
unsqueeze_dim
=
1
,
)
if
isinstance
(
attention_mask
,
List
):
packed_key_states_
=
packed_key_states_
[:,
:,
None
,
:].
repeat
(
1
,
1
,
self
.
num_key_value_groups
,
1
)
packed_key_states_
=
packed_key_states_
.
reshape
(
-
1
,
self
.
num_heads
,
self
.
head_dim
)
packed_value_states
=
packed_value_states
[:,
:,
None
,
:].
repeat
(
1
,
1
,
self
.
num_key_value_groups
,
1
)
packed_value_states
=
packed_value_states
.
reshape
(
-
1
,
self
.
num_heads
,
self
.
head_dim
)
unpacked_query_states
=
packed_query_states_
.
transpose
(
0
,
1
).
split
(
sample_lens
,
dim
=
1
)
unpacked_key_states
=
packed_key_states_
.
transpose
(
0
,
1
).
split
(
sample_lens
,
dim
=
1
)
unpacked_value_states
=
packed_value_states
.
transpose
(
0
,
1
).
split
(
sample_lens
,
dim
=
1
)
upacked_attn_output
=
[]
for
(
query_states
,
key_states
,
value_states
,
attention_mask_per_sample
,
)
in
zip
(
unpacked_query_states
,
unpacked_key_states
,
unpacked_value_states
,
attention_mask
,
):
with
sdpa_kernel
(
backends
=
[
SDPBackend
.
EFFICIENT_ATTENTION
]):
attn_output
=
scaled_dot_product_attention
(
query_states
.
to
(
torch
.
bfloat16
).
unsqueeze
(
0
),
key_states
.
to
(
torch
.
bfloat16
).
unsqueeze
(
0
),
value_states
.
to
(
torch
.
bfloat16
).
unsqueeze
(
0
),
attention_mask_per_sample
.
to
(
torch
.
bfloat16
).
unsqueeze
(
0
),
)
upacked_attn_output
.
append
(
attn_output
.
squeeze
(
0
))
packed_attn_output
=
torch
.
cat
(
upacked_attn_output
,
dim
=
1
)
else
:
pad_size
=
sum
(
sample_lens
)
-
packed_query_states
.
shape
[
0
]
packed_query_states_
=
pad_sequence
(
packed_query_states_
.
permute
(
1
,
0
,
2
),
pad_size
)
packed_key_states_
=
pad_sequence
(
packed_key_states_
.
permute
(
1
,
0
,
2
),
pad_size
)
packed_value_states
=
pad_sequence
(
packed_value_states
.
permute
(
1
,
0
,
2
),
pad_size
)
packed_attn_output
=
flex_attention
(
packed_query_states_
.
unsqueeze
(
0
),
# 1, num_head, L, head_dim
packed_key_states_
.
unsqueeze
(
0
),
packed_value_states
.
unsqueeze
(
0
),
enable_gqa
=
True
,
block_mask
=
attention_mask
,
)
end_index
=
packed_attn_output
.
shape
[
2
]
-
pad_size
packed_attn_output
=
packed_attn_output
[
0
,
:,
:
end_index
,
:]
packed_attn_output
=
packed_attn_output
.
transpose
(
0
,
1
).
reshape
(
-
1
,
self
.
num_heads
*
self
.
head_dim
)
packed_attn_output_
=
packed_attn_output
.
new_zeros
(
packed_attn_output
.
shape
)
packed_attn_output_
[
packed_und_token_indexes
]
=
self
.
o_proj
(
packed_attn_output
[
packed_und_token_indexes
]
)
packed_attn_output_
[
packed_gen_token_indexes
]
=
self
.
o_proj_moe_gen
(
packed_attn_output
[
packed_gen_token_indexes
]
)
return
packed_attn_output_
def
forward_inference
(
self
,
packed_query_sequence
:
torch
.
Tensor
,
query_lens
:
torch
.
Tensor
,
packed_query_position_embeddings
:
torch
.
Tensor
,
packed_query_indexes
:
torch
.
Tensor
,
past_key_values
:
Optional
[
NaiveCache
]
=
None
,
key_values_lens
:
Optional
[
torch
.
Tensor
]
=
None
,
packed_key_value_indexes
:
Optional
[
torch
.
Tensor
]
=
None
,
update_past_key_values
=
True
,
is_causal
=
True
,
mode
=
"und"
,
packed_vae_token_indexes
=
None
,
packed_text_indexes
=
None
,
):
if
mode
==
"und"
:
packed_query_states
=
self
.
q_proj
(
packed_query_sequence
).
view
(
-
1
,
self
.
num_heads
,
self
.
head_dim
)
packed_key_states
=
self
.
k_proj
(
packed_query_sequence
).
view
(
-
1
,
self
.
num_key_value_heads
,
self
.
head_dim
)
packed_value_states
=
self
.
v_proj
(
packed_query_sequence
).
view
(
-
1
,
self
.
num_key_value_heads
,
self
.
head_dim
)
packed_query_states
=
self
.
q_norm
(
packed_query_states
)
packed_key_states
=
self
.
k_norm
(
packed_key_states
)
elif
mode
==
"gen"
:
packed_query_sequence
=
packed_query_sequence
.
to
(
torch
.
bfloat16
)
packed_query_states
=
packed_query_sequence
.
new_zeros
(
(
packed_query_sequence
.
shape
[
0
],
self
.
num_heads
*
self
.
head_dim
)
)
packed_key_states
=
packed_query_sequence
.
new_zeros
(
(
packed_query_sequence
.
shape
[
0
],
self
.
num_key_value_heads
*
self
.
head_dim
,
)
)
packed_value_states
=
packed_query_sequence
.
new_zeros
(
(
packed_query_sequence
.
shape
[
0
],
self
.
num_key_value_heads
*
self
.
head_dim
,
)
)
packed_text_query_sequence
=
packed_query_sequence
[
packed_text_indexes
]
packed_vae_query_sequence
=
packed_query_sequence
[
packed_vae_token_indexes
]
packed_query_states
[
packed_text_indexes
]
=
self
.
q_proj
(
packed_text_query_sequence
)
packed_query_states
[
packed_vae_token_indexes
]
=
self
.
q_proj_moe_gen
(
packed_vae_query_sequence
)
packed_key_states
[
packed_text_indexes
]
=
self
.
k_proj
(
packed_text_query_sequence
)
packed_key_states
[
packed_vae_token_indexes
]
=
self
.
k_proj_moe_gen
(
packed_vae_query_sequence
)
packed_value_states
[
packed_text_indexes
]
=
self
.
v_proj
(
packed_text_query_sequence
)
packed_value_states
[
packed_vae_token_indexes
]
=
self
.
v_proj_moe_gen
(
packed_vae_query_sequence
)
packed_query_states
=
packed_query_states
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_dim
)
packed_key_states
=
packed_key_states
.
view
(
-
1
,
self
.
num_key_value_heads
,
self
.
head_dim
)
packed_value_states
=
packed_value_states
.
view
(
-
1
,
self
.
num_key_value_heads
,
self
.
head_dim
)
packed_query_states
=
packed_query_states
.
to
(
torch
.
float32
)
packed_query_states
[
packed_text_indexes
]
=
self
.
q_norm
(
packed_query_states
[
packed_text_indexes
]
)
packed_query_states
[
packed_vae_token_indexes
]
=
self
.
q_norm_moe_gen
(
packed_query_states
[
packed_vae_token_indexes
]
)
packed_key_states
=
packed_key_states
.
to
(
torch
.
float32
)
packed_key_states
[
packed_text_indexes
]
=
self
.
k_norm
(
packed_key_states
[
packed_text_indexes
]
)
packed_key_states
[
packed_vae_token_indexes
]
=
self
.
k_norm_moe_gen
(
packed_key_states
[
packed_vae_token_indexes
]
)
packed_cos
,
packed_sin
=
packed_query_position_embeddings
packed_query_states
,
packed_key_states
=
apply_rotary_pos_emb
(
packed_query_states
,
packed_key_states
,
packed_cos
,
packed_sin
,
unsqueeze_dim
=
1
,
)
packed_query_states
=
packed_query_states
.
to
(
torch
.
bfloat16
)
packed_key_states
=
packed_key_states
.
to
(
torch
.
bfloat16
)
packed_value_states
=
packed_value_states
.
to
(
torch
.
bfloat16
)
if
(
past_key_values
is
not
None
and
past_key_values
.
key_cache
[
self
.
layer_idx
]
is
not
None
):
past_key_states
=
past_key_values
.
key_cache
[
self
.
layer_idx
]
past_value_states
=
past_key_values
.
value_cache
[
self
.
layer_idx
]
seqlens
=
sum
(
query_lens
)
+
sum
(
key_values_lens
)
merged_key_states
=
past_key_states
.
new_zeros
(
size
=
[
seqlens
,
self
.
num_key_value_heads
,
self
.
head_dim
]
)
merged_value_states
=
past_key_states
.
new_zeros
(
size
=
[
seqlens
,
self
.
num_key_value_heads
,
self
.
head_dim
]
)
merged_key_states
[
packed_query_indexes
]
=
packed_key_states
merged_key_states
[
packed_key_value_indexes
]
=
past_key_states
merged_value_states
[
packed_query_indexes
]
=
packed_value_states
merged_value_states
[
packed_key_value_indexes
]
=
past_value_states
key_values_lens
=
key_values_lens
+
query_lens
else
:
merged_key_states
=
packed_key_states
merged_value_states
=
packed_value_states
key_values_lens
=
query_lens
cu_seqlens_q
=
torch
.
nn
.
functional
.
pad
(
torch
.
cumsum
(
query_lens
,
dim
=
0
),
(
1
,
0
))
cu_seqlens_k
=
torch
.
nn
.
functional
.
pad
(
torch
.
cumsum
(
key_values_lens
,
dim
=
0
),
(
1
,
0
)
)
packed_attn_output
=
flash_attn_varlen_func
(
q
=
packed_query_states
,
k
=
merged_key_states
,
v
=
merged_value_states
,
cu_seqlens_q
=
cu_seqlens_q
.
to
(
torch
.
int32
),
cu_seqlens_k
=
cu_seqlens_k
.
to
(
torch
.
int32
),
max_seqlen_q
=
max
(
query_lens
).
item
(),
max_seqlen_k
=
max
(
key_values_lens
).
item
(),
causal
=
is_causal
,
)
packed_attn_output
=
packed_attn_output
.
reshape
(
-
1
,
self
.
hidden_size
)
if
mode
==
"und"
:
packed_attn_output
=
self
.
o_proj
(
packed_attn_output
)
elif
mode
==
"gen"
:
packed_attn_output
[
packed_text_indexes
]
=
self
.
o_proj
(
packed_attn_output
[
packed_text_indexes
]
)
packed_attn_output
[
packed_vae_token_indexes
]
=
self
.
o_proj_moe_gen
(
packed_attn_output
[
packed_vae_token_indexes
]
)
if
update_past_key_values
:
past_key_values
.
key_cache
[
self
.
layer_idx
]
=
merged_key_states
past_key_values
.
value_cache
[
self
.
layer_idx
]
=
merged_value_states
return
packed_attn_output
,
past_key_values
class
Qwen2DecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
,
layer_idx
:
Optional
[
int
]
=
None
):
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
self_attn
=
PackedAttention
(
config
,
layer_idx
)
self
.
mlp
=
Qwen2MLP
(
config
)
self
.
input_layernorm
=
Qwen2RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
Qwen2RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
*
args
,
**
kwargs
):
if
self
.
training
:
return
self
.
forward_train
(
*
args
,
**
kwargs
)
else
:
return
self
.
forward_inference
(
*
args
,
**
kwargs
)
def
forward_train
(
self
,
packed_sequence
:
torch
.
Tensor
,
sample_lens
:
List
[
int
],
attention_mask
,
packed_position_embeddings
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
)
->
torch
.
Tensor
:
residual
=
packed_sequence
packed_sequence
=
self
.
input_layernorm
(
packed_sequence
)
# Self Attention
packed_sequence
=
self
.
self_attn
(
packed_sequence
=
packed_sequence
,
sample_lens
=
sample_lens
,
attention_mask
=
attention_mask
,
packed_position_embeddings
=
packed_position_embeddings
,
)
packed_sequence
=
residual
+
packed_sequence
# Fully Connected
residual
=
packed_sequence
packed_sequence
=
self
.
post_attention_layernorm
(
packed_sequence
)
packed_sequence
=
self
.
mlp
(
packed_sequence
)
packed_sequence
=
residual
+
packed_sequence
return
packed_sequence
def
forward_inference
(
self
,
packed_query_sequence
:
torch
.
Tensor
,
query_lens
:
torch
.
Tensor
,
packed_query_position_embeddings
:
torch
.
Tensor
,
packed_query_indexes
:
torch
.
Tensor
,
past_key_values
:
Optional
[
NaiveCache
]
=
None
,
key_values_lens
:
Optional
[
torch
.
Tensor
]
=
None
,
packed_key_value_indexes
:
Optional
[
torch
.
Tensor
]
=
None
,
update_past_key_values
=
True
,
is_causal
=
True
,
)
->
BaseNavitOutputWithPast
:
residual
=
packed_query_sequence
packed_query_sequence
=
self
.
input_layernorm
(
packed_query_sequence
)
# Self Attention
packed_query_sequence
,
past_key_values
=
self
.
self_attn
(
packed_query_sequence
=
packed_query_sequence
,
query_lens
=
query_lens
,
packed_query_position_embeddings
=
packed_query_position_embeddings
,
packed_query_indexes
=
packed_query_indexes
,
past_key_values
=
past_key_values
,
key_values_lens
=
key_values_lens
,
packed_key_value_indexes
=
packed_key_value_indexes
,
update_past_key_values
=
update_past_key_values
,
is_causal
=
is_causal
,
)
packed_query_sequence
=
residual
+
packed_query_sequence
# Fully Connected
residual
=
packed_query_sequence
packed_query_sequence
=
self
.
post_attention_layernorm
(
packed_query_sequence
)
packed_query_sequence
=
self
.
mlp
(
packed_query_sequence
)
packed_query_sequence
=
residual
+
packed_query_sequence
return
packed_query_sequence
,
past_key_values
class
Qwen2MoTDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
,
layer_idx
:
Optional
[
int
]
=
None
,
attn_module
:
Optional
[
Qwen2Attention
]
=
PackedAttentionMoT
,
):
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
freeze_und
=
config
.
freeze_und
self
.
self_attn
=
attn_module
(
config
,
layer_idx
)
self
.
mlp
=
Qwen2MLP
(
config
)
self
.
mlp_moe_gen
=
Qwen2MLP
(
config
)
self
.
input_layernorm
=
Qwen2RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
input_layernorm_moe_gen
=
Qwen2RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
Qwen2RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm_moe_gen
=
Qwen2RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
*
args
,
**
kwargs
):
if
self
.
training
:
return
self
.
forward_train
(
*
args
,
**
kwargs
)
else
:
return
self
.
forward_inference
(
*
args
,
**
kwargs
)
def
forward_train
(
self
,
packed_sequence
:
torch
.
Tensor
,
sample_lens
:
List
[
int
],
attention_mask
,
packed_position_embeddings
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
packed_und_token_indexes
:
torch
.
LongTensor
,
packed_gen_token_indexes
:
torch
.
LongTensor
,
)
->
torch
.
Tensor
:
residual
=
packed_sequence
packed_sequence_
=
packed_sequence
.
new_zeros
(
packed_sequence
.
shape
)
packed_sequence_
[
packed_und_token_indexes
]
=
self
.
input_layernorm
(
packed_sequence
[
packed_und_token_indexes
]
)
packed_sequence_
[
packed_gen_token_indexes
]
=
self
.
input_layernorm_moe_gen
(
packed_sequence
[
packed_gen_token_indexes
]
)
# Self Attention
packed_sequence_
=
self
.
self_attn
(
packed_sequence
=
packed_sequence_
,
sample_lens
=
sample_lens
,
attention_mask
=
attention_mask
,
packed_position_embeddings
=
packed_position_embeddings
,
packed_und_token_indexes
=
packed_und_token_indexes
,
packed_gen_token_indexes
=
packed_gen_token_indexes
,
)
if
self
.
freeze_und
:
packed_sequence_
[
packed_und_token_indexes
]
=
packed_sequence_
[
packed_und_token_indexes
].
detach
()
packed_sequence
=
residual
+
packed_sequence_
# Fully Connected
residual
=
packed_sequence
packed_sequence_
=
packed_sequence
.
new_zeros
(
packed_sequence
.
shape
)
packed_sequence_
[
packed_und_token_indexes
]
=
self
.
mlp
(
self
.
post_attention_layernorm
(
packed_sequence
[
packed_und_token_indexes
])
)
if
self
.
freeze_und
:
packed_sequence_
[
packed_und_token_indexes
]
=
packed_sequence_
[
packed_und_token_indexes
].
detach
()
packed_sequence_
[
packed_gen_token_indexes
]
=
self
.
mlp_moe_gen
(
self
.
post_attention_layernorm_moe_gen
(
packed_sequence
[
packed_gen_token_indexes
]
)
)
packed_sequence
=
residual
+
packed_sequence_
return
packed_sequence
def
forward_inference
(
self
,
packed_query_sequence
:
torch
.
Tensor
,
query_lens
:
torch
.
Tensor
,
packed_query_position_embeddings
:
torch
.
Tensor
,
packed_query_indexes
:
torch
.
Tensor
,
past_key_values
:
Optional
[
NaiveCache
]
=
None
,
key_values_lens
:
Optional
[
torch
.
Tensor
]
=
None
,
packed_key_value_indexes
:
Optional
[
torch
.
Tensor
]
=
None
,
update_past_key_values
=
True
,
is_causal
=
True
,
mode
=
"und"
,
packed_vae_token_indexes
=
None
,
packed_text_indexes
=
None
,
)
->
BaseNavitOutputWithPast
:
enable_taylorseer
=
getattr
(
self
,
"enable_taylorseer"
,
False
)
if
enable_taylorseer
and
self
.
current
[
"type"
]
==
"full"
:
self
.
current
[
"module"
]
=
"total"
taylor_cache_init
(
cache_dic
=
self
.
cache_dic
,
current
=
self
.
current
)
if
not
enable_taylorseer
or
(
enable_taylorseer
and
self
.
current
[
"type"
]
==
"full"
):
residual
=
packed_query_sequence
if
mode
==
"und"
:
packed_query_sequence
=
self
.
input_layernorm
(
packed_query_sequence
)
elif
mode
==
"gen"
:
packed_query_sequence_
=
torch
.
zeros_like
(
packed_query_sequence
)
packed_query_sequence_
[
packed_text_indexes
]
=
self
.
input_layernorm
(
packed_query_sequence
[
packed_text_indexes
]
)
packed_query_sequence_
[
packed_vae_token_indexes
]
=
(
self
.
input_layernorm_moe_gen
(
packed_query_sequence
[
packed_vae_token_indexes
]
)
)
packed_query_sequence
=
packed_query_sequence_
# Self Attention
packed_query_sequence
,
past_key_values
=
self
.
self_attn
(
packed_query_sequence
=
packed_query_sequence
,
query_lens
=
query_lens
,
packed_query_position_embeddings
=
packed_query_position_embeddings
,
packed_query_indexes
=
packed_query_indexes
,
past_key_values
=
past_key_values
,
key_values_lens
=
key_values_lens
,
packed_key_value_indexes
=
packed_key_value_indexes
,
update_past_key_values
=
update_past_key_values
,
is_causal
=
is_causal
,
mode
=
mode
,
packed_vae_token_indexes
=
packed_vae_token_indexes
,
packed_text_indexes
=
packed_text_indexes
,
)
packed_query_sequence
=
residual
+
packed_query_sequence
# Fully Connected
residual
=
packed_query_sequence
if
mode
==
"und"
:
packed_query_sequence
=
self
.
post_attention_layernorm
(
packed_query_sequence
)
packed_query_sequence
=
self
.
mlp
(
packed_query_sequence
)
elif
mode
==
"gen"
:
packed_text_query_sequence
=
packed_query_sequence
[
packed_text_indexes
]
packed_vae_query_sequence
=
packed_query_sequence
[
packed_vae_token_indexes
]
packed_text_query_sequence
=
self
.
post_attention_layernorm
(
packed_text_query_sequence
).
to
(
torch
.
bfloat16
)
packed_vae_query_sequence
=
self
.
post_attention_layernorm_moe_gen
(
packed_vae_query_sequence
).
to
(
torch
.
bfloat16
)
packed_query_sequence_
=
torch
.
zeros_like
(
packed_query_sequence
).
to
(
torch
.
bfloat16
)
packed_query_sequence_
[
packed_text_indexes
]
=
self
.
mlp
(
packed_text_query_sequence
)
packed_query_sequence_
[
packed_vae_token_indexes
]
=
self
.
mlp_moe_gen
(
packed_vae_query_sequence
)
packed_query_sequence
=
packed_query_sequence_
packed_query_sequence
=
residual
+
packed_query_sequence
if
enable_taylorseer
:
if
self
.
current
[
"type"
]
==
"full"
:
derivative_approximation
(
cache_dic
=
self
.
cache_dic
,
current
=
self
.
current
,
feature
=
packed_query_sequence
,
)
elif
self
.
current
[
"type"
]
==
"Taylor"
:
self
.
current
[
"module"
]
=
"total"
packed_query_sequence
=
taylor_formula
(
cache_dic
=
self
.
cache_dic
,
current
=
self
.
current
)
return
packed_query_sequence
,
past_key_values
class
Qwen2MoEDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
,
layer_idx
:
Optional
[
int
]
=
None
):
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
self_attn
=
PackedAttention
(
config
,
layer_idx
)
self
.
mlp
=
Qwen2MLP
(
config
)
self
.
mlp_moe_gen
=
Qwen2MLP
(
config
)
self
.
input_layernorm
=
Qwen2RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
Qwen2RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
*
args
,
**
kwargs
):
if
self
.
training
:
return
self
.
forward_train
(
*
args
,
**
kwargs
)
else
:
return
self
.
forward_inference
(
*
args
,
**
kwargs
)
def
forward_train
(
self
,
packed_sequence
:
torch
.
Tensor
,
sample_lens
:
List
[
int
],
attention_mask
,
packed_position_embeddings
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
packed_und_token_indexes
:
torch
.
LongTensor
,
packed_gen_token_indexes
:
torch
.
LongTensor
,
)
->
torch
.
Tensor
:
residual
=
packed_sequence
packed_sequence
=
self
.
input_layernorm
(
packed_sequence
)
# Self Attention
packed_sequence
=
self
.
self_attn
(
packed_sequence
=
packed_sequence
,
sample_lens
=
sample_lens
,
attention_mask
=
attention_mask
,
packed_position_embeddings
=
packed_position_embeddings
,
)
packed_sequence
=
residual
+
packed_sequence
# Fully Connected
residual
=
packed_sequence
packed_sequence
=
self
.
post_attention_layernorm
(
packed_sequence
)
packed_sequence_new
=
packed_sequence
.
new_zeros
(
packed_sequence
.
shape
)
packed_sequence_und
=
self
.
mlp
(
packed_sequence
[
packed_und_token_indexes
])
packed_sequence_gen
=
self
.
mlp_moe_gen
(
packed_sequence
[
packed_gen_token_indexes
]
)
packed_sequence_new
[
packed_und_token_indexes
]
=
packed_sequence_und
packed_sequence_new
[
packed_gen_token_indexes
]
=
packed_sequence_gen
packed_sequence
=
residual
+
packed_sequence_new
return
packed_sequence
def
forward_inference
(
self
,
packed_query_sequence
:
torch
.
Tensor
,
query_lens
:
torch
.
Tensor
,
packed_query_position_embeddings
:
torch
.
Tensor
,
packed_query_indexes
:
torch
.
Tensor
,
past_key_values
:
Optional
[
NaiveCache
]
=
None
,
key_values_lens
:
Optional
[
torch
.
Tensor
]
=
None
,
packed_key_value_indexes
:
Optional
[
torch
.
Tensor
]
=
None
,
update_past_key_values
=
True
,
is_causal
=
True
,
mode
=
"und"
,
packed_vae_token_indexes
=
None
,
packed_text_indexes
=
None
,
)
->
BaseNavitOutputWithPast
:
residual
=
packed_query_sequence
packed_query_sequence
=
self
.
input_layernorm
(
packed_query_sequence
)
# Self Attention
packed_query_sequence
,
past_key_values
=
self
.
self_attn
(
packed_query_sequence
=
packed_query_sequence
,
query_lens
=
query_lens
,
packed_query_position_embeddings
=
packed_query_position_embeddings
,
packed_query_indexes
=
packed_query_indexes
,
past_key_values
=
past_key_values
,
key_values_lens
=
key_values_lens
,
packed_key_value_indexes
=
packed_key_value_indexes
,
update_past_key_values
=
update_past_key_values
,
is_causal
=
is_causal
,
)
packed_query_sequence
=
residual
+
packed_query_sequence
# Fully Connected
residual
=
packed_query_sequence
packed_query_sequence
=
self
.
post_attention_layernorm
(
packed_query_sequence
)
if
mode
==
"und"
:
packed_query_sequence
=
self
.
mlp
(
packed_query_sequence
)
elif
mode
==
"gen"
:
packed_query_sequence_
=
torch
.
zeros_like
(
packed_query_sequence
).
to
(
torch
.
bfloat16
)
packed_query_sequence_
[
packed_text_indexes
]
=
self
.
mlp
(
packed_query_sequence
[
packed_text_indexes
]
)
packed_query_sequence_
[
packed_vae_token_indexes
]
=
self
.
mlp_moe_gen
(
packed_query_sequence
[
packed_vae_token_indexes
]
)
packed_query_sequence
=
packed_query_sequence_
packed_query_sequence
=
residual
+
packed_query_sequence
return
packed_query_sequence
,
past_key_values
Decoder_layer_dict
=
{
"Qwen2DecoderLayer"
:
Qwen2DecoderLayer
,
"Qwen2MoEDecoderLayer"
:
Qwen2MoEDecoderLayer
,
"Qwen2MoTDecoderLayer"
:
partial
(
Qwen2MoTDecoderLayer
,
attn_module
=
PackedAttentionMoT
),
}
class
Qwen2Model
(
Qwen2PreTrainedModel
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
self
.
use_moe
=
"Mo"
in
config
.
layer_module
self
.
embed_tokens
=
nn
.
Embedding
(
config
.
vocab_size
,
config
.
hidden_size
,
self
.
padding_idx
)
layer_module
=
Decoder_layer_dict
[
config
.
layer_module
]
self
.
layers
=
nn
.
ModuleList
(
[
layer_module
(
config
,
layer_idx
)
for
layer_idx
in
range
(
config
.
num_hidden_layers
)
]
)
self
.
norm
=
Qwen2RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
if
self
.
use_moe
:
self
.
norm_moe_gen
=
Qwen2RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
rotary_emb
=
Qwen2RotaryEmbedding
(
config
=
config
)
# Initialize weights and apply final processing
self
.
post_init
()
def
forward
(
self
,
*
args
,
**
kwargs
):
if
self
.
training
:
return
self
.
forward_train
(
*
args
,
**
kwargs
)
else
:
return
self
.
forward_inference
(
*
args
,
**
kwargs
)
def
forward_train
(
self
,
packed_sequence
:
torch
.
Tensor
,
sample_lens
:
List
[
int
],
attention_mask
,
packed_position_ids
:
torch
.
Tensor
,
packed_und_token_indexes
:
Optional
[
torch
.
LongTensor
]
=
None
,
packed_gen_token_indexes
:
Optional
[
torch
.
LongTensor
]
=
None
,
)
->
torch
.
Tensor
:
if
self
.
config
.
freeze_und
:
packed_sequence
[
packed_und_token_indexes
]
=
packed_sequence
[
packed_und_token_indexes
].
detach
()
# create position embeddings to be shared across the decoder layers
cos
,
sin
=
self
.
rotary_emb
(
packed_sequence
,
packed_position_ids
.
unsqueeze
(
0
))
cos
=
cos
.
squeeze
(
0
)
sin
=
sin
.
squeeze
(
0
)
packed_position_embeddings
=
(
cos
,
sin
)
extra_inputs
=
{}
if
self
.
use_moe
:
assert
packed_und_token_indexes
is
not
None
if
packed_gen_token_indexes
is
None
:
packed_gen_token_indexes
=
packed_und_token_indexes
.
new_ones
(
size
=
[
0
])
extra_inputs
.
update
(
packed_und_token_indexes
=
packed_und_token_indexes
,
packed_gen_token_indexes
=
packed_gen_token_indexes
,
)
for
decoder_layer
in
self
.
layers
:
packed_sequence
=
decoder_layer
(
packed_sequence
=
packed_sequence
,
sample_lens
=
sample_lens
,
attention_mask
=
attention_mask
,
packed_position_embeddings
=
packed_position_embeddings
,
**
extra_inputs
,
)
if
self
.
use_moe
:
packed_sequence_
=
torch
.
zeros_like
(
packed_sequence
)
packed_sequence_
[
packed_und_token_indexes
]
=
self
.
norm
(
packed_sequence
[
packed_und_token_indexes
]
)
if
self
.
config
.
freeze_und
:
packed_sequence_
[
packed_und_token_indexes
]
=
packed_sequence_
[
packed_und_token_indexes
].
detach
()
packed_sequence_
[
packed_gen_token_indexes
]
=
self
.
norm_moe_gen
(
packed_sequence
[
packed_gen_token_indexes
]
)
return
packed_sequence_
else
:
return
self
.
norm
(
packed_sequence
)
def
forward_inference
(
self
,
packed_query_sequence
:
torch
.
Tensor
,
query_lens
:
torch
.
Tensor
,
packed_query_position_ids
:
torch
.
Tensor
,
packed_query_indexes
:
torch
.
Tensor
,
past_key_values
:
Optional
[
NaiveCache
]
=
None
,
key_values_lens
:
Optional
[
torch
.
Tensor
]
=
None
,
packed_key_value_indexes
:
Optional
[
torch
.
Tensor
]
=
None
,
update_past_key_values
=
True
,
is_causal
=
True
,
mode
=
"und"
,
packed_vae_token_indexes
=
None
,
packed_text_indexes
=
None
,
)
->
BaseNavitOutputWithPast
:
enable_taylorseer
=
getattr
(
self
,
"enable_taylorseer"
,
False
)
if
enable_taylorseer
:
cal_type
(
self
.
cache_dic
,
self
.
current
)
self
.
current
[
"stream"
]
=
"layers_stream"
# create position embeddings to be shared across the decoder layers
cos
,
sin
=
self
.
rotary_emb
(
packed_query_sequence
,
packed_query_position_ids
.
unsqueeze
(
0
)
)
cos
=
cos
.
squeeze
(
0
)
sin
=
sin
.
squeeze
(
0
)
packed_query_position_embeddings
=
(
cos
,
sin
)
extra_inputs
=
{}
if
self
.
use_moe
:
extra_inputs
.
update
(
mode
=
mode
)
if
mode
==
"gen"
:
assert
packed_vae_token_indexes
is
not
None
assert
packed_text_indexes
is
not
None
extra_inputs
.
update
(
packed_vae_token_indexes
=
packed_vae_token_indexes
,
packed_text_indexes
=
packed_text_indexes
,
)
for
layer_idx
,
decoder_layer
in
enumerate
(
self
.
layers
):
if
enable_taylorseer
:
decoder_layer
.
current
=
self
.
current
decoder_layer
.
cache_dic
=
self
.
cache_dic
decoder_layer
.
enable_taylorseer
=
True
self
.
current
[
"layer"
]
=
layer_idx
packed_query_sequence
,
past_key_values
=
decoder_layer
(
packed_query_sequence
=
packed_query_sequence
,
query_lens
=
query_lens
,
packed_query_position_embeddings
=
packed_query_position_embeddings
,
packed_query_indexes
=
packed_query_indexes
,
past_key_values
=
past_key_values
,
key_values_lens
=
key_values_lens
,
packed_key_value_indexes
=
packed_key_value_indexes
,
update_past_key_values
=
update_past_key_values
,
is_causal
=
is_causal
,
**
extra_inputs
,
)
if
self
.
use_moe
:
if
mode
==
"und"
:
packed_query_sequence
=
self
.
norm
(
packed_query_sequence
)
elif
mode
==
"gen"
:
packed_query_sequence_
=
torch
.
zeros_like
(
packed_query_sequence
)
packed_query_sequence_
[
packed_text_indexes
]
=
self
.
norm
(
packed_query_sequence
[
packed_text_indexes
]
)
packed_query_sequence_
[
packed_vae_token_indexes
]
=
self
.
norm_moe_gen
(
packed_query_sequence
[
packed_vae_token_indexes
]
)
packed_query_sequence
=
packed_query_sequence_
else
:
packed_query_sequence
=
self
.
norm
(
packed_query_sequence
)
if
enable_taylorseer
:
self
.
current
[
"step"
]
+=
1
return
BaseNavitOutputWithPast
(
packed_query_sequence
=
packed_query_sequence
,
past_key_values
=
past_key_values
,
)
class
Qwen2ForCausalLM
(
Qwen2PreTrainedModel
):
_tied_weights_keys
=
[
"lm_head.weight"
]
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
model
=
Qwen2Model
(
config
)
self
.
vocab_size
=
config
.
vocab_size
self
.
lm_head
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
vocab_size
,
bias
=
False
)
# Initialize weights and apply final processing
self
.
post_init
()
def
init_moe
(
self
):
for
name
,
param
in
self
.
named_parameters
():
if
"moe_gen"
in
name
:
original_name
=
name
.
replace
(
"_moe_gen"
,
""
)
param
.
data
.
copy_
(
self
.
state_dict
()[
original_name
].
data
)
def
get_input_embeddings
(
self
):
return
self
.
model
.
embed_tokens
def
set_input_embeddings
(
self
,
value
):
self
.
model
.
embed_tokens
=
value
def
get_output_embeddings
(
self
):
return
self
.
lm_head
def
set_output_embeddings
(
self
,
new_embeddings
):
self
.
lm_head
=
new_embeddings
def
set_decoder
(
self
,
decoder
):
self
.
model
=
decoder
def
get_decoder
(
self
):
return
self
.
model
def
forward
(
self
,
*
args
,
**
kwargs
):
if
self
.
training
:
return
self
.
forward_train
(
*
args
,
**
kwargs
)
else
:
return
self
.
forward_inference
(
*
args
,
**
kwargs
)
def
forward_train
(
self
,
packed_sequence
:
torch
.
Tensor
,
sample_lens
:
List
[
int
],
attention_mask
,
packed_position_ids
:
torch
.
Tensor
,
packed_und_token_indexes
:
Optional
[
torch
.
LongTensor
]
=
None
,
packed_gen_token_indexes
:
Optional
[
torch
.
LongTensor
]
=
None
,
)
->
torch
.
Tensor
:
outputs
=
self
.
model
(
packed_sequence
=
packed_sequence
,
sample_lens
=
sample_lens
,
packed_position_ids
=
packed_position_ids
,
attention_mask
=
attention_mask
,
packed_und_token_indexes
=
packed_und_token_indexes
,
packed_gen_token_indexes
=
packed_gen_token_indexes
,
)
return
outputs
def
forward_inference
(
self
,
packed_query_sequence
:
torch
.
Tensor
,
query_lens
:
torch
.
Tensor
,
packed_query_position_ids
:
torch
.
Tensor
,
packed_query_indexes
:
torch
.
Tensor
,
past_key_values
:
Optional
[
NaiveCache
]
=
None
,
key_values_lens
:
Optional
[
torch
.
Tensor
]
=
None
,
packed_key_value_indexes
:
Optional
[
torch
.
Tensor
]
=
None
,
update_past_key_values
=
True
,
is_causal
=
True
,
mode
=
"und"
,
packed_vae_token_indexes
=
None
,
packed_text_indexes
=
None
,
)
->
BaseNavitOutputWithPast
:
outputs
=
self
.
model
(
packed_query_sequence
=
packed_query_sequence
,
query_lens
=
query_lens
,
packed_query_position_ids
=
packed_query_position_ids
,
packed_query_indexes
=
packed_query_indexes
,
past_key_values
=
past_key_values
,
key_values_lens
=
key_values_lens
,
packed_key_value_indexes
=
packed_key_value_indexes
,
update_past_key_values
=
update_past_key_values
,
is_causal
=
is_causal
,
mode
=
mode
,
packed_vae_token_indexes
=
packed_vae_token_indexes
,
packed_text_indexes
=
packed_text_indexes
,
)
return
outputs
SenseNova-SI-main/sensenova_si/bagel_utils/modeling/bagel/siglip_navit.py
0 → 100644
View file @
876a36a4
# Copyright (c) 2024 The HuggingFace Inc. team.
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
#
# This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025-05-20.
#
# Original file was released under Apache-2.0, with the full license text
# available at https://github.com/huggingface/transformers/blob/main/LICENSE.
#
# This modified file is released under the same license.
import
torch
from
flash_attn
import
flash_attn_varlen_func
from
torch
import
nn
from
transformers.activations
import
ACT2FN
from
..siglip.configuration_siglip
import
SiglipVisionConfig
as
_SiglipVisionConfig
from
..siglip.modeling_siglip
import
SiglipAttention
,
SiglipPreTrainedModel
class
SiglipVisionConfig
(
_SiglipVisionConfig
):
r
"""
This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a
Siglip vision encoder according to the specified arguments, defining the model architecture. Instantiating a
configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip
[google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
hidden_size (`int`, *optional*, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
intermediate_size (`int`, *optional*, defaults to 3072):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
num_hidden_layers (`int`, *optional*, defaults to 12):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
num_channels (`int`, *optional*, defaults to 3):
Number of channels in the input images.
image_size (`int`, *optional*, defaults to 224):
The size (resolution) of each image.
patch_size (`int`, *optional*, defaults to 16):
The size (resolution) of each patch.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported.
layer_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the layer normalization layers.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
Example:
```python
>>> from transformers import SiglipVisionConfig, SiglipVisionModel
>>> # Initializing a SiglipVisionConfig with google/siglip-base-patch16-224 style configuration
>>> configuration = SiglipVisionConfig()
>>> # Initializing a SiglipVisionModel (with random weights) from the google/siglip-base-patch16-224 style configuration
>>> model = SiglipVisionModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type
=
"siglip_vision_model"
def
__init__
(
self
,
hidden_size
=
768
,
intermediate_size
=
3072
,
num_hidden_layers
=
12
,
num_attention_heads
=
12
,
num_channels
=
3
,
image_size
=
224
,
patch_size
=
16
,
hidden_act
=
"gelu_pytorch_tanh"
,
layer_norm_eps
=
1e-6
,
attention_dropout
=
0.0
,
rope
=
True
,
**
kwargs
,
):
super
().
__init__
(
hidden_size
=
hidden_size
,
intermediate_size
=
intermediate_size
,
num_hidden_layers
=
num_hidden_layers
,
num_attention_heads
=
num_attention_heads
,
num_channels
=
num_channels
,
image_size
=
image_size
,
patch_size
=
patch_size
,
hidden_act
=
hidden_act
,
layer_norm_eps
=
layer_norm_eps
,
attention_dropout
=
attention_dropout
,
**
kwargs
,
)
self
.
rope
=
rope
class
RotaryEmbedding2D
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
,
max_h
,
max_w
,
base
=
10000
):
super
().
__init__
()
freq
=
torch
.
arange
(
0
,
dim
,
2
,
dtype
=
torch
.
int64
).
float
()
/
dim
inv_freq
=
1.0
/
(
base
**
freq
)
grid_h
=
torch
.
arange
(
0
,
max_h
)
grid_h
=
grid_h
.
to
(
inv_freq
.
dtype
)
grid_h
=
grid_h
[:,
None
].
repeat
(
1
,
max_w
)
grid_w
=
torch
.
arange
(
0
,
max_w
)
grid_w
=
grid_w
.
to
(
inv_freq
.
dtype
)
grid_w
=
grid_w
[
None
,
:].
repeat
(
max_h
,
1
)
cos_h
,
sin_h
=
self
.
_forward_one_side
(
grid_h
,
inv_freq
)
cos_w
,
sin_w
=
self
.
_forward_one_side
(
grid_w
,
inv_freq
)
self
.
register_buffer
(
"cos_h"
,
cos_h
)
self
.
register_buffer
(
"sin_h"
,
sin_h
)
self
.
register_buffer
(
"cos_w"
,
cos_w
)
self
.
register_buffer
(
"sin_w"
,
sin_w
)
def
_forward_one_side
(
self
,
grid
,
inv_freq
):
freqs
=
grid
[...,
None
]
*
inv_freq
[
None
,
None
,
:]
emb
=
torch
.
cat
((
freqs
,
freqs
),
dim
=-
1
).
flatten
(
0
,
1
)
return
emb
.
cos
(),
emb
.
sin
()
def
rotate_half
(
x
):
x1
=
x
[...,
:
x
.
shape
[
-
1
]
//
2
]
x2
=
x
[...,
x
.
shape
[
-
1
]
//
2
:]
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
def
apply_rotary_pos_emb
(
q
,
k
,
cos
,
sin
):
# unsqueeze due to the head dimension
cos
=
cos
.
unsqueeze
(
1
)
sin
=
sin
.
unsqueeze
(
1
)
q_embed
=
(
q
*
cos
)
+
(
rotate_half
(
q
)
*
sin
)
k_embed
=
(
k
*
cos
)
+
(
rotate_half
(
k
)
*
sin
)
return
q_embed
,
k_embed
class
SiglipVisionEmbeddings
(
nn
.
Module
):
def
__init__
(
self
,
config
:
SiglipVisionConfig
):
super
().
__init__
()
self
.
config
=
config
self
.
embed_dim
=
config
.
hidden_size
self
.
image_size
=
config
.
image_size
self
.
patch_size
=
config
.
patch_size
self
.
patch_embedding
=
nn
.
Conv2d
(
in_channels
=
config
.
num_channels
,
out_channels
=
self
.
embed_dim
,
kernel_size
=
self
.
patch_size
,
stride
=
self
.
patch_size
,
padding
=
"valid"
,
)
self
.
num_patches_per_side
=
self
.
image_size
//
self
.
patch_size
self
.
num_patches
=
self
.
num_patches_per_side
**
2
self
.
num_positions
=
self
.
num_patches
if
not
config
.
rope
:
self
.
position_embedding
=
nn
.
Embedding
(
self
.
num_positions
,
self
.
embed_dim
)
def
convert_conv2d_to_linear
(
self
,
config
,
meta
=
False
):
if
meta
:
linear_patch_embedding
=
nn
.
Linear
(
config
.
num_channels
*
self
.
patch_size
**
2
,
self
.
embed_dim
,
bias
=
True
,
device
=
"meta"
,
)
else
:
linear_patch_embedding
=
nn
.
Linear
(
config
.
num_channels
*
self
.
patch_size
**
2
,
self
.
embed_dim
,
bias
=
True
)
W
=
self
.
patch_embedding
.
weight
.
permute
(
0
,
2
,
3
,
1
).
reshape
(
self
.
embed_dim
,
config
.
num_channels
*
self
.
patch_size
**
2
)
linear_patch_embedding
.
weight
.
data
=
W
linear_patch_embedding
.
bias
.
data
=
self
.
patch_embedding
.
bias
.
data
del
self
.
patch_embedding
self
.
patch_embedding
=
linear_patch_embedding
def
forward
(
self
,
packed_pixel_values
:
torch
.
FloatTensor
,
packed_flattened_position_ids
:
torch
.
LongTensor
,
)
->
torch
.
Tensor
:
patch_embeds
=
self
.
patch_embedding
(
packed_pixel_values
)
if
not
self
.
config
.
rope
:
embeddings
=
patch_embeds
+
self
.
position_embedding
(
packed_flattened_position_ids
)
else
:
embeddings
=
patch_embeds
return
embeddings
class
SiglipFlashAttention2
(
SiglipAttention
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
IntTensor
,
max_seqlen
:
int
,
cos_h
:
torch
.
Tensor
=
None
,
sin_h
:
torch
.
Tensor
=
None
,
cos_w
:
torch
.
Tensor
=
None
,
sin_w
:
torch
.
Tensor
=
None
,
**
kwargs
,
)
->
torch
.
Tensor
:
total_q_len
,
_
=
hidden_states
.
size
()
query_states
=
self
.
q_proj
(
hidden_states
)
key_states
=
self
.
k_proj
(
hidden_states
)
value_states
=
self
.
v_proj
(
hidden_states
)
query_states
=
query_states
.
view
(
total_q_len
,
self
.
num_heads
,
self
.
head_dim
)
key_states
=
key_states
.
view
(
total_q_len
,
self
.
num_heads
,
self
.
head_dim
)
value_states
=
value_states
.
view
(
total_q_len
,
self
.
num_heads
,
self
.
head_dim
)
if
self
.
config
.
rope
:
qh
,
qw
=
(
query_states
[:,
:,
:
self
.
head_dim
//
2
],
query_states
[:,
:,
self
.
head_dim
//
2
:],
)
kh
,
kw
=
(
key_states
[:,
:,
:
self
.
head_dim
//
2
],
key_states
[:,
:,
self
.
head_dim
//
2
:],
)
qh
,
kh
=
apply_rotary_pos_emb
(
qh
,
kh
,
cos_h
,
sin_h
)
qw
,
kw
=
apply_rotary_pos_emb
(
qw
,
kw
,
cos_w
,
sin_w
)
query_states
=
torch
.
cat
([
qh
,
qw
],
dim
=-
1
)
key_states
=
torch
.
cat
([
kh
,
kw
],
dim
=-
1
)
attn_output
=
flash_attn_varlen_func
(
query_states
.
to
(
torch
.
bfloat16
),
key_states
.
to
(
torch
.
bfloat16
),
value_states
.
to
(
torch
.
bfloat16
),
cu_seqlens_q
=
cu_seqlens
,
cu_seqlens_k
=
cu_seqlens
,
max_seqlen_q
=
max_seqlen
,
max_seqlen_k
=
max_seqlen
,
causal
=
False
,
)
attn_output
=
self
.
out_proj
(
attn_output
.
reshape
(
total_q_len
,
-
1
))
return
attn_output
class
SiglipMLP
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
config
=
config
self
.
activation_fn
=
ACT2FN
[
config
.
hidden_act
]
self
.
fc1
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
intermediate_size
)
self
.
fc2
=
nn
.
Linear
(
config
.
intermediate_size
,
config
.
hidden_size
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
=
self
.
fc1
(
hidden_states
)
hidden_states
=
self
.
activation_fn
(
hidden_states
)
hidden_states
=
self
.
fc2
(
hidden_states
)
return
hidden_states
class
SiglipEncoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
SiglipVisionConfig
):
super
().
__init__
()
self
.
embed_dim
=
config
.
hidden_size
self
.
self_attn
=
SiglipFlashAttention2
(
config
)
self
.
layer_norm1
=
nn
.
LayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_eps
)
self
.
mlp
=
SiglipMLP
(
config
)
self
.
layer_norm2
=
nn
.
LayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_eps
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
IntTensor
,
max_seqlen
:
int
,
cos_h
:
torch
.
Tensor
=
None
,
sin_h
:
torch
.
Tensor
=
None
,
cos_w
:
torch
.
Tensor
=
None
,
sin_w
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
residual
=
hidden_states
hidden_states
=
self
.
layer_norm1
(
hidden_states
)
hidden_states
=
self
.
self_attn
(
hidden_states
=
hidden_states
,
cu_seqlens
=
cu_seqlens
,
max_seqlen
=
max_seqlen
,
cos_h
=
cos_h
,
sin_h
=
sin_h
,
cos_w
=
cos_w
,
sin_w
=
sin_w
,
)
hidden_states
=
residual
+
hidden_states
residual
=
hidden_states
hidden_states
=
self
.
layer_norm2
(
hidden_states
)
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
return
hidden_states
class
SiglipEncoder
(
nn
.
Module
):
def
__init__
(
self
,
config
:
SiglipVisionConfig
):
super
().
__init__
()
self
.
config
=
config
self
.
layers
=
nn
.
ModuleList
(
[
SiglipEncoderLayer
(
config
)
for
_
in
range
(
config
.
num_hidden_layers
)]
)
def
forward
(
self
,
inputs_embeds
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
IntTensor
,
max_seqlen
:
int
,
cos_h
:
torch
.
Tensor
=
None
,
sin_h
:
torch
.
Tensor
=
None
,
cos_w
:
torch
.
Tensor
=
None
,
sin_w
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
inputs_embeds
for
encoder_layer
in
self
.
layers
:
hidden_states
=
encoder_layer
(
hidden_states
,
cu_seqlens
,
max_seqlen
,
cos_h
=
cos_h
,
sin_h
=
sin_h
,
cos_w
=
cos_w
,
sin_w
=
sin_w
,
)
return
hidden_states
class
SiglipVisionTransformer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
SiglipVisionConfig
):
super
().
__init__
()
self
.
config
=
config
embed_dim
=
config
.
hidden_size
self
.
embeddings
=
SiglipVisionEmbeddings
(
config
)
if
config
.
rope
:
max_size
=
config
.
image_size
//
config
.
patch_size
dim_head
=
config
.
hidden_size
//
config
.
num_attention_heads
self
.
rope
=
RotaryEmbedding2D
(
dim_head
//
2
,
max_size
,
max_size
)
self
.
encoder
=
SiglipEncoder
(
config
)
self
.
post_layernorm
=
nn
.
LayerNorm
(
embed_dim
,
eps
=
config
.
layer_norm_eps
)
def
forward
(
self
,
packed_pixel_values
:
torch
.
Tensor
,
packed_flattened_position_ids
:
torch
.
LongTensor
,
cu_seqlens
:
torch
.
IntTensor
,
max_seqlen
:
int
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embeddings
(
packed_pixel_values
=
packed_pixel_values
,
packed_flattened_position_ids
=
packed_flattened_position_ids
,
)
extra_inputs
=
{}
if
self
.
config
.
rope
:
extra_inputs
.
update
(
cos_h
=
self
.
rope
.
cos_h
[
packed_flattened_position_ids
],
sin_h
=
self
.
rope
.
sin_h
[
packed_flattened_position_ids
],
cos_w
=
self
.
rope
.
cos_w
[
packed_flattened_position_ids
],
sin_w
=
self
.
rope
.
sin_w
[
packed_flattened_position_ids
],
)
last_hidden_state
=
self
.
encoder
(
inputs_embeds
=
hidden_states
,
cu_seqlens
=
cu_seqlens
,
max_seqlen
=
max_seqlen
,
**
extra_inputs
,
)
last_hidden_state
=
self
.
post_layernorm
(
last_hidden_state
)
return
last_hidden_state
class
SiglipVisionModel
(
SiglipPreTrainedModel
):
config_class
=
SiglipVisionConfig
main_input_name
=
"packed_pixel_values"
def
__init__
(
self
,
config
:
SiglipVisionConfig
):
super
().
__init__
(
config
)
self
.
vision_model
=
SiglipVisionTransformer
(
config
)
# Initialize weights and apply final processing
self
.
post_init
()
def
get_input_embeddings
(
self
)
->
nn
.
Module
:
return
self
.
vision_model
.
embeddings
.
patch_embedding
def
forward
(
self
,
packed_pixel_values
:
torch
.
Tensor
,
packed_flattened_position_ids
:
torch
.
LongTensor
,
cu_seqlens
:
torch
.
IntTensor
,
max_seqlen
:
int
,
)
->
torch
.
Tensor
:
return
self
.
vision_model
(
packed_pixel_values
=
packed_pixel_values
,
packed_flattened_position_ids
=
packed_flattened_position_ids
,
cu_seqlens
=
cu_seqlens
,
max_seqlen
=
max_seqlen
,
)
SenseNova-SI-main/sensenova_si/bagel_utils/modeling/cache_utils/taylorseer.py
0 → 100644
View file @
876a36a4
"""
Utility for TaylorSeer
"""
# Adapted from https://github.com/Shenyi-Z/TaylorSeer/blob/main/TaylorSeers-xDiT/taylorseer_flux/taylorseer_utils/__init__.py
import
math
from
typing
import
Dict
import
torch
def
derivative_approximation
(
cache_dic
:
Dict
,
current
:
Dict
,
feature
:
torch
.
Tensor
):
"""
Compute derivative approximation.
:param cache_dic: Cache dictionary
:param current: Information of the current step
"""
difference_distance
=
(
current
[
"activated_steps"
][
-
1
]
-
current
[
"activated_steps"
][
-
2
]
)
updated_taylor_factors
=
{}
updated_taylor_factors
[
0
]
=
feature
for
i
in
range
(
cache_dic
[
"max_order"
]):
if
(
cache_dic
[
"cache"
][
-
1
][
current
[
"stream"
]][
current
[
"layer"
]][
current
[
"module"
]
].
get
(
i
,
None
)
is
not
None
)
and
(
current
[
"step"
]
>
cache_dic
[
"first_enhance"
]
-
2
):
updated_taylor_factors
[
i
+
1
]
=
(
updated_taylor_factors
[
i
]
-
cache_dic
[
"cache"
][
-
1
][
current
[
"stream"
]][
current
[
"layer"
]][
current
[
"module"
]
][
i
]
)
/
difference_distance
else
:
break
cache_dic
[
"cache"
][
-
1
][
current
[
"stream"
]][
current
[
"layer"
]][
current
[
"module"
]]
=
(
updated_taylor_factors
)
def
taylor_formula
(
cache_dic
:
Dict
,
current
:
Dict
)
->
torch
.
Tensor
:
"""
Compute Taylor expansion error.
:param cache_dic: Cache dictionary
:param current: Information of the current step
"""
x
=
current
[
"step"
]
-
current
[
"activated_steps"
][
-
1
]
# x = current['t'] - current['activated_times'][-1]
output
=
0
for
i
in
range
(
len
(
cache_dic
[
"cache"
][
-
1
][
current
[
"stream"
]][
current
[
"layer"
]][
current
[
"module"
]
]
)
):
output
+=
(
(
1
/
math
.
factorial
(
i
))
*
cache_dic
[
"cache"
][
-
1
][
current
[
"stream"
]][
current
[
"layer"
]][
current
[
"module"
]
][
i
]
*
(
x
**
i
)
)
return
output
def
taylor_cache_init
(
cache_dic
:
Dict
,
current
:
Dict
):
"""
Initialize Taylor cache and allocate storage for different-order derivatives in the Taylor cache.
:param cache_dic: Cache dictionary
:param current: Information of the current step
"""
if
(
current
[
"step"
]
==
0
)
and
(
cache_dic
[
"taylor_cache"
]):
cache_dic
[
"cache"
][
-
1
][
current
[
"stream"
]][
current
[
"layer"
]][
current
[
"module"
]
]
=
{}
# Copied from https://github.com/Shenyi-Z/TaylorSeer/blob/main/TaylorSeers-xDiT/taylorseer_flux/cache_functions/force_scheduler.py
def
force_scheduler
(
cache_dic
,
current
):
if
cache_dic
[
"fresh_ratio"
]
==
0
:
# FORA
linear_step_weight
=
0.0
else
:
# TokenCache
linear_step_weight
=
0.0
step_factor
=
torch
.
tensor
(
1
-
linear_step_weight
+
2
*
linear_step_weight
*
current
[
"step"
]
/
current
[
"num_steps"
]
)
threshold
=
torch
.
round
(
cache_dic
[
"fresh_threshold"
]
/
step_factor
)
# no force constrain for sensitive steps, cause the performance is good enough.
# you may have a try.
cache_dic
[
"cal_threshold"
]
=
threshold
# return threshold
# Copied from https://github.com/Shenyi-Z/TaylorSeer/blob/main/TaylorSeers-xDiT/taylorseer_flux/cache_functions/cal_type.py
def
cal_type
(
cache_dic
,
current
):
"""
Determine calculation type for this step
"""
if
(
cache_dic
[
"fresh_ratio"
]
==
0.0
)
and
(
not
cache_dic
[
"taylor_cache"
]):
# FORA:Uniform
first_step
=
current
[
"step"
]
==
0
else
:
# ToCa: First enhanced
first_step
=
current
[
"step"
]
<
cache_dic
[
"first_enhance"
]
if
not
first_step
:
fresh_interval
=
cache_dic
[
"cal_threshold"
]
else
:
fresh_interval
=
cache_dic
[
"fresh_threshold"
]
if
(
first_step
)
or
(
cache_dic
[
"cache_counter"
]
==
fresh_interval
-
1
):
current
[
"type"
]
=
"full"
cache_dic
[
"cache_counter"
]
=
0
current
[
"activated_steps"
].
append
(
current
[
"step"
])
force_scheduler
(
cache_dic
,
current
)
elif
cache_dic
[
"taylor_cache"
]:
cache_dic
[
"cache_counter"
]
+=
1
current
[
"type"
]
=
"Taylor"
elif
(
cache_dic
[
"cache_counter"
]
%
2
==
1
):
# 0: ToCa-Aggresive-ToCa, 1: Aggresive-ToCa-Aggresive
cache_dic
[
"cache_counter"
]
+=
1
current
[
"type"
]
=
"ToCa"
# 'cache_noise' 'ToCa' 'FORA'
elif
cache_dic
[
"Delta-DiT"
]:
cache_dic
[
"cache_counter"
]
+=
1
current
[
"type"
]
=
"Delta-Cache"
else
:
cache_dic
[
"cache_counter"
]
+=
1
current
[
"type"
]
=
"ToCa"
# Modified from https://github.com/Shenyi-Z/TaylorSeer/blob/main/TaylorSeers-xDiT/taylorseer_flux/cache_functions/cache_init.py
def
cache_init
(
self
,
num_steps
:
int
):
"""
Initialization for cache.
"""
cache_dic
=
{}
cache
=
{}
cache_index
=
{}
cache
[
-
1
]
=
{}
cache_index
[
-
1
]
=
{}
cache_index
[
"layer_index"
]
=
{}
cache
[
-
1
][
"layers_stream"
]
=
{}
cache_dic
[
"cache_counter"
]
=
0
for
j
in
range
(
len
(
self
.
language_model
.
model
.
layers
)):
cache
[
-
1
][
"layers_stream"
][
j
]
=
{}
cache_index
[
-
1
][
j
]
=
{}
cache_dic
[
"Delta-DiT"
]
=
False
cache_dic
[
"cache_type"
]
=
"random"
cache_dic
[
"cache_index"
]
=
cache_index
cache_dic
[
"cache"
]
=
cache
cache_dic
[
"fresh_ratio_schedule"
]
=
"ToCa"
cache_dic
[
"fresh_ratio"
]
=
0.0
cache_dic
[
"fresh_threshold"
]
=
3
cache_dic
[
"soft_fresh_weight"
]
=
0.0
cache_dic
[
"taylor_cache"
]
=
True
cache_dic
[
"max_order"
]
=
6
cache_dic
[
"first_enhance"
]
=
5
current
=
{}
current
[
"activated_steps"
]
=
[
0
]
current
[
"step"
]
=
0
current
[
"num_steps"
]
=
num_steps
return
cache_dic
,
current
Prev
1
2
3
4
5
6
…
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment