Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ModelZoo
SenseNova-SI
Commits
876a36a4
Commit
876a36a4
authored
May 27, 2026
by
raojy
Browse files
first
parent
eda2afb8
Changes
175
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
5223 additions
and
0 deletions
+5223
-0
SenseNova-SI-main/examples/Q8_4.jpg
SenseNova-SI-main/examples/Q8_4.jpg
+0
-0
SenseNova-SI-main/examples/Q9.jpg
SenseNova-SI-main/examples/Q9.jpg
+0
-0
SenseNova-SI-main/examples/bagel-generate-example.jpg
SenseNova-SI-main/examples/bagel-generate-example.jpg
+0
-0
SenseNova-SI-main/examples/bagel-think_generate-example.jpg
SenseNova-SI-main/examples/bagel-think_generate-example.jpg
+0
-0
SenseNova-SI-main/pyproject.toml
SenseNova-SI-main/pyproject.toml
+69
-0
SenseNova-SI-main/sensenova_si/__init__.py
SenseNova-SI-main/sensenova_si/__init__.py
+36
-0
SenseNova-SI-main/sensenova_si/bagel.py
SenseNova-SI-main/sensenova_si/bagel.py
+406
-0
SenseNova-SI-main/sensenova_si/bagel_utils/__init__.py
SenseNova-SI-main/sensenova_si/bagel_utils/__init__.py
+0
-0
SenseNova-SI-main/sensenova_si/bagel_utils/data/__init__.py
SenseNova-SI-main/sensenova_si/bagel_utils/data/__init__.py
+2
-0
SenseNova-SI-main/sensenova_si/bagel_utils/data/data_utils.py
...eNova-SI-main/sensenova_si/bagel_utils/data/data_utils.py
+198
-0
SenseNova-SI-main/sensenova_si/bagel_utils/data/transforms.py
...eNova-SI-main/sensenova_si/bagel_utils/data/transforms.py
+306
-0
SenseNova-SI-main/sensenova_si/bagel_utils/inferencer.py
SenseNova-SI-main/sensenova_si/bagel_utils/inferencer.py
+362
-0
SenseNova-SI-main/sensenova_si/bagel_utils/modeling/__init__.py
...ova-SI-main/sensenova_si/bagel_utils/modeling/__init__.py
+4
-0
SenseNova-SI-main/sensenova_si/bagel_utils/modeling/autoencoder.py
...-SI-main/sensenova_si/bagel_utils/modeling/autoencoder.py
+386
-0
SenseNova-SI-main/sensenova_si/bagel_utils/modeling/bagel/__init__.py
...-main/sensenova_si/bagel_utils/modeling/bagel/__init__.py
+17
-0
SenseNova-SI-main/sensenova_si/bagel_utils/modeling/bagel/bagel.py
...-SI-main/sensenova_si/bagel_utils/modeling/bagel/bagel.py
+1215
-0
SenseNova-SI-main/sensenova_si/bagel_utils/modeling/bagel/modeling_utils.py
...sensenova_si/bagel_utils/modeling/bagel/modeling_utils.py
+153
-0
SenseNova-SI-main/sensenova_si/bagel_utils/modeling/bagel/qwen2_navit.py
...in/sensenova_si/bagel_utils/modeling/bagel/qwen2_navit.py
+1457
-0
SenseNova-SI-main/sensenova_si/bagel_utils/modeling/bagel/siglip_navit.py
...n/sensenova_si/bagel_utils/modeling/bagel/siglip_navit.py
+419
-0
SenseNova-SI-main/sensenova_si/bagel_utils/modeling/cache_utils/taylorseer.py
...nsenova_si/bagel_utils/modeling/cache_utils/taylorseer.py
+193
-0
No files found.
SenseNova-SI-main/examples/Q8_4.jpg
0 → 100644
View file @
876a36a4
58.5 KB
SenseNova-SI-main/examples/Q9.jpg
0 → 100644
View file @
876a36a4
226 KB
SenseNova-SI-main/examples/bagel-generate-example.jpg
0 → 100644
View file @
876a36a4
86.4 KB
SenseNova-SI-main/examples/bagel-think_generate-example.jpg
0 → 100644
View file @
876a36a4
66.5 KB
SenseNova-SI-main/pyproject.toml
0 → 100644
View file @
876a36a4
[project]
name
=
"SenseNova-SI"
version
=
"0.1.0"
description
=
"Scaling Spatial Intelligence with Multimodal Foundation Models"
readme
=
"README.md"
requires-python
=
">=3.11"
keywords
=
[
"computer vision"
,
"multimodal"
,
"spatial intelligence"
,
"MLLM"
]
dependencies
=
[
"transformers>=4.57.0"
,
"Pillow"
,
"numpy"
,
"setuptools"
,
"einops>=0.8.1"
,
"timm>=1.0.22"
,
"accelerate>=1.11.0"
,
"opencv-python>=4.11.0.86"
,
]
[dependency-groups]
flash-attn
=
["flash-attn==2.7.4.post1"]
dev
=
["ruff==0.14.4"]
[project.optional-dependencies]
cu118
=
[
"torch>=2.4.0"
,
"torchvision"
]
cu121
=
[
"torch>=2.4.0"
,
"torchvision"
]
cu124
=
[
"torch>=2.4.0"
,
"torchvision"
]
cu126
=
[
"torch>=2.4.0"
,
"torchvision"
]
cu128
=
[
"torch>=2.4.0"
,
"torchvision"
]
cu129
=
[
"torch>=2.4.0"
,
"torchvision"
]
[tool.uv]
default-groups
=
["flash-attn"]
no-build-isolation-package
=
[
'flash-attn'
,
'setuptools'
]
conflicts
=
[
[
{
extra
=
"cu118"
}
,
{
extra
=
"cu121"
}
,
{
extra
=
"cu124"
}
,
{
extra
=
"cu126"
}
,
{
extra
=
"cu128"
}
,
{
extra
=
"cu129"
}
,
],
]
index
=
[
{
name
=
"pytorch-cu118"
,
url
=
"https://download.pytorch.org/whl/cu118"
,
explicit
=
true
}
,
{
name
=
"pytorch-cu121"
,
url
=
"https://download.pytorch.org/whl/cu121"
,
explicit
=
true
}
,
{
name
=
"pytorch-cu124"
,
url
=
"https://download.pytorch.org/whl/cu124"
,
explicit
=
true
}
,
{
name
=
"pytorch-cu126"
,
url
=
"https://download.pytorch.org/whl/cu126"
,
explicit
=
true
}
,
{
name
=
"pytorch-cu128"
,
url
=
"https://download.pytorch.org/whl/cu128"
,
explicit
=
true
}
,
{
name
=
"pytorch-cu129"
,
url
=
"https://download.pytorch.org/whl/cu129"
,
explicit
=
true
}
,
]
[tool.uv.sources]
torch
=
[
{
index
=
"pytorch-cu118"
,
extra
=
"cu118"
}
,
{
index
=
"pytorch-cu121"
,
extra
=
"cu121"
}
,
{
index
=
"pytorch-cu124"
,
extra
=
"cu124"
}
,
{
index
=
"pytorch-cu126"
,
extra
=
"cu126"
}
,
{
index
=
"pytorch-cu128"
,
extra
=
"cu128"
}
,
{
index
=
"pytorch-cu129"
,
extra
=
"cu129"
}
,
]
torchvision
=
[
{
index
=
"pytorch-cu118"
,
extra
=
"cu118"
}
,
{
index
=
"pytorch-cu121"
,
extra
=
"cu121"
}
,
{
index
=
"pytorch-cu124"
,
extra
=
"cu124"
}
,
{
index
=
"pytorch-cu126"
,
extra
=
"cu126"
}
,
{
index
=
"pytorch-cu128"
,
extra
=
"cu128"
}
,
{
index
=
"pytorch-cu129"
,
extra
=
"cu129"
}
,
]
SenseNova-SI-main/sensenova_si/__init__.py
0 → 100644
View file @
876a36a4
from
.bagel
import
SenseNovaSIBagelModel
from
.internvl
import
SenseNovaSIInternVLModel
from
.qwen
import
SenseNovaSIQwenModel
def
get_default_model_type
(
model_path
):
if
"qwen"
in
model_path
.
lower
():
return
"qwen"
elif
"internvl"
in
model_path
.
lower
():
return
"internvl"
elif
"bagel"
in
model_path
.
lower
():
return
"bagel"
else
:
raise
ValueError
(
f
"Unknown model type for
{
model_path
}
"
)
def
get_model
(
model_path
,
model_type
=
"auto"
):
if
model_type
==
"auto"
:
model_type
=
get_default_model_type
(
model_path
)
if
model_type
==
"qwen"
:
return
SenseNovaSIQwenModel
(
model_path
)
elif
model_type
==
"internvl"
:
return
SenseNovaSIInternVLModel
(
model_path
)
elif
model_type
==
"bagel"
:
return
SenseNovaSIBagelModel
(
model_path
)
else
:
raise
ValueError
(
f
"Unknown model type:
{
model_type
}
"
)
__all__
=
[
"get_default_model_type"
,
"get_model"
,
"SenseNovaSIInternVLModel"
,
"SenseNovaSIQwenModel"
,
"SenseNovaSIBagelModel"
,
]
SenseNova-SI-main/sensenova_si/bagel.py
0 → 100644
View file @
876a36a4
import
os
import
uuid
from
datetime
import
datetime
from
pathlib
import
Path
from
typing
import
Any
,
Dict
,
Optional
import
torch
from
accelerate
import
(
infer_auto_device_map
,
init_empty_weights
,
load_checkpoint_and_dispatch
,
)
from
huggingface_hub
import
snapshot_download
from
PIL
import
Image
from
.bagel_utils.data.transforms
import
ImageTransform
from
.bagel_utils.inferencer
import
InterleaveInferencer
from
.bagel_utils.modeling.autoencoder
import
load_ae
from
.bagel_utils.modeling.bagel
import
(
Bagel
,
BagelConfig
,
Qwen2Config
,
Qwen2ForCausalLM
,
SiglipVisionConfig
,
SiglipVisionModel
,
)
from
.bagel_utils.modeling.qwen2
import
Qwen2Tokenizer
from
.model
import
Model
from
.utils
import
add_special_tokens
BASE_PARAMS
:
Dict
[
str
,
Dict
[
str
,
Any
]]
=
{
"generate"
:
dict
(
cfg_text_scale
=
4.0
,
cfg_img_scale
=
1.0
,
cfg_interval
=
[
0.4
,
1.0
],
timestep_shift
=
3.0
,
num_timesteps
=
50
,
cfg_renorm_min
=
1.0
,
cfg_renorm_type
=
"global"
,
),
"think_generate"
:
dict
(
max_think_token_n
=
1000
,
do_sample
=
False
,
cfg_text_scale
=
4.0
,
cfg_img_scale
=
1.0
,
cfg_interval
=
[
0.4
,
1.0
],
timestep_shift
=
3.0
,
num_timesteps
=
50
,
cfg_renorm_min
=
1.0
,
cfg_renorm_type
=
"global"
,
think
=
True
,
),
"edit"
:
dict
(
cfg_text_scale
=
4.0
,
cfg_img_scale
=
2.0
,
cfg_interval
=
[
0.0
,
1.0
],
timestep_shift
=
3.0
,
num_timesteps
=
50
,
cfg_renorm_min
=
0.0
,
cfg_renorm_type
=
"text_channel"
,
),
"think_edit"
:
dict
(
max_think_token_n
=
1000
,
do_sample
=
False
,
cfg_text_scale
=
4.0
,
cfg_img_scale
=
2.0
,
cfg_interval
=
[
0.0
,
1.0
],
timestep_shift
=
3.0
,
num_timesteps
=
50
,
cfg_renorm_min
=
0.0
,
cfg_renorm_type
=
"text_channel"
,
think
=
True
,
),
"understanding"
:
dict
(
max_think_token_n
=
1000
,
do_sample
=
False
,
understanding_output
=
True
,
),
"think_understanding"
:
dict
(
max_think_token_n
=
1000
,
do_sample
=
False
,
understanding_output
=
True
,
think
=
True
,
),
}
class
SenseNovaSIBagelModel
(
Model
):
def
__init__
(
self
,
model_path
=
"sensenova/SenseNova-SI-1.1-BAGEL-7B-MoT"
,
generation_config
:
dict
[
str
,
Any
]
|
str
|
os
.
PathLike
|
None
=
None
,
mode
=
"understanding"
,
out_img_dir
=
"./output_images/test_bagel/"
,
dtype
:
str
=
"bf16"
,
):
super
().
__init__
(
generation_config
)
# 1. Parse params
self
.
precision
=
dtype
if
os
.
path
.
exists
(
model_path
):
cache_path
=
model_path
else
:
cache_path
=
snapshot_download
(
repo_id
=
model_path
)
self
.
model_path
=
cache_path
self
.
checkpoint_path
=
os
.
path
.
join
(
self
.
model_path
,
"model.safetensors"
)
# Bagel mode
env_mode
=
os
.
getenv
(
"BAGEL_MODE"
)
mode
=
env_mode
.
strip
()
if
env_mode
and
env_mode
.
strip
()
else
mode
if
mode
not
in
BASE_PARAMS
:
raise
ValueError
(
f
"Invalid mode '
{
mode
}
'. "
f
"Bagel Supported modes:
{
list
(
BASE_PARAMS
.
keys
())
}
"
)
self
.
mode
=
mode
env_out_img_dir
=
os
.
getenv
(
"BAGEL_OUT_IMG_DIR"
)
self
.
out_img_dir
=
(
env_out_img_dir
.
strip
()
if
env_out_img_dir
and
env_out_img_dir
.
strip
()
else
out_img_dir
)
msg
=
(
f
"[Bagel] mode = '
{
self
.
mode
}
' "
f
"(can be overridden with env var BAGEL_MODE); "
f
"out_img_dir = '
{
self
.
out_img_dir
}
' "
f
"(can be overridden with env var BAGEL_OUT_IMG_DIR)"
)
print
(
msg
)
# 2. Build model
model
,
vae_model
,
tokenizer
,
new_token_ids
,
vit_transform
,
vae_transform
=
(
self
.
_build_model
()
)
# 3. Load Checkpoint
model
=
self
.
_load_model_weights
(
model
)
# 4. Build inferencer
self
.
tokenizer
=
tokenizer
self
.
new_token_ids
=
new_token_ids
self
.
vit_transform
=
vit_transform
self
.
inferencer
=
InterleaveInferencer
(
model
=
model
,
vae_model
=
vae_model
,
tokenizer
=
tokenizer
,
vae_transform
=
vae_transform
,
vit_transform
=
vit_transform
,
new_token_ids
=
new_token_ids
,
)
torch
.
cuda
.
empty_cache
()
def
_build_model
(
self
):
# build llm config
llm_config
=
Qwen2Config
.
from_json_file
(
os
.
path
.
join
(
self
.
model_path
,
"llm_config.json"
)
)
llm_config
.
qk_norm
=
True
llm_config
.
tie_word_embeddings
=
False
llm_config
.
layer_module
=
"Qwen2MoTDecoderLayer"
# build vit config
vit_config
=
SiglipVisionConfig
.
from_json_file
(
os
.
path
.
join
(
self
.
model_path
,
"vit_config.json"
)
)
vit_config
.
rope
=
False
vit_config
.
num_hidden_layers
-=
1
vit_transform
=
ImageTransform
(
980
,
224
,
14
)
vae_transform
=
ImageTransform
(
1024
,
512
,
16
)
# build vae config
vae_model
,
vae_config
=
load_ae
(
local_path
=
os
.
path
.
join
(
self
.
model_path
,
"ae.safetensors"
)
)
# build tokenizer
tokenizer
=
Qwen2Tokenizer
.
from_pretrained
(
self
.
model_path
)
tokenizer
,
new_token_ids
,
_
=
add_special_tokens
(
tokenizer
)
# build model
model_config
=
BagelConfig
(
visual_gen
=
True
,
visual_und
=
True
,
llm_config
=
llm_config
,
vit_config
=
vit_config
,
vae_config
=
vae_config
,
latent_patch_size
=
2
,
max_latent_size
=
64
,
vit_max_num_patch_per_side
=
70
,
connector_act
=
"gelu_pytorch_tanh"
,
)
with
init_empty_weights
():
language_model
=
Qwen2ForCausalLM
(
llm_config
)
vit_model
=
SiglipVisionModel
(
vit_config
)
model
=
Bagel
(
language_model
,
vit_model
,
model_config
)
model
.
vit_model
.
vision_model
.
embeddings
.
convert_conv2d_to_linear
(
vit_config
)
return
model
,
vae_model
,
tokenizer
,
new_token_ids
,
vit_transform
,
vae_transform
def
_load_model_weights
(
self
,
model
):
device_map
=
infer_auto_device_map
(
model
,
no_split_module_classes
=
[
"Bagel"
,
"Qwen2MoTDecoderLayer"
]
)
same_device_modules
=
[
"language_model.model.embed_tokens"
,
"time_embedder"
,
"latent_pos_embed"
,
"vae2llm"
,
"llm2vae"
,
"connector"
,
"vit_pos_embed"
,
]
if
torch
.
cuda
.
device_count
()
==
1
:
first_device
=
device_map
.
get
(
same_device_modules
[
0
],
"cuda:0"
)
for
k
in
same_device_modules
:
if
k
in
device_map
:
device_map
[
k
]
=
first_device
else
:
device_map
[
k
]
=
"cuda:0"
else
:
first_device
=
device_map
.
get
(
same_device_modules
[
0
])
for
k
in
same_device_modules
:
if
k
in
device_map
:
device_map
[
k
]
=
first_device
if
self
.
precision
==
"bf16"
:
model
=
load_checkpoint_and_dispatch
(
model
,
checkpoint
=
self
.
checkpoint_path
,
device_map
=
device_map
,
offload_buffers
=
True
,
offload_folder
=
"offload"
,
dtype
=
torch
.
bfloat16
,
force_hooks
=
True
,
).
eval
()
elif
self
.
precision
==
"nf4"
:
from
accelerate.utils
import
BnbQuantizationConfig
,
load_and_quantize_model
model
=
load_and_quantize_model
(
model
,
weights_location
=
self
.
checkpoint_path
,
bnb_quantization_config
=
BnbQuantizationConfig
(
load_in_4bit
=
True
,
bnb_4bit_compute_dtype
=
torch
.
bfloat16
,
bnb_4bit_use_double_quant
=
False
,
bnb_4bit_quant_type
=
"nf4"
,
),
device_map
=
device_map
,
offload_folder
=
"offload"
,
).
eval
()
elif
self
.
precision
==
"int8"
:
from
accelerate.utils
import
BnbQuantizationConfig
,
load_and_quantize_model
model
=
load_and_quantize_model
(
model
,
weights_location
=
self
.
checkpoint_path
,
bnb_quantization_config
=
BnbQuantizationConfig
(
load_in_8bit
=
True
,
torch_dtype
=
torch
.
float32
),
device_map
=
device_map
,
offload_folder
=
"offload"
,
).
eval
()
else
:
raise
NotImplementedError
(
f
"Unsupported precision:
{
self
.
precision
}
"
)
return
model
def
_save_output_image
(
self
,
image
:
Image
.
Image
,
mode
:
str
,
img_path
:
Optional
[
str
],
)
->
str
:
if
image
is
None
:
raise
ValueError
(
f
"[OutputError] Mode=
{
mode
}
expected an image output, but got None."
)
root
=
Path
(
self
.
out_img_dir
)
images_root
=
root
/
(
f
"images"
)
images_root
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
if
mode
in
[
"edit"
,
"think_edit"
]:
if
img_path
:
src
=
Path
(
img_path
)
parent_name
=
src
.
parent
.
name
or
"default"
out_dir
=
images_root
/
parent_name
out_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
filename
=
src
.
name
else
:
out_dir
=
images_root
/
"edit"
out_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
ts
=
datetime
.
now
().
strftime
(
"%Y%m%d-%H%M%S-%f"
)
base
=
"sample"
filename
=
f
"
{
base
}
_edit_
{
ts
}
_
{
uuid
.
uuid4
().
hex
[:
8
]
}
.jpg"
out_path
=
out_dir
/
filename
elif
mode
in
[
"generate"
,
"think_generate"
]:
out_dir
=
images_root
out_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
ts
=
datetime
.
now
().
strftime
(
"%Y%m%d-%H%M%S-%f"
)
base
=
"sample"
filename
=
f
"
{
base
}
_
{
ts
}
_
{
uuid
.
uuid4
().
hex
[:
8
]
}
.jpg"
out_path
=
out_dir
/
filename
else
:
raise
ValueError
(
f
"[OutputError] Unexpected mode for image saving:
{
mode
}
"
)
image
.
save
(
out_path
)
return
str
(
out_path
)
def
generate
(
self
,
question
:
str
,
images
:
list
[
str
]
|
None
=
None
,
**
kwargs
):
mode
=
self
.
mode
images
=
images
or
[]
# Auto-prepend <image> placeholders if the question doesn't contain them
existing_count
=
question
.
count
(
"<image>"
)
if
images
and
existing_count
==
0
:
question
=
""
.
join
([
"<image>
\n
"
for
_
in
images
])
+
question
text_parts
=
question
.
split
(
"<image>"
)
if
len
(
text_parts
)
!=
len
(
images
)
+
1
:
raise
ValueError
(
f
"Text iamge tokens and number of images not match! "
)
input_lists
=
[]
input_img_paths
=
[]
for
i
,
part
in
enumerate
(
text_parts
):
text
=
part
.
strip
()
if
text
:
input_lists
.
append
(
text
)
if
i
<
len
(
images
):
img_path
=
images
[
i
]
try
:
image
=
Image
.
open
(
img_path
)
input_lists
.
append
(
image
)
input_img_paths
.
append
(
img_path
)
except
Exception
as
e
:
raise
RuntimeError
(
f
"Can not load image
{
img_path
}
:
{
e
}
"
)
from
e
params
=
dict
(
BASE_PARAMS
[
mode
])
understanding_output_flag
=
params
.
pop
(
"understanding_output"
,
False
)
think_flag
=
params
.
pop
(
"think"
,
False
)
res
=
self
.
inferencer
.
interleave_inference
(
input_lists
=
input_lists
,
think
=
think_flag
,
understanding_output
=
understanding_output_flag
,
**
params
,
)
ret
=
{
"image"
:
[],
"text"
:
[]}
for
i
in
res
:
if
isinstance
(
i
,
Image
.
Image
):
ret
[
"image"
].
append
(
i
)
elif
isinstance
(
i
,
str
):
ret
[
"text"
].
append
(
i
)
img_cnt
,
txt_cnt
=
len
(
ret
[
"image"
]),
len
(
ret
[
"text"
])
if
img_cnt
+
txt_cnt
!=
1
:
print
(
f
"[Warning] You are using
{
mode
}
mode, so the output has
{
img_cnt
}
images and
{
txt_cnt
}
texts"
)
if
txt_cnt
>
0
:
print
(
f
"[Warning] The text output is:
{
ret
[
'text'
][
0
]
}
"
)
ret
[
"image"
]
=
ret
[
"image"
][
0
]
if
img_cnt
else
None
ret
[
"text"
]
=
ret
[
"text"
][
0
]
if
txt_cnt
else
None
if
mode
in
[
"edit"
,
"think_edit"
,
"generate"
,
"think_generate"
]:
if
ret
[
"image"
]
is
not
None
:
if
len
(
input_img_paths
)
==
1
:
ref_img_path
=
input_img_paths
[
0
]
else
:
ref_img_path
=
None
img_path_out
=
self
.
_save_output_image
(
image
=
ret
[
"image"
],
mode
=
mode
,
img_path
=
ref_img_path
,
)
ret
[
"image"
]
=
img_path_out
res
=
img_path_out
else
:
res
=
None
else
:
res
=
ret
[
"text"
]
return
res
SenseNova-SI-main/sensenova_si/bagel_utils/__init__.py
0 → 100644
View file @
876a36a4
SenseNova-SI-main/sensenova_si/bagel_utils/data/__init__.py
0 → 100644
View file @
876a36a4
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
SenseNova-SI-main/sensenova_si/bagel_utils/data/data_utils.py
0 → 100644
View file @
876a36a4
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
import
math
import
random
import
torch
from
PIL
import
Image
# from torch.nn.attention.flex_attention import or_masks, and_masks
def
create_sparse_mask
(
document_lens
,
split_lens
,
attn_modes
,
device
):
def
causal_mask
(
b
,
h
,
q_idx
,
kv_idx
):
return
q_idx
>=
kv_idx
def
full_and_noise_mask
(
b
,
h
,
q_idx
,
kv_idx
):
return
(
full_and_noise_seq_id
[
q_idx
]
==
full_and_noise_seq_id
[
kv_idx
])
&
(
full_and_noise_seq_id
[
q_idx
]
>=
0
)
def
remove_noise_mask
(
b
,
h
,
q_idx
,
kv_idx
):
return
~
(
(
noise_seq_id
[
kv_idx
]
>=
0
)
&
(
noise_seq_id
[
q_idx
]
!=
noise_seq_id
[
kv_idx
])
)
def
sample_mask
(
b
,
h
,
q_idx
,
kv_idx
):
return
document_id
[
q_idx
]
==
document_id
[
kv_idx
]
full_and_noise_tmp
=
[]
noise_tmp
=
[]
for
i
,
(
length
,
model
)
in
enumerate
(
zip
(
split_lens
,
attn_modes
)):
value
=
i
if
model
in
[
"full"
,
"noise"
]
else
-
1
full_and_noise_tmp
.
extend
([
value
]
*
length
)
value_noise
=
i
if
model
==
"noise"
else
-
1
noise_tmp
.
extend
([
value_noise
]
*
length
)
full_and_noise_seq_id
=
torch
.
Tensor
(
full_and_noise_tmp
).
to
(
device
)
noise_seq_id
=
torch
.
Tensor
(
noise_tmp
).
to
(
device
)
document_id
=
torch
.
cat
(
[
torch
.
full
((
l
,),
i
)
for
i
,
l
in
enumerate
(
document_lens
,
start
=
1
)]
).
to
(
device
)
return
and_masks
(
or_masks
(
causal_mask
,
full_and_noise_mask
),
remove_noise_mask
,
sample_mask
)
def
patchify
(
image
,
patch_size
):
p
=
patch_size
c
,
h
,
w
=
image
.
shape
assert
h
%
p
==
0
and
w
%
p
==
0
image
=
image
.
reshape
(
c
,
h
//
p
,
p
,
w
//
p
,
p
)
image
=
torch
.
einsum
(
"chpwq->hwpqc"
,
image
)
image
=
image
.
reshape
(
-
1
,
p
**
2
*
c
)
return
image
def
get_flattened_position_ids_extrapolate
(
img_h
,
img_w
,
patch_size
,
max_num_patches_per_side
):
num_patches_h
,
num_patches_w
=
img_h
//
patch_size
,
img_w
//
patch_size
coords_h
=
torch
.
arange
(
0
,
num_patches_h
)
coords_w
=
torch
.
arange
(
0
,
num_patches_w
)
pos_ids
=
(
coords_h
[:,
None
]
*
max_num_patches_per_side
+
coords_w
).
flatten
()
return
pos_ids
def
get_flattened_position_ids_interpolate
(
img_h
,
img_w
,
patch_size
,
max_num_patches_per_side
):
num_patches_h
,
num_patches_w
=
img_h
//
patch_size
,
img_w
//
patch_size
boundaries
=
torch
.
arange
(
1
/
max_num_patches_per_side
,
1.0
,
1
/
max_num_patches_per_side
)
fractional_coords_h
=
torch
.
arange
(
0
,
1
-
1e-6
,
1
/
num_patches_h
)
fractional_coords_w
=
torch
.
arange
(
0
,
1
-
1e-6
,
1
/
num_patches_w
)
bucket_coords_h
=
torch
.
bucketize
(
fractional_coords_h
,
boundaries
,
right
=
True
)
bucket_coords_w
=
torch
.
bucketize
(
fractional_coords_w
,
boundaries
,
right
=
True
)
pos_ids
=
(
bucket_coords_h
[:,
None
]
*
max_num_patches_per_side
+
bucket_coords_w
).
flatten
()
return
pos_ids
def
prepare_attention_mask_per_sample
(
split_lens
,
attn_modes
,
device
=
"cpu"
):
"""
nested_split_lens: A list of N lists of ints. Each int indicates the length of a split within
a sample, where each sample contains multiple splits with different attn modes.
nested_attn_modes: whether to use full attn in each split.
"""
sample_len
=
sum
(
split_lens
)
attention_mask
=
torch
.
zeros
(
(
sample_len
,
sample_len
),
dtype
=
torch
.
bool
,
device
=
device
)
csum
=
0
for
s
,
attn_mode
in
zip
(
split_lens
,
attn_modes
):
assert
attn_mode
in
[
"causal"
,
"full"
,
"noise"
]
if
attn_mode
==
"causal"
:
attention_mask
[
csum
:
csum
+
s
,
csum
:
csum
+
s
]
=
torch
.
ones
(
(
s
,
s
),
device
=
device
).
tril
()
attention_mask
[
csum
:
csum
+
s
,
:
csum
]
=
1
else
:
attention_mask
[
csum
:
csum
+
s
,
csum
:
csum
+
s
]
=
torch
.
ones
((
s
,
s
))
attention_mask
[
csum
:
csum
+
s
,
:
csum
]
=
1
csum
+=
s
csum
=
0
for
s
,
attn_mode
in
zip
(
split_lens
,
attn_modes
):
if
attn_mode
==
"noise"
:
attention_mask
[:,
csum
:
csum
+
s
]
=
torch
.
zeros
((
sample_len
,
s
))
attention_mask
[
csum
:
csum
+
s
,
csum
:
csum
+
s
]
=
torch
.
ones
((
s
,
s
))
csum
+=
s
attention_mask
=
torch
.
zeros_like
(
attention_mask
,
dtype
=
torch
.
float
).
masked_fill_
(
~
attention_mask
,
float
(
"-inf"
)
)
return
attention_mask
def
split_integer_exp_decay
(
S
,
ng_sample_decay
=
1.0
):
if
ng_sample_decay
==
1.0
:
N
=
random
.
randint
(
1
,
S
)
else
:
base
=
(
1
-
ng_sample_decay
)
/
(
1
-
math
.
pow
(
ng_sample_decay
,
S
))
p
=
[
base
*
math
.
pow
(
ng_sample_decay
,
i
)
for
i
in
range
(
S
)]
N
=
random
.
choices
(
list
(
range
(
1
,
S
+
1
)),
p
,
k
=
1
)[
0
]
cumsum
=
[
0
]
+
sorted
(
random
.
sample
(
range
(
1
,
S
),
N
-
1
))
+
[
S
]
result
=
[
cumsum
[
i
+
1
]
-
cumsum
[
i
]
for
i
in
range
(
len
(
cumsum
)
-
1
)]
return
result
,
cumsum
def
pil_img2rgb
(
image
):
if
image
.
mode
==
"RGBA"
or
image
.
info
.
get
(
"transparency"
,
None
)
is
not
None
:
image
=
image
.
convert
(
"RGBA"
)
white
=
Image
.
new
(
mode
=
"RGB"
,
size
=
image
.
size
,
color
=
(
255
,
255
,
255
))
white
.
paste
(
image
,
mask
=
image
.
split
()[
3
])
image
=
white
else
:
image
=
image
.
convert
(
"RGB"
)
return
image
def
add_special_tokens
(
tokenizer
):
all_special_tokens
=
[]
for
k
,
v
in
tokenizer
.
special_tokens_map
.
items
():
if
isinstance
(
v
,
str
):
all_special_tokens
.
append
(
v
)
elif
isinstance
(
v
,
list
):
all_special_tokens
+=
v
new_tokens
=
[]
if
"<|im_start|>"
not
in
all_special_tokens
:
new_tokens
.
append
(
"<|im_start|>"
)
if
"<|im_end|>"
not
in
all_special_tokens
:
new_tokens
.
append
(
"<|im_end|>"
)
if
"<|vision_start|>"
not
in
all_special_tokens
:
new_tokens
.
append
(
"<|vision_start|>"
)
if
"<|vision_end|>"
not
in
all_special_tokens
:
new_tokens
.
append
(
"<|vision_end|>"
)
num_new_tokens
=
tokenizer
.
add_tokens
(
new_tokens
)
bos_token_id
=
tokenizer
.
convert_tokens_to_ids
(
"<|im_start|>"
)
eos_token_id
=
tokenizer
.
convert_tokens_to_ids
(
"<|im_end|>"
)
start_of_image
=
tokenizer
.
convert_tokens_to_ids
(
"<|vision_start|>"
)
end_of_image
=
tokenizer
.
convert_tokens_to_ids
(
"<|vision_end|>"
)
new_token_ids
=
dict
(
bos_token_id
=
bos_token_id
,
eos_token_id
=
eos_token_id
,
start_of_image
=
start_of_image
,
end_of_image
=
end_of_image
,
)
return
tokenizer
,
new_token_ids
,
num_new_tokens
def
len2weight
(
x
,
loss_reduction
=
"square"
):
if
x
==
0
:
return
x
if
loss_reduction
==
"token"
:
return
1
if
loss_reduction
==
"sample"
:
return
1
/
x
if
loss_reduction
==
"square"
:
return
1
/
(
x
**
0.5
)
raise
NotImplementedError
(
loss_reduction
)
SenseNova-SI-main/sensenova_si/bagel_utils/data/transforms.py
0 → 100644
View file @
876a36a4
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
import
random
import
cv2
import
numpy
as
np
import
torch
from
PIL
import
Image
from
torchvision
import
transforms
from
torchvision.transforms
import
InterpolationMode
from
torchvision.transforms
import
functional
as
F
class
MaxLongEdgeMinShortEdgeResize
(
torch
.
nn
.
Module
):
"""Resize the input image so that its longest side and shortest side are within a specified range,
ensuring that both sides are divisible by a specified stride.
Args:
max_size (int): Maximum size for the longest edge of the image.
min_size (int): Minimum size for the shortest edge of the image.
stride (int): Value by which the height and width of the image must be divisible.
max_pixels (int): Maximum pixels for the full image.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
``InterpolationMode.BILINEAR``, and ``InterpolationMode.BICUBIC`` are supported.
The corresponding Pillow integer constants, e.g., ``PIL.Image.BILINEAR`` are also accepted.
antialias (bool, optional): Whether to apply antialiasing (default is True).
"""
def
__init__
(
self
,
max_size
:
int
,
min_size
:
int
,
stride
:
int
,
max_pixels
:
int
,
interpolation
=
InterpolationMode
.
BICUBIC
,
antialias
=
True
,
):
super
().
__init__
()
self
.
max_size
=
max_size
self
.
min_size
=
min_size
self
.
stride
=
stride
self
.
max_pixels
=
max_pixels
self
.
interpolation
=
interpolation
self
.
antialias
=
antialias
def
_make_divisible
(
self
,
value
,
stride
):
"""Ensure the value is divisible by the stride."""
return
max
(
stride
,
int
(
round
(
value
/
stride
)
*
stride
))
def
_apply_scale
(
self
,
width
,
height
,
scale
):
new_width
=
round
(
width
*
scale
)
new_height
=
round
(
height
*
scale
)
new_width
=
self
.
_make_divisible
(
new_width
,
self
.
stride
)
new_height
=
self
.
_make_divisible
(
new_height
,
self
.
stride
)
return
new_width
,
new_height
def
forward
(
self
,
img
,
img_num
=
1
):
"""
Args:
img (PIL Image): Image to be resized.
img_num (int): Number of images, used to change max_tokens.
Returns:
PIL Image or Tensor: Rescaled image with divisible dimensions.
"""
if
isinstance
(
img
,
torch
.
Tensor
):
height
,
width
=
img
.
shape
[
-
2
:]
else
:
width
,
height
=
img
.
size
scale
=
min
(
self
.
max_size
/
max
(
width
,
height
),
1.0
)
scale
=
max
(
scale
,
self
.
min_size
/
min
(
width
,
height
))
new_width
,
new_height
=
self
.
_apply_scale
(
width
,
height
,
scale
)
# Ensure the number of pixels does not exceed max_pixels
if
new_width
*
new_height
>
self
.
max_pixels
/
img_num
:
scale
=
self
.
max_pixels
/
img_num
/
(
new_width
*
new_height
)
new_width
,
new_height
=
self
.
_apply_scale
(
new_width
,
new_height
,
scale
)
# Ensure longest edge does not exceed max_size
if
max
(
new_width
,
new_height
)
>
self
.
max_size
:
scale
=
self
.
max_size
/
max
(
new_width
,
new_height
)
new_width
,
new_height
=
self
.
_apply_scale
(
new_width
,
new_height
,
scale
)
return
F
.
resize
(
img
,
(
new_height
,
new_width
),
self
.
interpolation
,
antialias
=
self
.
antialias
)
class
ImageTransform
:
def
__init__
(
self
,
max_image_size
,
min_image_size
,
image_stride
,
max_pixels
=
14
*
14
*
9
*
1024
,
image_mean
=
[
0.5
,
0.5
,
0.5
],
image_std
=
[
0.5
,
0.5
,
0.5
],
):
self
.
stride
=
image_stride
self
.
resize_transform
=
MaxLongEdgeMinShortEdgeResize
(
max_size
=
max_image_size
,
min_size
=
min_image_size
,
stride
=
image_stride
,
max_pixels
=
max_pixels
,
)
self
.
to_tensor_transform
=
transforms
.
ToTensor
()
self
.
normalize_transform
=
transforms
.
Normalize
(
mean
=
image_mean
,
std
=
image_std
,
inplace
=
True
)
def
__call__
(
self
,
img
,
img_num
=
1
):
img
=
self
.
resize_transform
(
img
,
img_num
=
img_num
)
img
=
self
.
to_tensor_transform
(
img
)
img
=
self
.
normalize_transform
(
img
)
return
img
def
decolorization
(
image
):
gray_image
=
image
.
convert
(
"L"
)
return
(
Image
.
merge
(
image
.
mode
,
[
gray_image
]
*
3
)
if
image
.
mode
in
(
"RGB"
,
"L"
)
else
gray_image
)
def
downscale
(
image
,
scale_factor
):
new_width
=
int
(
round
(
image
.
width
*
scale_factor
))
new_height
=
int
(
round
(
image
.
height
*
scale_factor
))
new_width
=
max
(
1
,
new_width
)
new_height
=
max
(
1
,
new_height
)
return
image
.
resize
((
new_width
,
new_height
),
resample
=
Image
.
BICUBIC
)
def
crop
(
image
,
crop_factors
):
target_h
,
target_w
=
crop_factors
img_w
,
img_h
=
image
.
size
if
target_h
>
img_h
or
target_w
>
img_w
:
raise
ValueError
(
"Crop size exceeds image dimensions"
)
x
=
random
.
randint
(
0
,
img_w
-
target_w
)
y
=
random
.
randint
(
0
,
img_h
-
target_h
)
return
image
.
crop
((
x
,
y
,
x
+
target_w
,
y
+
target_h
)),
[
[
x
,
y
],
[
x
+
target_w
,
y
+
target_h
],
]
def
motion_blur_opencv
(
image
,
kernel_size
=
15
,
angle
=
0
):
# 线性核
kernel
=
np
.
zeros
((
kernel_size
,
kernel_size
),
dtype
=
np
.
float32
)
kernel
[
kernel_size
//
2
,
:]
=
np
.
ones
(
kernel_size
,
dtype
=
np
.
float32
)
# 旋转核
center
=
(
kernel_size
/
2
-
0.5
,
kernel_size
/
2
-
0.5
)
M
=
cv2
.
getRotationMatrix2D
(
center
,
angle
,
1
)
rotated_kernel
=
cv2
.
warpAffine
(
kernel
,
M
,
(
kernel_size
,
kernel_size
))
# 归一化核
rotated_kernel
/=
rotated_kernel
.
sum
()
if
rotated_kernel
.
sum
()
!=
0
else
1
img
=
np
.
array
(
image
)
if
img
.
ndim
==
2
:
blurred
=
cv2
.
filter2D
(
img
,
-
1
,
rotated_kernel
,
borderType
=
cv2
.
BORDER_REFLECT
)
else
:
# 对于彩色图像,各通道独立卷积
blurred
=
np
.
zeros_like
(
img
)
for
c
in
range
(
img
.
shape
[
2
]):
blurred
[...,
c
]
=
cv2
.
filter2D
(
img
[...,
c
],
-
1
,
rotated_kernel
,
borderType
=
cv2
.
BORDER_REFLECT
)
return
Image
.
fromarray
(
blurred
.
astype
(
np
.
uint8
))
def
shuffle_patch
(
image
,
num_splits
,
gap_size
=
2
):
"""将图像分割为块(允许尺寸不整除),随机打乱后拼接,块间保留间隙"""
h_splits
,
w_splits
=
num_splits
img_w
,
img_h
=
image
.
size
base_patch_h
=
img_h
//
h_splits
patch_heights
=
[
base_patch_h
]
*
(
h_splits
-
1
)
patch_heights
.
append
(
img_h
-
sum
(
patch_heights
))
base_patch_w
=
img_w
//
w_splits
patch_widths
=
[
base_patch_w
]
*
(
w_splits
-
1
)
patch_widths
.
append
(
img_w
-
sum
(
patch_widths
))
patches
=
[]
current_y
=
0
for
i
in
range
(
h_splits
):
current_x
=
0
patch_h
=
patch_heights
[
i
]
for
j
in
range
(
w_splits
):
patch_w
=
patch_widths
[
j
]
patch
=
image
.
crop
(
(
current_x
,
current_y
,
current_x
+
patch_w
,
current_y
+
patch_h
)
)
patches
.
append
(
patch
)
current_x
+=
patch_w
current_y
+=
patch_h
random
.
shuffle
(
patches
)
total_width
=
sum
(
patch_widths
)
+
(
w_splits
-
1
)
*
gap_size
total_height
=
sum
(
patch_heights
)
+
(
h_splits
-
1
)
*
gap_size
new_image
=
Image
.
new
(
image
.
mode
,
(
total_width
,
total_height
),
color
=
(
255
,
255
,
255
)
)
current_y
=
0
# 当前行的起始 Y 坐标
patch_idx
=
0
# 当前处理的块索引
for
i
in
range
(
h_splits
):
current_x
=
0
# 当前列的起始 X 坐标
patch_h
=
patch_heights
[
i
]
# 当前行块的高度
for
j
in
range
(
w_splits
):
# 取出打乱后的块
patch
=
patches
[
patch_idx
]
patch_w
=
patch_widths
[
j
]
# 当前列块的宽度
# 粘贴块(左上角坐标为 (current_x, current_y))
new_image
.
paste
(
patch
,
(
current_x
,
current_y
))
# 更新 X 坐标(下一个块的起始位置 = 当前块宽度 + 间隙)
current_x
+=
patch_w
+
gap_size
patch_idx
+=
1
# 更新 Y 坐标(下一行的起始位置 = 当前行高度 + 间隙)
current_y
+=
patch_h
+
gap_size
return
new_image
def
inpainting
(
image
,
num_splits
,
blank_ratio
=
0.3
,
blank_color
=
(
255
,
255
,
255
)):
"""
图像分割后随机空白部分patch,用于inpainting任务
参数:
image: PIL.Image 输入图像(RGB模式)
h_splits: int 行分割数(垂直方向分割块数)
w_splits: int 列分割数(水平方向分割块数)
blank_ratio: float 空白patch的比例(0~1)
blank_color: tuple 空白区域的颜色(RGB,如白色(255,255,255))
返回:
PIL.Image 处理后拼接的图像
"""
h_splits
,
w_splits
=
num_splits
img_w
,
img_h
=
image
.
size
base_patch_h
=
img_h
//
h_splits
patch_heights
=
[
base_patch_h
]
*
(
h_splits
-
1
)
patch_heights
.
append
(
img_h
-
sum
(
patch_heights
))
base_patch_w
=
img_w
//
w_splits
patch_widths
=
[
base_patch_w
]
*
(
w_splits
-
1
)
patch_widths
.
append
(
img_w
-
sum
(
patch_widths
))
patches
=
[]
current_y
=
0
for
i
in
range
(
h_splits
):
current_x
=
0
patch_h
=
patch_heights
[
i
]
for
j
in
range
(
w_splits
):
patch_w
=
patch_widths
[
j
]
patch
=
image
.
crop
(
(
current_x
,
current_y
,
current_x
+
patch_w
,
current_y
+
patch_h
)
)
patches
.
append
(
patch
)
current_x
+=
patch_w
current_y
+=
patch_h
total_patches
=
h_splits
*
w_splits
num_blank
=
int
(
total_patches
*
blank_ratio
)
num_blank
=
max
(
0
,
min
(
num_blank
,
total_patches
))
blank_indices
=
random
.
sample
(
range
(
total_patches
),
num_blank
)
processed_patches
=
[]
for
idx
,
patch
in
enumerate
(
patches
):
if
idx
in
blank_indices
:
blank_patch
=
Image
.
new
(
"RGB"
,
patch
.
size
,
color
=
blank_color
)
processed_patches
.
append
(
blank_patch
)
else
:
processed_patches
.
append
(
patch
)
# 创建结果图像(尺寸与原图一致)
result_image
=
Image
.
new
(
"RGB"
,
(
img_w
,
img_h
))
current_y
=
0
patch_idx
=
0
for
i
in
range
(
h_splits
):
current_x
=
0
patch_h
=
patch_heights
[
i
]
for
j
in
range
(
w_splits
):
# 取出处理后的patch
patch
=
processed_patches
[
patch_idx
]
patch_w
=
patch_widths
[
j
]
# 粘贴到原位置
result_image
.
paste
(
patch
,
(
current_x
,
current_y
))
current_x
+=
patch_w
patch_idx
+=
1
current_y
+=
patch_h
return
result_image
SenseNova-SI-main/sensenova_si/bagel_utils/inferencer.py
0 → 100644
View file @
876a36a4
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
from
copy
import
deepcopy
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Union
import
torch
from
PIL
import
Image
from
.data.data_utils
import
pil_img2rgb
from
.modeling.bagel.qwen2_navit
import
NaiveCache
VLM_THINK_SYSTEM_PROMPT
=
"""You should first think about the reasoning process in the mind and then provide the user with the answer.
The reasoning process is enclosed within <think> </think> tags, i.e. <think> reasoning process here </think> answer here"""
GEN_THINK_SYSTEM_PROMPT
=
"""You should first think about the planning process in the mind and then generate the image.
The planning process is enclosed within <think> </think> tags, i.e. <think> planning process here </think> image here"""
class
InterleaveInferencer
:
def
__init__
(
self
,
model
,
vae_model
,
tokenizer
,
vae_transform
,
vit_transform
,
new_token_ids
):
self
.
model
=
model
self
.
vae_model
=
vae_model
self
.
tokenizer
=
tokenizer
self
.
vae_transform
=
vae_transform
self
.
vit_transform
=
vit_transform
self
.
new_token_ids
=
new_token_ids
def
init_gen_context
(
self
):
gen_context
=
{
"kv_lens"
:
[
0
],
"ropes"
:
[
0
],
"past_key_values"
:
NaiveCache
(
self
.
model
.
config
.
llm_config
.
num_hidden_layers
),
}
return
gen_context
@
torch
.
no_grad
()
def
update_context_text
(
self
,
text
,
gen_context
):
# used for interleave data, currently only support 1 data inference,
past_key_values
=
gen_context
[
"past_key_values"
]
kv_lens
=
gen_context
[
"kv_lens"
]
ropes
=
gen_context
[
"ropes"
]
generation_input
,
kv_lens
,
ropes
=
self
.
model
.
prepare_prompts
(
curr_kvlens
=
kv_lens
,
curr_rope
=
ropes
,
prompts
=
[
text
],
tokenizer
=
self
.
tokenizer
,
new_token_ids
=
self
.
new_token_ids
,
)
past_key_values
=
self
.
model
.
forward_cache_update_text
(
past_key_values
,
**
generation_input
)
gen_context
[
"kv_lens"
]
=
kv_lens
gen_context
[
"ropes"
]
=
ropes
gen_context
[
"past_key_values"
]
=
past_key_values
return
gen_context
@
torch
.
no_grad
()
def
update_context_image
(
self
,
image
,
gen_context
,
vae
=
True
,
vit
=
True
):
# used for interleave data, currently only support 1 data inference,
assert
vae
or
vit
past_key_values
=
gen_context
[
"past_key_values"
]
kv_lens
=
gen_context
[
"kv_lens"
]
ropes
=
gen_context
[
"ropes"
]
if
vae
:
## update vae
generation_input
,
kv_lens
,
ropes
=
self
.
model
.
prepare_vae_images
(
curr_kvlens
=
kv_lens
,
curr_rope
=
ropes
,
images
=
[
image
],
transforms
=
self
.
vae_transform
,
new_token_ids
=
self
.
new_token_ids
,
)
past_key_values
=
self
.
model
.
forward_cache_update_vae
(
self
.
vae_model
,
past_key_values
,
**
generation_input
)
if
vit
:
## update vit
generation_input
,
kv_lens
,
ropes
=
self
.
model
.
prepare_vit_images
(
curr_kvlens
=
kv_lens
,
curr_rope
=
ropes
,
images
=
[
image
],
transforms
=
self
.
vit_transform
,
new_token_ids
=
self
.
new_token_ids
,
)
past_key_values
=
self
.
model
.
forward_cache_update_vit
(
past_key_values
,
**
generation_input
)
gen_context
[
"kv_lens"
]
=
kv_lens
gen_context
[
"ropes"
]
=
ropes
gen_context
[
"past_key_values"
]
=
past_key_values
return
gen_context
@
torch
.
no_grad
()
def
gen_image
(
self
,
image_shape
,
gen_context
,
cfg_text_scale
=
4.0
,
cfg_img_scale
=
1.5
,
cfg_text_precontext
=
None
,
cfg_img_precontext
=
None
,
cfg_interval
=
(
0.4
,
1.0
),
cfg_renorm_min
=
0.0
,
cfg_renorm_type
=
"global"
,
num_timesteps
=
50
,
timestep_shift
=
3.0
,
enable_taylorseer
=
False
,
):
# print(cfg_renorm_type)
past_key_values
=
gen_context
[
"past_key_values"
]
kv_lens
=
gen_context
[
"kv_lens"
]
ropes
=
gen_context
[
"ropes"
]
generation_input
=
self
.
model
.
prepare_vae_latent
(
curr_kvlens
=
kv_lens
,
curr_rope
=
ropes
,
image_sizes
=
[
image_shape
],
new_token_ids
=
self
.
new_token_ids
,
)
# text cfg
cfg_text_past_key_values
=
cfg_text_precontext
[
"past_key_values"
]
kv_lens_cfg
=
cfg_text_precontext
[
"kv_lens"
]
ropes_cfg
=
cfg_text_precontext
[
"ropes"
]
generation_input_cfg_text
=
self
.
model
.
prepare_vae_latent_cfg
(
curr_kvlens
=
kv_lens_cfg
,
curr_rope
=
ropes_cfg
,
image_sizes
=
[
image_shape
],
)
# img cfg
cfg_img_past_key_values
=
cfg_img_precontext
[
"past_key_values"
]
kv_lens_cfg
=
cfg_img_precontext
[
"kv_lens"
]
ropes_cfg
=
cfg_img_precontext
[
"ropes"
]
generation_input_cfg_img
=
self
.
model
.
prepare_vae_latent_cfg
(
curr_kvlens
=
kv_lens_cfg
,
curr_rope
=
ropes_cfg
,
image_sizes
=
[
image_shape
],
)
unpacked_latent
=
self
.
model
.
generate_image
(
past_key_values
=
past_key_values
,
cfg_text_past_key_values
=
cfg_text_past_key_values
,
cfg_img_past_key_values
=
cfg_img_past_key_values
,
num_timesteps
=
num_timesteps
,
cfg_text_scale
=
cfg_text_scale
,
cfg_img_scale
=
cfg_img_scale
,
cfg_interval
=
cfg_interval
,
cfg_renorm_min
=
cfg_renorm_min
,
cfg_renorm_type
=
cfg_renorm_type
,
timestep_shift
=
timestep_shift
,
**
generation_input
,
cfg_text_packed_position_ids
=
generation_input_cfg_text
[
"cfg_packed_position_ids"
],
cfg_text_packed_query_indexes
=
generation_input_cfg_text
[
"cfg_packed_query_indexes"
],
cfg_text_key_values_lens
=
generation_input_cfg_text
[
"cfg_key_values_lens"
],
cfg_text_packed_key_value_indexes
=
generation_input_cfg_text
[
"cfg_packed_key_value_indexes"
],
cfg_img_packed_position_ids
=
generation_input_cfg_img
[
"cfg_packed_position_ids"
],
cfg_img_packed_query_indexes
=
generation_input_cfg_img
[
"cfg_packed_query_indexes"
],
cfg_img_key_values_lens
=
generation_input_cfg_img
[
"cfg_key_values_lens"
],
cfg_img_packed_key_value_indexes
=
generation_input_cfg_img
[
"cfg_packed_key_value_indexes"
],
enable_taylorseer
=
enable_taylorseer
,
)
image
=
self
.
decode_image
(
unpacked_latent
[
0
],
image_shape
)
return
image
def
decode_image
(
self
,
latent
,
image_shape
):
H
,
W
=
image_shape
h
,
w
=
H
//
self
.
model
.
latent_downsample
,
W
//
self
.
model
.
latent_downsample
latent
=
latent
.
reshape
(
1
,
h
,
w
,
self
.
model
.
latent_patch_size
,
self
.
model
.
latent_patch_size
,
self
.
model
.
latent_channel
,
)
latent
=
torch
.
einsum
(
"nhwpqc->nchpwq"
,
latent
)
latent
=
latent
.
reshape
(
1
,
self
.
model
.
latent_channel
,
h
*
self
.
model
.
latent_patch_size
,
w
*
self
.
model
.
latent_patch_size
,
)
image
=
self
.
vae_model
.
decode
(
latent
)
image
=
(
image
*
0.5
+
0.5
).
clamp
(
0
,
1
)[
0
].
permute
(
1
,
2
,
0
)
*
255
image
=
Image
.
fromarray
((
image
).
to
(
torch
.
uint8
).
cpu
().
numpy
())
return
image
@
torch
.
no_grad
()
def
gen_text
(
self
,
gen_context
,
max_length
:
int
=
500
,
do_sample
:
bool
=
True
,
temperature
:
float
=
1.0
,
):
gen_context
=
deepcopy
(
gen_context
)
past_key_values
=
gen_context
[
"past_key_values"
]
kv_lens
=
gen_context
[
"kv_lens"
]
ropes
=
gen_context
[
"ropes"
]
generation_input
=
self
.
model
.
prepare_start_tokens
(
kv_lens
,
ropes
,
self
.
new_token_ids
)
unpacked_latent
=
self
.
model
.
generate_text
(
past_key_values
=
past_key_values
,
max_length
=
max_length
,
do_sample
=
do_sample
,
temperature
=
temperature
,
end_token_id
=
self
.
new_token_ids
[
"eos_token_id"
],
**
generation_input
,
)
output
=
self
.
tokenizer
.
decode
(
unpacked_latent
[:,
0
])
output
=
output
.
split
(
"<|im_end|>"
)[
0
].
split
(
"<|im_start|>"
)[
1
]
return
output
@
torch
.
no_grad
()
def
interleave_inference
(
self
,
input_lists
:
List
[
Union
[
str
,
Image
.
Image
]],
think
=
False
,
understanding_output
=
False
,
max_think_token_n
=
1000
,
do_sample
=
False
,
text_temperature
=
0.3
,
cfg_text_scale
=
3.0
,
cfg_img_scale
=
1.5
,
cfg_interval
=
[
0.4
,
1.0
],
timestep_shift
=
3.0
,
num_timesteps
=
50
,
cfg_renorm_min
=
0.0
,
cfg_renorm_type
=
"global"
,
image_shapes
=
(
1024
,
1024
),
enable_taylorseer
=
False
,
)
->
List
[
Union
[
str
,
Image
.
Image
]]:
output_list
=
[]
gen_context
=
self
.
init_gen_context
()
cfg_text_context
=
deepcopy
(
gen_context
)
cfg_img_context
=
deepcopy
(
gen_context
)
with
torch
.
autocast
(
device_type
=
"cuda"
,
enabled
=
True
,
dtype
=
torch
.
bfloat16
):
if
think
:
if
understanding_output
:
system_prompt
=
VLM_THINK_SYSTEM_PROMPT
else
:
system_prompt
=
GEN_THINK_SYSTEM_PROMPT
gen_context
=
self
.
update_context_text
(
system_prompt
,
gen_context
)
cfg_img_context
=
self
.
update_context_text
(
system_prompt
,
cfg_img_context
)
for
input_term
in
input_lists
:
if
isinstance
(
input_term
,
str
):
cfg_text_context
=
deepcopy
(
gen_context
)
gen_context
=
self
.
update_context_text
(
input_term
,
gen_context
)
cfg_img_context
=
self
.
update_context_text
(
input_term
,
cfg_img_context
)
elif
isinstance
(
input_term
,
Image
.
Image
):
input_term
=
self
.
vae_transform
.
resize_transform
(
pil_img2rgb
(
input_term
)
)
gen_context
=
self
.
update_context_image
(
input_term
,
gen_context
,
vae
=
not
understanding_output
)
image_shapes
=
input_term
.
size
[::
-
1
]
cfg_text_context
=
deepcopy
(
gen_context
)
else
:
raise
ValueError
(
f
"Unsupported input type:
{
type
(
input_term
)
}
"
)
if
understanding_output
:
gen_text
=
self
.
gen_text
(
gen_context
,
do_sample
=
do_sample
,
temperature
=
text_temperature
,
max_length
=
max_think_token_n
,
)
output_list
.
append
(
gen_text
)
else
:
if
think
:
gen_text
=
self
.
gen_text
(
gen_context
,
do_sample
=
do_sample
,
temperature
=
text_temperature
,
max_length
=
max_think_token_n
,
)
gen_context
=
self
.
update_context_text
(
gen_text
,
gen_context
)
output_list
.
append
(
gen_text
)
img
=
self
.
gen_image
(
image_shapes
,
gen_context
,
cfg_text_precontext
=
cfg_text_context
,
cfg_img_precontext
=
cfg_img_context
,
cfg_text_scale
=
cfg_text_scale
,
cfg_img_scale
=
cfg_img_scale
,
cfg_interval
=
cfg_interval
,
timestep_shift
=
timestep_shift
,
num_timesteps
=
num_timesteps
,
cfg_renorm_min
=
cfg_renorm_min
,
cfg_renorm_type
=
cfg_renorm_type
,
enable_taylorseer
=
enable_taylorseer
,
)
output_list
.
append
(
img
)
return
output_list
def
__call__
(
self
,
image
:
Optional
[
Image
.
Image
]
=
None
,
text
:
Optional
[
str
]
=
None
,
**
kargs
)
->
Dict
[
str
,
Any
]:
output_dict
=
{
"image"
:
None
,
"text"
:
None
}
if
image
is
None
and
text
is
None
:
print
(
"Please provide at least one input: either an image or text."
)
return
output_dict
input_list
=
[]
if
image
is
not
None
:
input_list
.
append
(
image
)
if
text
is
not
None
:
input_list
.
append
(
text
)
output_list
=
self
.
interleave_inference
(
input_list
,
**
kargs
)
for
i
in
output_list
:
if
isinstance
(
i
,
Image
.
Image
):
output_dict
[
"image"
]
=
i
elif
isinstance
(
i
,
str
):
output_dict
[
"text"
]
=
i
return
output_dict
SenseNova-SI-main/sensenova_si/bagel_utils/modeling/__init__.py
0 → 100644
View file @
876a36a4
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
from
.
import
autoencoder
,
bagel
,
qwen2
,
siglip
SenseNova-SI-main/sensenova_si/bagel_utils/modeling/autoencoder.py
0 → 100644
View file @
876a36a4
This diff is collapsed.
Click to expand it.
SenseNova-SI-main/sensenova_si/bagel_utils/modeling/bagel/__init__.py
0 → 100644
View file @
876a36a4
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
from
.bagel
import
Bagel
,
BagelConfig
from
.qwen2_navit
import
Qwen2Config
,
Qwen2ForCausalLM
,
Qwen2Model
from
.siglip_navit
import
SiglipVisionConfig
,
SiglipVisionModel
__all__
=
[
"BagelConfig"
,
"Bagel"
,
"Qwen2Config"
,
"Qwen2Model"
,
"Qwen2ForCausalLM"
,
"SiglipVisionConfig"
,
"SiglipVisionModel"
,
]
SenseNova-SI-main/sensenova_si/bagel_utils/modeling/bagel/bagel.py
0 → 100644
View file @
876a36a4
This diff is collapsed.
Click to expand it.
SenseNova-SI-main/sensenova_si/bagel_utils/modeling/bagel/modeling_utils.py
0 → 100644
View file @
876a36a4
# Copyright (c) 2022 Facebook, Inc. and its affiliates.
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: CC BY-NC 4.0
#
# This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025-05-20.
#
# Original file was released under CC BY-NC 4.0, with the full license text
# available at https://github.com/facebookresearch/DiT/blob/main/LICENSE.txt.
#
# This modified file is released under the same license.
import
math
import
numpy
as
np
import
torch
from
torch
import
nn
from
transformers.activations
import
ACT2FN
# --------------------------------------------------------
# 2D sine-cosine position embedding
# References:
# DiT: https://github.com/facebookresearch/DiT/blob/main/models.py
# --------------------------------------------------------
def
get_2d_sincos_pos_embed
(
embed_dim
,
grid_size
,
cls_token
=
False
,
extra_tokens
=
0
):
grid_h
=
np
.
arange
(
grid_size
,
dtype
=
np
.
float32
)
grid_w
=
np
.
arange
(
grid_size
,
dtype
=
np
.
float32
)
grid
=
np
.
meshgrid
(
grid_w
,
grid_h
)
# here w goes first
grid
=
np
.
stack
(
grid
,
axis
=
0
)
grid
=
grid
.
reshape
([
2
,
1
,
grid_size
,
grid_size
])
pos_embed
=
get_2d_sincos_pos_embed_from_grid
(
embed_dim
,
grid
)
if
cls_token
and
extra_tokens
>
0
:
pos_embed
=
np
.
concatenate
(
[
np
.
zeros
([
extra_tokens
,
embed_dim
]),
pos_embed
],
axis
=
0
)
return
pos_embed
def
get_2d_sincos_pos_embed_from_grid
(
embed_dim
,
grid
):
assert
embed_dim
%
2
==
0
# use half of dimensions to encode grid_h
emb_h
=
get_1d_sincos_pos_embed_from_grid
(
embed_dim
//
2
,
grid
[
0
])
# (H*W, D/2)
emb_w
=
get_1d_sincos_pos_embed_from_grid
(
embed_dim
//
2
,
grid
[
1
])
# (H*W, D/2)
emb
=
np
.
concatenate
([
emb_h
,
emb_w
],
axis
=
1
)
# (H*W, D)
return
emb
def
get_1d_sincos_pos_embed_from_grid
(
embed_dim
,
pos
):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert
embed_dim
%
2
==
0
omega
=
np
.
arange
(
embed_dim
//
2
,
dtype
=
np
.
float64
)
omega
/=
embed_dim
/
2.0
omega
=
1.0
/
10000
**
omega
# (D/2,)
pos
=
pos
.
reshape
(
-
1
)
# (M,)
out
=
np
.
einsum
(
"m,d->md"
,
pos
,
omega
)
# (M, D/2), outer product
emb_sin
=
np
.
sin
(
out
)
# (M, D/2)
emb_cos
=
np
.
cos
(
out
)
# (M, D/2)
emb
=
np
.
concatenate
([
emb_sin
,
emb_cos
],
axis
=
1
)
# (M, D)
return
emb
# --------------------------------------------------------
# TimestepEmbedder
# Reference:
# DiT: https://github.com/facebookresearch/DiT/blob/main/models.py
# --------------------------------------------------------
class
TimestepEmbedder
(
nn
.
Module
):
"""
Embeds scalar timesteps into vector representations.
"""
def
__init__
(
self
,
hidden_size
,
frequency_embedding_size
=
256
):
super
().
__init__
()
self
.
mlp
=
nn
.
Sequential
(
nn
.
Linear
(
frequency_embedding_size
,
hidden_size
,
bias
=
True
),
nn
.
SiLU
(),
nn
.
Linear
(
hidden_size
,
hidden_size
,
bias
=
True
),
)
self
.
frequency_embedding_size
=
frequency_embedding_size
@
staticmethod
def
timestep_embedding
(
t
,
dim
,
max_period
=
10000
):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
half
=
dim
//
2
freqs
=
torch
.
exp
(
-
math
.
log
(
max_period
)
*
torch
.
arange
(
start
=
0
,
end
=
half
,
dtype
=
torch
.
float32
)
/
half
).
to
(
device
=
t
.
device
)
args
=
t
[:,
None
].
float
()
*
freqs
[
None
]
embedding
=
torch
.
cat
([
torch
.
cos
(
args
),
torch
.
sin
(
args
)],
dim
=-
1
)
if
dim
%
2
:
embedding
=
torch
.
cat
(
[
embedding
,
torch
.
zeros_like
(
embedding
[:,
:
1
])],
dim
=-
1
)
return
embedding
def
forward
(
self
,
t
):
t_freq
=
self
.
timestep_embedding
(
t
,
self
.
frequency_embedding_size
)
t_emb
=
self
.
mlp
(
t_freq
)
return
t_emb
class
MLPconnector
(
nn
.
Module
):
def
__init__
(
self
,
in_dim
:
int
,
out_dim
:
int
,
hidden_act
:
str
):
super
().
__init__
()
self
.
activation_fn
=
ACT2FN
[
hidden_act
]
self
.
fc1
=
nn
.
Linear
(
in_dim
,
out_dim
)
self
.
fc2
=
nn
.
Linear
(
out_dim
,
out_dim
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
=
self
.
fc1
(
hidden_states
)
hidden_states
=
self
.
activation_fn
(
hidden_states
)
hidden_states
=
self
.
fc2
(
hidden_states
)
return
hidden_states
class
PositionEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
max_num_patch_per_side
,
hidden_size
):
super
().
__init__
()
self
.
max_num_patch_per_side
=
max_num_patch_per_side
self
.
hidden_size
=
hidden_size
self
.
pos_embed
=
nn
.
Parameter
(
torch
.
zeros
(
max_num_patch_per_side
**
2
,
hidden_size
),
requires_grad
=
False
)
self
.
_init_weights
()
def
_init_weights
(
self
):
# Initialize (and freeze) pos_embed by sin-cos embedding:
pos_embed
=
get_2d_sincos_pos_embed
(
self
.
hidden_size
,
self
.
max_num_patch_per_side
)
self
.
pos_embed
.
data
.
copy_
(
torch
.
from_numpy
(
pos_embed
).
float
())
def
forward
(
self
,
position_ids
):
return
self
.
pos_embed
[
position_ids
]
SenseNova-SI-main/sensenova_si/bagel_utils/modeling/bagel/qwen2_navit.py
0 → 100644
View file @
876a36a4
This diff is collapsed.
Click to expand it.
SenseNova-SI-main/sensenova_si/bagel_utils/modeling/bagel/siglip_navit.py
0 → 100644
View file @
876a36a4
This diff is collapsed.
Click to expand it.
SenseNova-SI-main/sensenova_si/bagel_utils/modeling/cache_utils/taylorseer.py
0 → 100644
View file @
876a36a4
This diff is collapsed.
Click to expand it.
Prev
1
2
3
4
5
6
…
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment