Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ModelZoo
SenseNova-SI
Commits
876a36a4
"vscode:/vscode.git/clone" did not exist on "47684368dbbe4185d068be77d32a962059cfc37c"
Commit
876a36a4
authored
May 27, 2026
by
raojy
Browse files
first
parent
eda2afb8
Changes
175
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
5670 additions
and
0 deletions
+5670
-0
SenseNova-SI-main/training/bagel/data/data_utils.py
SenseNova-SI-main/training/bagel/data/data_utils.py
+202
-0
SenseNova-SI-main/training/bagel/data/dataset_base.py
SenseNova-SI-main/training/bagel/data/dataset_base.py
+732
-0
SenseNova-SI-main/training/bagel/data/dataset_info.py
SenseNova-SI-main/training/bagel/data/dataset_info.py
+48
-0
SenseNova-SI-main/training/bagel/data/dataset_info/sensenova_si_800K.json
...n/training/bagel/data/dataset_info/sensenova_si_800K.json
+7
-0
SenseNova-SI-main/training/bagel/data/dataset_info/sensenova_si_8M.json
...ain/training/bagel/data/dataset_info/sensenova_si_8M.json
+7
-0
SenseNova-SI-main/training/bagel/data/distributed_iterable_dataset.py
...-main/training/bagel/data/distributed_iterable_dataset.py
+59
-0
SenseNova-SI-main/training/bagel/data/edit_dataset_jsonl.py
SenseNova-SI-main/training/bagel/data/edit_dataset_jsonl.py
+243
-0
SenseNova-SI-main/training/bagel/data/parquet_utils.py
SenseNova-SI-main/training/bagel/data/parquet_utils.py
+95
-0
SenseNova-SI-main/training/bagel/data/t2i_dataset.py
SenseNova-SI-main/training/bagel/data/t2i_dataset.py
+158
-0
SenseNova-SI-main/training/bagel/data/t2i_dataset_jsonl.py
SenseNova-SI-main/training/bagel/data/t2i_dataset_jsonl.py
+158
-0
SenseNova-SI-main/training/bagel/data/transforms.py
SenseNova-SI-main/training/bagel/data/transforms.py
+306
-0
SenseNova-SI-main/training/bagel/data/video_utils.py
SenseNova-SI-main/training/bagel/data/video_utils.py
+179
-0
SenseNova-SI-main/training/bagel/data/vlm_dataset.py
SenseNova-SI-main/training/bagel/data/vlm_dataset.py
+231
-0
SenseNova-SI-main/training/bagel/environment.yml
SenseNova-SI-main/training/bagel/environment.yml
+98
-0
SenseNova-SI-main/training/bagel/modeling/__init__.py
SenseNova-SI-main/training/bagel/modeling/__init__.py
+4
-0
SenseNova-SI-main/training/bagel/modeling/autoencoder.py
SenseNova-SI-main/training/bagel/modeling/autoencoder.py
+386
-0
SenseNova-SI-main/training/bagel/modeling/bagel/__init__.py
SenseNova-SI-main/training/bagel/modeling/bagel/__init__.py
+17
-0
SenseNova-SI-main/training/bagel/modeling/bagel/bagel.py
SenseNova-SI-main/training/bagel/modeling/bagel/bagel.py
+1175
-0
SenseNova-SI-main/training/bagel/modeling/bagel/modeling_utils.py
...a-SI-main/training/bagel/modeling/bagel/modeling_utils.py
+153
-0
SenseNova-SI-main/training/bagel/modeling/bagel/qwen2_navit.py
...Nova-SI-main/training/bagel/modeling/bagel/qwen2_navit.py
+1412
-0
No files found.
SenseNova-SI-main/training/bagel/data/data_utils.py
0 → 100644
View file @
876a36a4
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
import
math
import
random
import
numpy
as
np
import
torch
from
PIL
import
Image
from
torch.nn.attention.flex_attention
import
and_masks
,
or_masks
def
create_sparse_mask
(
document_lens
,
split_lens
,
attn_modes
,
device
):
def
causal_mask
(
b
,
h
,
q_idx
,
kv_idx
):
return
q_idx
>=
kv_idx
def
full_and_noise_mask
(
b
,
h
,
q_idx
,
kv_idx
):
return
(
full_and_noise_seq_id
[
q_idx
]
==
full_and_noise_seq_id
[
kv_idx
])
&
(
full_and_noise_seq_id
[
q_idx
]
>=
0
)
def
remove_noise_mask
(
b
,
h
,
q_idx
,
kv_idx
):
return
~
(
(
noise_seq_id
[
kv_idx
]
>=
0
)
&
(
noise_seq_id
[
q_idx
]
!=
noise_seq_id
[
kv_idx
])
)
def
sample_mask
(
b
,
h
,
q_idx
,
kv_idx
):
return
document_id
[
q_idx
]
==
document_id
[
kv_idx
]
full_and_noise_tmp
=
[]
noise_tmp
=
[]
for
i
,
(
length
,
model
)
in
enumerate
(
zip
(
split_lens
,
attn_modes
)):
value
=
i
if
model
in
[
"full"
,
"noise"
]
else
-
1
full_and_noise_tmp
.
extend
([
value
]
*
length
)
value_noise
=
i
if
model
==
"noise"
else
-
1
noise_tmp
.
extend
([
value_noise
]
*
length
)
full_and_noise_seq_id
=
torch
.
Tensor
(
full_and_noise_tmp
).
to
(
device
)
noise_seq_id
=
torch
.
Tensor
(
noise_tmp
).
to
(
device
)
document_id
=
torch
.
cat
(
[
torch
.
full
((
l
,),
i
)
for
i
,
l
in
enumerate
(
document_lens
,
start
=
1
)]
).
to
(
device
)
return
and_masks
(
or_masks
(
causal_mask
,
full_and_noise_mask
),
remove_noise_mask
,
sample_mask
)
def
patchify
(
image
,
patch_size
):
p
=
patch_size
c
,
h
,
w
=
image
.
shape
assert
h
%
p
==
0
and
w
%
p
==
0
image
=
image
.
reshape
(
c
,
h
//
p
,
p
,
w
//
p
,
p
)
image
=
torch
.
einsum
(
"chpwq->hwpqc"
,
image
)
image
=
image
.
reshape
(
-
1
,
p
**
2
*
c
)
return
image
def
get_flattened_position_ids_extrapolate
(
img_h
,
img_w
,
patch_size
,
max_num_patches_per_side
):
num_patches_h
,
num_patches_w
=
img_h
//
patch_size
,
img_w
//
patch_size
coords_h
=
torch
.
arange
(
0
,
num_patches_h
)
coords_w
=
torch
.
arange
(
0
,
num_patches_w
)
pos_ids
=
(
coords_h
[:,
None
]
*
max_num_patches_per_side
+
coords_w
).
flatten
()
return
pos_ids
def
get_flattened_position_ids_interpolate
(
img_h
,
img_w
,
patch_size
,
max_num_patches_per_side
):
num_patches_h
,
num_patches_w
=
img_h
//
patch_size
,
img_w
//
patch_size
boundaries
=
torch
.
arange
(
1
/
max_num_patches_per_side
,
1.0
,
1
/
max_num_patches_per_side
)
fractional_coords_h
=
torch
.
arange
(
0
,
1
-
1e-6
,
1
/
num_patches_h
)
fractional_coords_w
=
torch
.
arange
(
0
,
1
-
1e-6
,
1
/
num_patches_w
)
bucket_coords_h
=
torch
.
bucketize
(
fractional_coords_h
,
boundaries
,
right
=
True
)
bucket_coords_w
=
torch
.
bucketize
(
fractional_coords_w
,
boundaries
,
right
=
True
)
pos_ids
=
(
bucket_coords_h
[:,
None
]
*
max_num_patches_per_side
+
bucket_coords_w
).
flatten
()
return
pos_ids
def
prepare_attention_mask_per_sample
(
split_lens
,
attn_modes
,
device
=
"cpu"
):
"""
nested_split_lens: A list of N lists of ints. Each int indicates the length of a split within
a sample, where each sample contains multiple splits with different attn modes.
nested_attn_modes: whether to use full attn in each split.
"""
sample_len
=
sum
(
split_lens
)
attention_mask
=
torch
.
zeros
(
(
sample_len
,
sample_len
),
dtype
=
torch
.
bool
,
device
=
device
)
csum
=
0
for
s
,
attn_mode
in
zip
(
split_lens
,
attn_modes
):
assert
attn_mode
in
[
"causal"
,
"full"
,
"noise"
]
if
attn_mode
==
"causal"
:
attention_mask
[
csum
:
csum
+
s
,
csum
:
csum
+
s
]
=
torch
.
ones
(
(
s
,
s
),
device
=
device
).
tril
()
attention_mask
[
csum
:
csum
+
s
,
:
csum
]
=
1
else
:
attention_mask
[
csum
:
csum
+
s
,
csum
:
csum
+
s
]
=
torch
.
ones
((
s
,
s
))
attention_mask
[
csum
:
csum
+
s
,
:
csum
]
=
1
csum
+=
s
csum
=
0
for
s
,
attn_mode
in
zip
(
split_lens
,
attn_modes
):
if
attn_mode
==
"noise"
:
attention_mask
[:,
csum
:
csum
+
s
]
=
torch
.
zeros
((
sample_len
,
s
))
attention_mask
[
csum
:
csum
+
s
,
csum
:
csum
+
s
]
=
torch
.
ones
((
s
,
s
))
csum
+=
s
attention_mask
=
torch
.
zeros_like
(
attention_mask
,
dtype
=
torch
.
float
).
masked_fill_
(
~
attention_mask
,
float
(
"-inf"
)
)
return
attention_mask
def
split_integer_exp_decay
(
S
,
ng_sample_decay
=
1.0
):
if
ng_sample_decay
==
1.0
:
N
=
random
.
randint
(
1
,
S
)
else
:
base
=
(
1
-
ng_sample_decay
)
/
(
1
-
math
.
pow
(
ng_sample_decay
,
S
))
p
=
[
base
*
math
.
pow
(
ng_sample_decay
,
i
)
for
i
in
range
(
S
)]
N
=
random
.
choices
(
list
(
range
(
1
,
S
+
1
)),
p
,
k
=
1
)[
0
]
cumsum
=
[
0
]
+
sorted
(
random
.
sample
(
range
(
1
,
S
),
N
-
1
))
+
[
S
]
result
=
[
cumsum
[
i
+
1
]
-
cumsum
[
i
]
for
i
in
range
(
len
(
cumsum
)
-
1
)]
return
result
,
cumsum
def
pil_img2rgb
(
image
):
if
image
.
mode
==
"RGBA"
or
image
.
info
.
get
(
"transparency"
,
None
)
is
not
None
:
image
=
image
.
convert
(
"RGBA"
)
white
=
Image
.
new
(
mode
=
"RGB"
,
size
=
image
.
size
,
color
=
(
255
,
255
,
255
))
white
.
paste
(
image
,
mask
=
image
.
split
()[
3
])
image
=
white
else
:
image
=
image
.
convert
(
"RGB"
)
return
image
def
add_special_tokens
(
tokenizer
):
all_special_tokens
=
[]
for
k
,
v
in
tokenizer
.
special_tokens_map
.
items
():
if
isinstance
(
v
,
str
):
all_special_tokens
.
append
(
v
)
elif
isinstance
(
v
,
list
):
all_special_tokens
+=
v
new_tokens
=
[]
if
"<|im_start|>"
not
in
all_special_tokens
:
new_tokens
.
append
(
"<|im_start|>"
)
if
"<|im_end|>"
not
in
all_special_tokens
:
new_tokens
.
append
(
"<|im_end|>"
)
if
"<|vision_start|>"
not
in
all_special_tokens
:
new_tokens
.
append
(
"<|vision_start|>"
)
if
"<|vision_end|>"
not
in
all_special_tokens
:
new_tokens
.
append
(
"<|vision_end|>"
)
num_new_tokens
=
tokenizer
.
add_tokens
(
new_tokens
)
bos_token_id
=
tokenizer
.
convert_tokens_to_ids
(
"<|im_start|>"
)
eos_token_id
=
tokenizer
.
convert_tokens_to_ids
(
"<|im_end|>"
)
start_of_image
=
tokenizer
.
convert_tokens_to_ids
(
"<|vision_start|>"
)
end_of_image
=
tokenizer
.
convert_tokens_to_ids
(
"<|vision_end|>"
)
new_token_ids
=
dict
(
bos_token_id
=
bos_token_id
,
eos_token_id
=
eos_token_id
,
start_of_image
=
start_of_image
,
end_of_image
=
end_of_image
,
)
return
tokenizer
,
new_token_ids
,
num_new_tokens
def
len2weight
(
x
,
loss_reduction
=
"square"
):
if
x
==
0
:
return
x
if
loss_reduction
==
"token"
:
return
1
if
loss_reduction
==
"sample"
:
return
1
/
x
if
loss_reduction
==
"square"
:
return
1
/
(
x
**
0.5
)
raise
NotImplementedError
(
loss_reduction
)
def
load_image
(
image_path
):
return
Image
.
open
(
image_path
)
SenseNova-SI-main/training/bagel/data/dataset_base.py
0 → 100644
View file @
876a36a4
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
import
json
import
random
import
numpy
as
np
import
torch
from
.data_utils
import
(
get_flattened_position_ids_extrapolate
,
get_flattened_position_ids_interpolate
,
len2weight
,
patchify
,
prepare_attention_mask_per_sample
,
)
from
.dataset_info
import
DATASET_INFO
,
DATASET_REGISTRY
from
.transforms
import
ImageTransform
from
.video_utils
import
FrameSampler
class
DataConfig
:
def
__init__
(
self
,
grouped_datasets
,
text_cond_dropout_prob
=
0.1
,
vit_cond_dropout_prob
=
0.4
,
vae_cond_dropout_prob
=
0.1
,
vae_image_downsample
=
16
,
max_latent_size
=
32
,
vit_patch_size
=
14
,
max_num_patch_per_side
=
70
,
):
self
.
grouped_datasets
=
grouped_datasets
self
.
text_cond_dropout_prob
=
text_cond_dropout_prob
self
.
vit_cond_dropout_prob
=
vit_cond_dropout_prob
self
.
vit_patch_size
=
vit_patch_size
self
.
max_num_patch_per_side
=
max_num_patch_per_side
self
.
vae_cond_dropout_prob
=
vae_cond_dropout_prob
self
.
vae_image_downsample
=
vae_image_downsample
self
.
max_latent_size
=
max_latent_size
class
PackedDataset
(
torch
.
utils
.
data
.
IterableDataset
):
bos_token_id
:
int
eos_token_id
:
int
start_of_image
:
int
end_of_image
:
int
def
__init__
(
self
,
data_config
,
tokenizer
,
special_tokens
,
local_rank
,
world_size
,
num_workers
,
expected_num_tokens
=
32768
,
max_num_tokens_per_sample
=
16384
,
max_num_tokens
=
36864
,
prefer_buffer_before
=
16384
,
max_buffer_size
=
50
,
interpolate_pos
=
False
,
use_flex
=
False
,
data_status
=
None
,
):
super
().
__init__
()
self
.
expected_num_tokens
=
expected_num_tokens
self
.
max_num_tokens_per_sample
=
max_num_tokens_per_sample
self
.
prefer_buffer_before
=
prefer_buffer_before
self
.
max_num_tokens
=
max_num_tokens
self
.
max_buffer_size
=
max_buffer_size
self
.
tokenizer
=
tokenizer
self
.
local_rank
=
local_rank
self
.
world_size
=
world_size
self
.
num_workers
=
num_workers
self
.
use_flex
=
use_flex
for
k
,
v
in
special_tokens
.
items
():
setattr
(
self
,
k
,
v
)
grouped_datasets
,
is_mandatory
,
grouped_weights
=
self
.
build_datasets
(
data_config
.
grouped_datasets
,
data_status
)
self
.
grouped_datasets
=
grouped_datasets
self
.
dataset_iters
=
[
iter
(
dataset
)
for
dataset
in
grouped_datasets
]
self
.
is_mandatory
=
is_mandatory
self
.
grouped_weights
=
grouped_weights
self
.
data_config
=
data_config
self
.
interpolate_pos
=
interpolate_pos
if
self
.
interpolate_pos
:
self
.
get_flattened_position_ids
=
get_flattened_position_ids_interpolate
else
:
self
.
get_flattened_position_ids
=
get_flattened_position_ids_extrapolate
def
build_datasets
(
self
,
datasets_metainfo
,
data_status
):
datasets
=
[]
is_mandatory
=
[]
grouped_weights
=
[]
for
grouped_dataset_name
,
dataset_args
in
datasets_metainfo
.
items
():
is_mandatory
.
append
(
dataset_args
.
pop
(
"is_mandatory"
,
False
))
grouped_weights
.
append
(
dataset_args
.
pop
(
"weight"
,
0.0
))
if
"frame_sampler_args"
in
dataset_args
.
keys
():
frame_sampler
=
FrameSampler
(
**
dataset_args
.
pop
(
"frame_sampler_args"
))
dataset_args
[
"frame_sampler"
]
=
frame_sampler
if
"image_transform_args"
in
dataset_args
.
keys
():
transform
=
ImageTransform
(
**
dataset_args
.
pop
(
"image_transform_args"
))
dataset_args
[
"transform"
]
=
transform
if
"vit_image_transform_args"
in
dataset_args
.
keys
():
vit_transform
=
ImageTransform
(
**
dataset_args
.
pop
(
"vit_image_transform_args"
)
)
dataset_args
[
"vit_transform"
]
=
vit_transform
if
"dataset_names"
in
dataset_args
.
keys
():
dataset_names
=
dataset_args
.
pop
(
"dataset_names"
)
else
:
dataset_names
=
DATASET_INFO
[
grouped_dataset_name
].
keys
()
if
"num_used_data"
not
in
dataset_args
.
keys
():
dataset_args
[
"num_used_data"
]
=
[]
append_num_used_data
=
True
else
:
append_num_used_data
=
False
dataset_args
[
"data_dir_list"
]
=
[]
for
item
in
dataset_names
:
if
self
.
local_rank
==
0
:
print
(
f
"Preparing Dataset
{
grouped_dataset_name
}
/
{
item
}
"
)
meta_info
=
DATASET_INFO
[
grouped_dataset_name
][
item
]
dataset_args
[
"data_dir_list"
].
append
(
meta_info
[
"data_dir"
])
if
append_num_used_data
:
dataset_args
[
"num_used_data"
].
append
(
meta_info
[
"num_total_samples"
])
if
"parquet_info_path"
in
meta_info
.
keys
():
if
"parquet_info"
not
in
dataset_args
.
keys
():
dataset_args
[
"parquet_info"
]
=
{}
with
open
(
meta_info
[
"parquet_info_path"
],
"r"
)
as
f
:
parquet_info
=
json
.
load
(
f
)
dataset_args
[
"parquet_info"
].
update
(
parquet_info
)
if
"json_dir"
in
meta_info
.
keys
():
# parquet/tar with json
if
"json_dir_list"
not
in
dataset_args
.
keys
():
dataset_args
[
"json_dir_list"
]
=
[
meta_info
[
"json_dir"
]]
else
:
dataset_args
[
"json_dir_list"
].
append
(
meta_info
[
"json_dir"
])
if
"jsonl_path"
in
meta_info
.
keys
():
# jsonl with jpeg
if
"jsonl_path_list"
not
in
dataset_args
.
keys
():
dataset_args
[
"jsonl_path_list"
]
=
[
meta_info
[
"jsonl_path"
]]
else
:
dataset_args
[
"jsonl_path_list"
].
append
(
meta_info
[
"jsonl_path"
])
resume_data_status
=
dataset_args
.
pop
(
"resume_data_status"
,
True
)
if
(
data_status
is
not
None
and
grouped_dataset_name
in
data_status
.
keys
()
and
resume_data_status
):
data_status_per_group
=
data_status
[
grouped_dataset_name
]
else
:
data_status_per_group
=
None
dataset
=
DATASET_REGISTRY
[
grouped_dataset_name
](
dataset_name
=
grouped_dataset_name
,
tokenizer
=
self
.
tokenizer
,
local_rank
=
self
.
local_rank
,
world_size
=
self
.
world_size
,
num_workers
=
self
.
num_workers
,
data_status
=
data_status_per_group
,
**
dataset_args
,
)
datasets
.
append
(
dataset
)
return
datasets
,
is_mandatory
,
grouped_weights
def
set_epoch
(
self
,
seed
):
for
dataset
in
self
.
grouped_datasets
:
dataset
.
set_epoch
(
seed
)
def
set_sequence_status
(
self
):
sequence_status
=
dict
(
curr
=
0
,
sample_lens
=
list
(),
packed_position_ids
=
list
(),
nested_attention_masks
=
list
(),
split_lens
=
list
(),
attn_modes
=
list
(),
packed_text_ids
=
list
(),
packed_text_indexes
=
list
(),
packed_label_ids
=
list
(),
ce_loss_indexes
=
list
(),
ce_loss_weights
=
list
(),
vae_image_tensors
=
list
(),
packed_latent_position_ids
=
list
(),
vae_latent_shapes
=
list
(),
packed_vae_token_indexes
=
list
(),
packed_timesteps
=
list
(),
mse_loss_indexes
=
list
(),
packed_vit_tokens
=
list
(),
vit_token_seqlens
=
list
(),
packed_vit_position_ids
=
list
(),
packed_vit_token_indexes
=
list
(),
)
return
sequence_status
def
to_tensor
(
self
,
sequence_status
):
data
=
dict
(
sequence_length
=
sum
(
sequence_status
[
"sample_lens"
]),
sample_lens
=
sequence_status
[
"sample_lens"
],
packed_text_ids
=
torch
.
tensor
(
sequence_status
[
"packed_text_ids"
]),
packed_text_indexes
=
torch
.
tensor
(
sequence_status
[
"packed_text_indexes"
]),
packed_position_ids
=
torch
.
tensor
(
sequence_status
[
"packed_position_ids"
]),
)
if
not
self
.
use_flex
:
data
[
"nested_attention_masks"
]
=
sequence_status
[
"nested_attention_masks"
]
else
:
sequence_len
=
data
[
"sequence_length"
]
pad_len
=
self
.
max_num_tokens
-
sequence_len
data
[
"split_lens"
]
=
sequence_status
[
"split_lens"
]
+
[
pad_len
]
data
[
"attn_modes"
]
=
sequence_status
[
"attn_modes"
]
+
[
"causal"
]
data
[
"sample_lens"
]
+=
[
pad_len
]
# if the model has a convnet vae (e.g., as visual tokenizer)
if
len
(
sequence_status
[
"vae_image_tensors"
])
>
0
:
image_tensors
=
sequence_status
.
pop
(
"vae_image_tensors"
)
image_sizes
=
[
item
.
shape
for
item
in
image_tensors
]
max_image_size
=
[
max
(
item
)
for
item
in
list
(
zip
(
*
image_sizes
))]
padded_images
=
torch
.
zeros
(
size
=
(
len
(
image_tensors
),
*
max_image_size
))
for
i
,
image_tensor
in
enumerate
(
image_tensors
):
padded_images
[
i
,
:,
:
image_tensor
.
shape
[
1
],
:
image_tensor
.
shape
[
2
]
]
=
image_tensor
data
[
"padded_images"
]
=
padded_images
data
[
"patchified_vae_latent_shapes"
]
=
sequence_status
[
"vae_latent_shapes"
]
data
[
"packed_latent_position_ids"
]
=
torch
.
cat
(
sequence_status
[
"packed_latent_position_ids"
],
dim
=
0
)
data
[
"packed_vae_token_indexes"
]
=
torch
.
tensor
(
sequence_status
[
"packed_vae_token_indexes"
]
)
# if the model has a vit (e.g., as visual tokenizer)
if
len
(
sequence_status
[
"packed_vit_tokens"
])
>
0
:
data
[
"packed_vit_tokens"
]
=
torch
.
cat
(
sequence_status
[
"packed_vit_tokens"
],
dim
=
0
)
data
[
"packed_vit_position_ids"
]
=
torch
.
cat
(
sequence_status
[
"packed_vit_position_ids"
],
dim
=
0
)
data
[
"packed_vit_token_indexes"
]
=
torch
.
tensor
(
sequence_status
[
"packed_vit_token_indexes"
]
)
data
[
"vit_token_seqlens"
]
=
torch
.
tensor
(
sequence_status
[
"vit_token_seqlens"
]
)
# if the model is required to perform visual generation
if
len
(
sequence_status
[
"packed_timesteps"
])
>
0
:
data
[
"packed_timesteps"
]
=
torch
.
tensor
(
sequence_status
[
"packed_timesteps"
])
data
[
"mse_loss_indexes"
]
=
torch
.
tensor
(
sequence_status
[
"mse_loss_indexes"
])
# if the model is required to perform text generation
if
len
(
sequence_status
[
"packed_label_ids"
])
>
0
:
data
[
"packed_label_ids"
]
=
torch
.
tensor
(
sequence_status
[
"packed_label_ids"
])
data
[
"ce_loss_indexes"
]
=
torch
.
tensor
(
sequence_status
[
"ce_loss_indexes"
])
data
[
"ce_loss_weights"
]
=
torch
.
tensor
(
sequence_status
[
"ce_loss_weights"
])
return
data
def
__iter__
(
self
):
total_weights
=
sum
(
self
.
grouped_weights
)
assert
total_weights
>
0.0
group_cumprobs
=
[
sum
(
self
.
grouped_weights
[:
i
+
1
])
/
total_weights
for
i
in
range
(
len
(
self
.
grouped_weights
))
]
sequence_status
=
self
.
set_sequence_status
()
batch_data_indexes
=
[]
buffer
=
[]
video_buffer
=
[]
# Separate buffer for extremely long video samples
while
True
:
# Ensure at least one sample from each group
if
sequence_status
[
"curr"
]
==
0
:
if
len
(
video_buffer
)
>
0
:
sample
=
video_buffer
.
pop
(
0
)
num_tokens
=
sample
[
"num_tokens"
]
+
2
*
len
(
sample
[
"sequence_plan"
])
sequence_status
=
self
.
pack_sequence
(
sample
,
sequence_status
)
batch_data_indexes
.
append
(
sample
[
"data_indexes"
])
else
:
for
group_index
,
group_iter
in
enumerate
(
self
.
dataset_iters
):
if
self
.
is_mandatory
[
group_index
]:
while
True
:
sample
=
next
(
group_iter
)
# if a sample is too long, put it in video buffer
num_tokens
=
sample
[
"num_tokens"
]
+
2
*
len
(
sample
[
"sequence_plan"
]
)
if
num_tokens
>
self
.
max_num_tokens_per_sample
:
if
len
(
video_buffer
)
<
self
.
max_buffer_size
:
video_buffer
.
append
(
sample
)
print
(
f
"Added sample with length
{
num_tokens
}
to video_buffer (size:
{
len
(
video_buffer
)
}
)"
)
else
:
print
(
f
"video_buffer full, skip a sample with length
{
num_tokens
}
"
)
break
elif
num_tokens
<
self
.
max_num_tokens_per_sample
:
sequence_status
=
self
.
pack_sequence
(
sample
,
sequence_status
)
batch_data_indexes
.
append
(
sample
[
"data_indexes"
])
break
if
sequence_status
[
"curr"
]
>=
self
.
expected_num_tokens
:
data
=
self
.
to_tensor
(
sequence_status
)
data
[
"batch_data_indexes"
]
=
batch_data_indexes
print
(
f
"Yielding
{
len
(
sequence_status
[
'sample_lens'
])
}
3D data with length
{
sum
(
sequence_status
[
'sample_lens'
])
}
, length of each sample:
{
sequence_status
[
'sample_lens'
]
}
"
)
yield
data
sequence_status
=
self
.
set_sequence_status
()
batch_data_indexes
=
[]
if
sequence_status
[
"curr"
]
<
self
.
prefer_buffer_before
and
len
(
buffer
)
>
0
:
sample
=
buffer
.
pop
(
0
)
sample_from_buffer
=
True
else
:
# sample normally across all groups
n
=
random
.
random
()
group_index
=
0
for
i
,
cumprob
in
enumerate
(
group_cumprobs
):
if
n
<
cumprob
:
group_index
=
i
break
sample
=
next
(
self
.
dataset_iters
[
group_index
])
sample_from_buffer
=
False
# if a sample is too long, skip it
num_tokens
=
sample
[
"num_tokens"
]
+
2
*
len
(
sample
[
"sequence_plan"
])
if
num_tokens
>
self
.
max_num_tokens_per_sample
:
if
len
(
video_buffer
)
<
self
.
max_buffer_size
:
video_buffer
.
append
(
sample
)
print
(
f
"Added sample with length
{
num_tokens
}
to video_buffer (size:
{
len
(
video_buffer
)
}
)"
)
else
:
print
(
f
"video_buffer full, skip a sample with length
{
num_tokens
}
"
)
continue
if
sequence_status
[
"curr"
]
+
num_tokens
>
self
.
max_num_tokens
:
if
len
(
buffer
)
<
self
.
max_buffer_size
and
not
sample_from_buffer
:
buffer
.
append
(
sample
)
else
:
# print(f"Yielding data with length {sum(sequence_status['sample_lens'])}")
data
=
self
.
to_tensor
(
sequence_status
)
data
[
"batch_data_indexes"
]
=
batch_data_indexes
yield
data
sequence_status
=
self
.
set_sequence_status
()
batch_data_indexes
=
[]
continue
sequence_status
=
self
.
pack_sequence
(
sample
,
sequence_status
)
batch_data_indexes
.
append
(
sample
[
"data_indexes"
])
if
sequence_status
[
"curr"
]
>=
self
.
expected_num_tokens
:
data
=
self
.
to_tensor
(
sequence_status
)
data
[
"batch_data_indexes"
]
=
batch_data_indexes
yield
data
sequence_status
=
self
.
set_sequence_status
()
batch_data_indexes
=
[]
def
pack_sequence
(
self
,
sample
,
sequence_status
):
image_tensor_list
=
sample
[
"image_tensor_list"
]
text_ids_list
=
sample
[
"text_ids_list"
]
sequence_plan
=
sample
[
"sequence_plan"
]
split_lens
,
attn_modes
=
list
(),
list
()
curr
=
sequence_status
[
"curr"
]
curr_rope_id
=
0
sample_lens
=
0
for
item
in
sequence_plan
:
split_start
=
item
.
get
(
"split_start"
,
True
)
if
split_start
:
curr_split_len
=
0
if
item
[
"type"
]
==
"text"
:
text_ids
=
text_ids_list
.
pop
(
0
)
if
(
item
[
"enable_cfg"
]
==
1
and
random
.
random
()
<
self
.
data_config
.
text_cond_dropout_prob
):
continue
shifted_text_ids
=
[
self
.
bos_token_id
]
+
text_ids
sequence_status
[
"packed_text_ids"
].
extend
(
shifted_text_ids
)
sequence_status
[
"packed_text_indexes"
].
extend
(
range
(
curr
,
curr
+
len
(
shifted_text_ids
))
)
if
item
[
"loss"
]
==
1
:
sequence_status
[
"ce_loss_indexes"
].
extend
(
range
(
curr
,
curr
+
len
(
shifted_text_ids
))
)
sequence_status
[
"ce_loss_weights"
].
extend
(
[
len2weight
(
len
(
shifted_text_ids
))]
*
len
(
shifted_text_ids
)
)
sequence_status
[
"packed_label_ids"
].
extend
(
text_ids
+
[
self
.
eos_token_id
]
)
curr
+=
len
(
shifted_text_ids
)
curr_split_len
+=
len
(
shifted_text_ids
)
# add a <|im_end|> token
sequence_status
[
"packed_text_ids"
].
append
(
self
.
eos_token_id
)
sequence_status
[
"packed_text_indexes"
].
append
(
curr
)
if
item
[
"special_token_loss"
]
==
1
:
# <|im_end|> may have loss
sequence_status
[
"ce_loss_indexes"
].
append
(
curr
)
sequence_status
[
"ce_loss_weights"
].
append
(
1.0
)
sequence_status
[
"packed_label_ids"
].
append
(
item
[
"special_token_label"
]
)
curr
+=
1
curr_split_len
+=
1
# update sequence status
attn_modes
.
append
(
"causal"
)
sequence_status
[
"packed_position_ids"
].
extend
(
range
(
curr_rope_id
,
curr_rope_id
+
curr_split_len
)
)
curr_rope_id
+=
curr_split_len
elif
item
[
"type"
]
==
"vit_image"
:
image_tensor
=
image_tensor_list
.
pop
(
0
)
if
(
item
[
"enable_cfg"
]
==
1
and
random
.
random
()
<
self
.
data_config
.
vit_cond_dropout_prob
):
curr_rope_id
+=
1
continue
# add a <|startofimage|> token
sequence_status
[
"packed_text_ids"
].
append
(
self
.
start_of_image
)
sequence_status
[
"packed_text_indexes"
].
append
(
curr
)
curr
+=
1
curr_split_len
+=
1
# preprocess image
vit_tokens
=
patchify
(
image_tensor
,
self
.
data_config
.
vit_patch_size
)
num_img_tokens
=
vit_tokens
.
shape
[
0
]
sequence_status
[
"packed_vit_token_indexes"
].
extend
(
range
(
curr
,
curr
+
num_img_tokens
)
)
curr
+=
num_img_tokens
curr_split_len
+=
num_img_tokens
sequence_status
[
"packed_vit_tokens"
].
append
(
vit_tokens
)
sequence_status
[
"vit_token_seqlens"
].
append
(
num_img_tokens
)
sequence_status
[
"packed_vit_position_ids"
].
append
(
self
.
get_flattened_position_ids
(
image_tensor
.
size
(
1
),
image_tensor
.
size
(
2
),
self
.
data_config
.
vit_patch_size
,
max_num_patches_per_side
=
self
.
data_config
.
max_num_patch_per_side
,
)
)
# add a <|endofimage|> token
sequence_status
[
"packed_text_ids"
].
append
(
self
.
end_of_image
)
sequence_status
[
"packed_text_indexes"
].
append
(
curr
)
if
item
[
"special_token_loss"
]
==
1
:
# <|endofimage|> may have loss
sequence_status
[
"ce_loss_indexes"
].
append
(
curr
)
sequence_status
[
"ce_loss_weights"
].
append
(
1.0
)
sequence_status
[
"packed_label_ids"
].
append
(
item
[
"special_token_label"
]
)
curr
+=
1
curr_split_len
+=
1
# update sequence status
attn_modes
.
append
(
"full"
)
sequence_status
[
"packed_position_ids"
].
extend
(
[
curr_rope_id
]
*
curr_split_len
)
curr_rope_id
+=
1
elif
item
[
"type"
]
==
"vae_image"
:
image_tensor
=
image_tensor_list
.
pop
(
0
)
if
(
item
[
"enable_cfg"
]
==
1
and
random
.
random
()
<
self
.
data_config
.
vae_cond_dropout_prob
):
# FIXME fix vae dropout in video2video setting.
curr_rope_id
+=
1
continue
# add a <|startofimage|> token
sequence_status
[
"packed_text_ids"
].
append
(
self
.
start_of_image
)
sequence_status
[
"packed_text_indexes"
].
append
(
curr
)
curr
+=
1
curr_split_len
+=
1
# preprocess image
sequence_status
[
"vae_image_tensors"
].
append
(
image_tensor
)
sequence_status
[
"packed_latent_position_ids"
].
append
(
self
.
get_flattened_position_ids
(
image_tensor
.
size
(
1
),
image_tensor
.
size
(
2
),
self
.
data_config
.
vae_image_downsample
,
max_num_patches_per_side
=
self
.
data_config
.
max_latent_size
,
)
)
H
,
W
=
image_tensor
.
shape
[
1
:]
h
=
H
//
self
.
data_config
.
vae_image_downsample
w
=
W
//
self
.
data_config
.
vae_image_downsample
sequence_status
[
"vae_latent_shapes"
].
append
((
h
,
w
))
num_img_tokens
=
w
*
h
sequence_status
[
"packed_vae_token_indexes"
].
extend
(
range
(
curr
,
curr
+
num_img_tokens
)
)
if
item
[
"loss"
]
==
1
:
sequence_status
[
"mse_loss_indexes"
].
extend
(
range
(
curr
,
curr
+
num_img_tokens
)
)
if
split_start
:
timestep
=
np
.
random
.
randn
()
else
:
timestep
=
float
(
"-inf"
)
sequence_status
[
"packed_timesteps"
].
extend
([
timestep
]
*
num_img_tokens
)
curr
+=
num_img_tokens
curr_split_len
+=
num_img_tokens
# add a <|endofimage|> token
sequence_status
[
"packed_text_ids"
].
append
(
self
.
end_of_image
)
sequence_status
[
"packed_text_indexes"
].
append
(
curr
)
# <|endofimage|> may have loss
if
item
[
"special_token_loss"
]
==
1
:
sequence_status
[
"ce_loss_indexes"
].
append
(
curr
)
sequence_status
[
"ce_loss_weights"
].
append
(
1.0
)
sequence_status
[
"packed_label_ids"
].
append
(
item
[
"special_token_label"
]
)
curr
+=
1
curr_split_len
+=
1
# update sequence status
if
split_start
:
if
item
[
"loss"
]
==
1
and
"frame_delta"
not
in
item
.
keys
():
attn_modes
.
append
(
"noise"
)
else
:
attn_modes
.
append
(
"full"
)
sequence_status
[
"packed_position_ids"
].
extend
(
[
curr_rope_id
]
*
(
num_img_tokens
+
2
)
)
if
"frame_delta"
in
item
.
keys
():
curr_rope_id
+=
item
[
"frame_delta"
]
elif
item
[
"loss"
]
==
0
:
curr_rope_id
+=
1
if
item
.
get
(
"split_end"
,
True
):
split_lens
.
append
(
curr_split_len
)
sample_lens
+=
curr_split_len
sequence_status
[
"curr"
]
=
curr
sequence_status
[
"sample_lens"
].
append
(
sample_lens
)
# prepare attention mask
if
not
self
.
use_flex
:
sequence_status
[
"nested_attention_masks"
].
append
(
prepare_attention_mask_per_sample
(
split_lens
,
attn_modes
)
)
else
:
sequence_status
[
"split_lens"
].
extend
(
split_lens
)
sequence_status
[
"attn_modes"
].
extend
(
attn_modes
)
return
sequence_status
class
SimpleCustomBatch
:
def
__init__
(
self
,
batch
):
data
=
batch
[
0
]
self
.
batch_data_indexes
=
data
[
"batch_data_indexes"
]
self
.
sequence_length
=
data
[
"sequence_length"
]
self
.
sample_lens
=
data
[
"sample_lens"
]
self
.
packed_text_ids
=
data
[
"packed_text_ids"
]
self
.
packed_text_indexes
=
data
[
"packed_text_indexes"
]
self
.
packed_position_ids
=
data
[
"packed_position_ids"
]
self
.
use_flex
=
"nested_attention_masks"
not
in
data
.
keys
()
if
self
.
use_flex
:
self
.
split_lens
=
data
[
"split_lens"
]
self
.
attn_modes
=
data
[
"attn_modes"
]
else
:
self
.
nested_attention_masks
=
data
[
"nested_attention_masks"
]
if
"padded_images"
in
data
.
keys
():
self
.
padded_images
=
data
[
"padded_images"
]
self
.
patchified_vae_latent_shapes
=
data
[
"patchified_vae_latent_shapes"
]
self
.
packed_latent_position_ids
=
data
[
"packed_latent_position_ids"
]
self
.
packed_vae_token_indexes
=
data
[
"packed_vae_token_indexes"
]
if
"packed_vit_tokens"
in
data
.
keys
():
self
.
packed_vit_tokens
=
data
[
"packed_vit_tokens"
]
self
.
packed_vit_position_ids
=
data
[
"packed_vit_position_ids"
]
self
.
packed_vit_token_indexes
=
data
[
"packed_vit_token_indexes"
]
self
.
vit_token_seqlens
=
data
[
"vit_token_seqlens"
]
if
"packed_timesteps"
in
data
.
keys
():
self
.
packed_timesteps
=
data
[
"packed_timesteps"
]
self
.
mse_loss_indexes
=
data
[
"mse_loss_indexes"
]
if
"packed_label_ids"
in
data
.
keys
():
self
.
packed_label_ids
=
data
[
"packed_label_ids"
]
self
.
ce_loss_indexes
=
data
[
"ce_loss_indexes"
]
self
.
ce_loss_weights
=
data
[
"ce_loss_weights"
]
def
pin_memory
(
self
):
self
.
packed_text_ids
=
self
.
packed_text_ids
.
pin_memory
()
self
.
packed_text_indexes
=
self
.
packed_text_indexes
.
pin_memory
()
self
.
packed_position_ids
=
self
.
packed_position_ids
.
pin_memory
()
if
not
self
.
use_flex
:
self
.
nested_attention_masks
=
[
item
.
pin_memory
()
for
item
in
self
.
nested_attention_masks
]
if
hasattr
(
self
,
"padded_images"
):
self
.
padded_images
=
self
.
padded_images
.
pin_memory
()
self
.
packed_vae_token_indexes
=
self
.
packed_vae_token_indexes
.
pin_memory
()
self
.
packed_latent_position_ids
=
(
self
.
packed_latent_position_ids
.
pin_memory
()
)
if
hasattr
(
self
,
"packed_timesteps"
):
self
.
packed_timesteps
=
self
.
packed_timesteps
.
pin_memory
()
self
.
mse_loss_indexes
=
self
.
mse_loss_indexes
.
pin_memory
()
if
hasattr
(
self
,
"packed_vit_tokens"
):
self
.
packed_vit_tokens
=
self
.
packed_vit_tokens
.
pin_memory
()
self
.
packed_vit_position_ids
=
self
.
packed_vit_position_ids
.
pin_memory
()
self
.
packed_vit_token_indexes
=
self
.
packed_vit_token_indexes
.
pin_memory
()
self
.
vit_token_seqlens
=
self
.
vit_token_seqlens
.
pin_memory
()
if
hasattr
(
self
,
"packed_label_ids"
):
self
.
packed_label_ids
=
self
.
packed_label_ids
.
pin_memory
()
self
.
ce_loss_indexes
=
self
.
ce_loss_indexes
.
pin_memory
()
self
.
ce_loss_weights
=
self
.
ce_loss_weights
.
pin_memory
()
return
self
def
cuda
(
self
,
device
):
self
.
packed_text_ids
=
self
.
packed_text_ids
.
to
(
device
)
self
.
packed_text_indexes
=
self
.
packed_text_indexes
.
to
(
device
)
self
.
packed_position_ids
=
self
.
packed_position_ids
.
to
(
device
)
if
not
self
.
use_flex
:
self
.
nested_attention_masks
=
[
item
.
to
(
device
)
for
item
in
self
.
nested_attention_masks
]
if
hasattr
(
self
,
"padded_images"
):
self
.
padded_images
=
self
.
padded_images
.
to
(
device
)
self
.
packed_vae_token_indexes
=
self
.
packed_vae_token_indexes
.
to
(
device
)
self
.
packed_latent_position_ids
=
self
.
packed_latent_position_ids
.
to
(
device
)
if
hasattr
(
self
,
"packed_timesteps"
):
self
.
packed_timesteps
=
self
.
packed_timesteps
.
to
(
device
)
self
.
mse_loss_indexes
=
self
.
mse_loss_indexes
.
to
(
device
)
if
hasattr
(
self
,
"packed_vit_tokens"
):
self
.
packed_vit_tokens
=
self
.
packed_vit_tokens
.
to
(
device
)
self
.
packed_vit_position_ids
=
self
.
packed_vit_position_ids
.
to
(
device
)
self
.
packed_vit_token_indexes
=
self
.
packed_vit_token_indexes
.
to
(
device
)
self
.
vit_token_seqlens
=
self
.
vit_token_seqlens
.
to
(
device
)
if
hasattr
(
self
,
"packed_label_ids"
):
self
.
packed_label_ids
=
self
.
packed_label_ids
.
to
(
device
)
self
.
ce_loss_indexes
=
self
.
ce_loss_indexes
.
to
(
device
)
self
.
ce_loss_weights
=
self
.
ce_loss_weights
.
to
(
device
)
return
self
def
to_dict
(
self
):
data
=
dict
(
sequence_length
=
self
.
sequence_length
,
sample_lens
=
self
.
sample_lens
,
packed_text_ids
=
self
.
packed_text_ids
,
packed_text_indexes
=
self
.
packed_text_indexes
,
packed_position_ids
=
self
.
packed_position_ids
,
batch_data_indexes
=
self
.
batch_data_indexes
,
)
if
not
self
.
use_flex
:
data
[
"nested_attention_masks"
]
=
self
.
nested_attention_masks
else
:
data
[
"split_lens"
]
=
self
.
split_lens
data
[
"attn_modes"
]
=
self
.
attn_modes
if
hasattr
(
self
,
"padded_images"
):
data
[
"padded_images"
]
=
self
.
padded_images
data
[
"patchified_vae_latent_shapes"
]
=
self
.
patchified_vae_latent_shapes
data
[
"packed_latent_position_ids"
]
=
self
.
packed_latent_position_ids
data
[
"packed_vae_token_indexes"
]
=
self
.
packed_vae_token_indexes
if
hasattr
(
self
,
"packed_vit_tokens"
):
data
[
"packed_vit_tokens"
]
=
self
.
packed_vit_tokens
data
[
"packed_vit_position_ids"
]
=
self
.
packed_vit_position_ids
data
[
"packed_vit_token_indexes"
]
=
self
.
packed_vit_token_indexes
data
[
"vit_token_seqlens"
]
=
self
.
vit_token_seqlens
if
hasattr
(
self
,
"packed_timesteps"
):
data
[
"packed_timesteps"
]
=
self
.
packed_timesteps
data
[
"mse_loss_indexes"
]
=
self
.
mse_loss_indexes
if
hasattr
(
self
,
"packed_label_ids"
):
data
[
"packed_label_ids"
]
=
self
.
packed_label_ids
data
[
"ce_loss_indexes"
]
=
self
.
ce_loss_indexes
data
[
"ce_loss_weights"
]
=
self
.
ce_loss_weights
return
data
def
collate_wrapper
():
def
collate_fn
(
batch
):
return
SimpleCustomBatch
(
batch
)
return
collate_fn
SenseNova-SI-main/training/bagel/data/dataset_info.py
0 → 100644
View file @
876a36a4
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
import
glob
import
json
import
os
import
os.path
as
osp
from
.edit_dataset_jsonl
import
EditJSONLIterableDataset
from
.interleave_datasets
import
UnifiedEditIterableDataset
from
.t2i_dataset
import
T2IIterableDataset
from
.t2i_dataset_jsonl
import
T2IJSONLIterableDataset
from
.vlm_dataset
import
SftJSONLIterableDataset
DATASET_REGISTRY
=
{
"sensenova_si_800K"
:
SftJSONLIterableDataset
,
"sensenova_si_8M"
:
SftJSONLIterableDataset
,
}
DATASET_INFO
=
{}
# load additional dataset info from the dataset_info/ directory
dataset_info_path
=
osp
.
join
(
osp
.
dirname
(
__file__
),
"dataset_info"
)
dataset_info_files
=
glob
.
glob
(
osp
.
join
(
dataset_info_path
,
"*.json"
))
training_root
=
os
.
environ
.
get
(
"TRAINING_ROOT"
,
osp
.
abspath
(
osp
.
join
(
osp
.
dirname
(
__file__
),
".."
,
".."
,
".."
)),
)
def
_resolve_training_root_path
(
value
):
if
isinstance
(
value
,
str
):
return
value
.
replace
(
"__TRAINING_ROOT__"
,
training_root
)
if
isinstance
(
value
,
list
):
return
[
_resolve_training_root_path
(
v
)
for
v
in
value
]
if
isinstance
(
value
,
dict
):
return
{
k
:
_resolve_training_root_path
(
v
)
for
k
,
v
in
value
.
items
()}
return
value
for
dataset_info_file
in
dataset_info_files
:
with
open
(
dataset_info_file
,
"r"
)
as
f
:
base_name
=
osp
.
splitext
(
osp
.
basename
(
dataset_info_file
))[
0
]
dataset_info
=
_resolve_training_root_path
(
json
.
load
(
f
))
for
key
in
dataset_info
.
keys
():
if
key
in
DATASET_INFO
:
raise
ValueError
(
f
"Key
{
key
}
already exists in DATASET_INFO"
)
DATASET_INFO
.
update
({
base_name
:
dataset_info
})
SenseNova-SI-main/training/bagel/data/dataset_info/sensenova_si_800K.json
0 → 100644
View file @
876a36a4
{
"sensenova_si_800K"
:
{
"data_dir"
:
"__TRAINING_ROOT__/data/SenseNova-SI-800K/"
,
"jsonl_path"
:
"__TRAINING_ROOT__/data/SenseNova-SI-800K/SenseNova-SI-800K.jsonl"
,
"num_total_samples"
:
832132
}
}
SenseNova-SI-main/training/bagel/data/dataset_info/sensenova_si_8M.json
0 → 100644
View file @
876a36a4
{
"sensenova_si_8M"
:
{
"data_dir"
:
"__TRAINING_ROOT__/data/SenseNova-SI-8M/"
,
"jsonl_path"
:
"__TRAINING_ROOT__/data/SenseNova-SI-8M/SenseNova-SI-8M.jsonl"
,
"num_total_samples"
:
8165067
}
}
SenseNova-SI-main/training/bagel/data/distributed_iterable_dataset.py
0 → 100644
View file @
876a36a4
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
import
random
import
torch
class
DistributedIterableDataset
(
torch
.
utils
.
data
.
IterableDataset
):
def
__init__
(
self
,
dataset_name
,
local_rank
=
0
,
world_size
=
1
,
num_workers
=
8
):
self
.
dataset_name
=
dataset_name
self
.
local_rank
=
local_rank
self
.
world_size
=
world_size
self
.
num_workers
=
num_workers
self
.
rng
=
random
.
Random
()
self
.
data_paths
=
None
def
get_data_paths
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
def
set_epoch
(
self
,
seed
=
42
):
if
self
.
data_paths
is
None
:
return
if
isinstance
(
self
.
data_paths
[
0
],
tuple
):
data_paths
=
sorted
(
self
.
data_paths
,
key
=
lambda
x
:
(
x
[
0
],
x
[
1
]))
elif
isinstance
(
self
.
data_paths
[
0
],
str
):
data_paths
=
sorted
(
self
.
data_paths
)
else
:
raise
ValueError
(
f
"Unknown data_paths type:
{
type
(
self
.
data_paths
[
0
])
}
"
)
self
.
rng
.
seed
(
seed
)
self
.
rng
.
shuffle
(
data_paths
)
num_files_per_rank
=
len
(
data_paths
)
//
self
.
world_size
local_start
=
self
.
local_rank
*
num_files_per_rank
local_end
=
(
self
.
local_rank
+
1
)
*
num_files_per_rank
self
.
num_files_per_rank
=
num_files_per_rank
self
.
data_paths_per_rank
=
data_paths
[
local_start
:
local_end
]
def
get_data_paths_per_worker
(
self
):
if
self
.
data_paths
is
None
:
return
None
info
=
torch
.
utils
.
data
.
get_worker_info
()
if
info
is
None
:
# Single worker: Use all files assigned to the rank
return
self
.
data_paths_per_rank
,
0
worker_id
=
info
.
id
num_files_per_worker
=
self
.
num_files_per_rank
//
info
.
num_workers
start
=
num_files_per_worker
*
worker_id
end
=
num_files_per_worker
*
(
worker_id
+
1
)
data_paths_per_worker
=
self
.
data_paths_per_rank
[
start
:
end
]
return
data_paths_per_worker
[::
-
1
],
worker_id
def
__iter__
(
self
):
raise
NotImplementedError
SenseNova-SI-main/training/bagel/data/edit_dataset_jsonl.py
0 → 100644
View file @
876a36a4
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
import
io
import
json
import
os
import
random
import
pyarrow.parquet
as
pq
from
PIL
import
Image
,
ImageFile
,
PngImagePlugin
from
.data_utils
import
load_image
,
pil_img2rgb
from
.distributed_iterable_dataset
import
DistributedIterableDataset
from
.parquet_utils
import
get_parquet_data_paths
,
init_arrow_pf_fs
Image
.
MAX_IMAGE_PIXELS
=
200000000
ImageFile
.
LOAD_TRUNCATED_IMAGES
=
True
MaximumDecompressedSize
=
1024
MegaByte
=
2
**
20
PngImagePlugin
.
MAX_TEXT_CHUNK
=
MaximumDecompressedSize
*
MegaByte
class
EditJSONLIterableDataset
(
DistributedIterableDataset
):
def
_add_text
(
self
,
sample
,
text
,
need_loss
,
enable_cfg
=
True
):
text_ids
=
self
.
tokenizer
.
encode
(
text
)
sample
[
"num_tokens"
]
+=
len
(
text_ids
)
sample
[
"text_ids_list"
].
append
(
text_ids
)
sample
[
"sequence_plan"
].
append
(
{
"type"
:
"text"
,
"enable_cfg"
:
int
(
enable_cfg
),
"loss"
:
int
(
need_loss
),
"special_token_loss"
:
0
,
"special_token_label"
:
None
,
}
)
return
sample
def
_resize_and_pad
(
self
,
img
:
Image
.
Image
,
is_mask
=
False
)
->
Image
.
Image
:
"""根据 __init__ 里解析好的 fixed_size 进行 resize/pad"""
if
self
.
fixed_size
==
None
:
return
img
interp
=
Image
.
NEAREST
if
is_mask
else
Image
.
BICUBIC
# case1: (H,W) 矩形 resize
# if self.fixed_mode == "rect":
target_h
,
target_w
=
self
.
fixed_size
,
self
.
fixed_size
return
img
.
resize
((
target_w
,
target_h
),
interp
)
def
_add_image
(
self
,
sample
,
image
,
need_loss
,
need_vae
,
need_vit
,
enable_cfg
=
True
):
assert
need_loss
or
need_vae
or
need_vit
if
need_loss
:
sample
[
"sequence_plan"
].
append
(
{
"type"
:
"vae_image"
,
"enable_cfg"
:
0
,
"loss"
:
1
,
"special_token_loss"
:
0
,
"special_token_label"
:
None
,
}
)
image_tensor
=
self
.
transform
(
image
)
height
,
width
=
image_tensor
.
shape
[
1
:]
sample
[
"num_tokens"
]
+=
width
*
height
//
self
.
transform
.
stride
**
2
sample
[
"image_tensor_list"
].
append
(
image_tensor
)
if
need_vae
:
sample
[
"sequence_plan"
].
append
(
{
"type"
:
"vae_image"
,
"enable_cfg"
:
int
(
enable_cfg
),
"loss"
:
0
,
"special_token_loss"
:
0
,
"special_token_label"
:
None
,
}
)
image_tensor
=
self
.
transform
(
image
)
height
,
width
=
image_tensor
.
shape
[
1
:]
sample
[
"num_tokens"
]
+=
width
*
height
//
self
.
transform
.
stride
**
2
sample
[
"image_tensor_list"
].
append
(
image_tensor
.
clone
())
if
need_vit
:
sample
[
"sequence_plan"
].
append
(
{
"type"
:
"vit_image"
,
"enable_cfg"
:
int
(
enable_cfg
),
"loss"
:
0
,
"special_token_loss"
:
0
,
"special_token_label"
:
None
,
},
)
vit_image_tensor
=
self
.
vit_transform
(
image
)
height
,
width
=
vit_image_tensor
.
shape
[
1
:]
sample
[
"num_tokens"
]
+=
width
*
height
//
self
.
vit_transform
.
stride
**
2
sample
[
"image_tensor_list"
].
append
(
vit_image_tensor
)
return
sample
def
__init__
(
self
,
dataset_name
,
transform
,
tokenizer
,
vit_transform
,
jsonl_path_list
,
data_dir_list
,
num_used_data
,
local_rank
=
0
,
world_size
=
1
,
num_workers
=
8
,
data_status
=
None
,
shuffle_lines
=
False
,
shuffle_seed
=
0
,
fixed_size
=
None
,
):
"""
jsonl_path_list: list of jsonl file paths
data_dir_list: list of image directories containing the images of each jsonl file
num_used_data: list of number of sampled data points for each jsonl
"""
super
().
__init__
(
dataset_name
,
local_rank
,
world_size
,
num_workers
)
self
.
transform
=
transform
if
fixed_size
is
None
:
self
.
fixed_size
=
None
else
:
self
.
fixed_size
=
fixed_size
self
.
tokenizer
=
tokenizer
self
.
vit_transform
=
vit_transform
self
.
data_status
=
data_status
self
.
data_paths
=
self
.
get_data_paths
(
jsonl_path_list
,
data_dir_list
,
num_used_data
,
shuffle_lines
,
shuffle_seed
,
)
self
.
set_epoch
()
def
get_data_paths
(
self
,
jsonl_path_list
,
data_dir_list
,
num_used_data
,
shuffle_lines
,
shuffle_seed
,
):
data_paths
=
[]
for
jsonl_path
,
image_dir
,
num_data_point
in
zip
(
jsonl_path_list
,
data_dir_list
,
num_used_data
):
with
open
(
jsonl_path
,
"r"
)
as
f
:
raw_data
=
f
.
readlines
()
if
shuffle_lines
:
self
.
rng
.
seed
(
shuffle_seed
)
self
.
rng
.
shuffle
(
raw_data
)
raw_data
=
raw_data
[:
num_data_point
]
data_paths
.
extend
([(
json_data
,
image_dir
)
for
json_data
in
raw_data
])
return
data_paths
def
__iter__
(
self
):
data_paths_per_worker
,
worker_id
=
self
.
get_data_paths_per_worker
()
if
self
.
data_status
is
not
None
:
row_start_id
=
self
.
data_status
[
worker_id
]
+
1
else
:
row_start_id
=
0
transform_stride
=
self
.
transform
.
stride
print
(
f
"rank-
{
self
.
local_rank
}
worker-
{
worker_id
}
dataset-
{
self
.
dataset_name
}
: "
f
"resuming data at row#
{
row_start_id
}
"
)
while
True
:
data_paths_per_worker_
=
data_paths_per_worker
[
row_start_id
:]
for
row_idx
,
(
data
,
image_dir
)
in
enumerate
(
data_paths_per_worker_
,
start
=
row_start_id
):
sample
=
{
"sequence_plan"
:
[],
"text_ids_list"
:
[],
"image_tensor_list"
:
[],
"num_tokens"
:
0
,
}
# try:
data_item
=
json
.
loads
(
data
)
sample
=
self
.
_add_image
(
sample
,
# pil_img2rgb(Image.open(os.path.join(image_dir, data_item['image'][0]))),
pil_img2rgb
(
self
.
_resize_and_pad
(
load_image
(
os
.
path
.
join
(
image_dir
,
data_item
[
"image"
][
0
]))
)
),
need_loss
=
False
,
need_vae
=
True
,
need_vit
=
True
,
)
if
"instruction"
in
data_item
:
instruction
=
data_item
[
"instruction"
]
elif
"conversations"
in
data_item
:
conversations
=
data_item
[
"conversations"
]
if
len
(
conversations
)
==
2
:
if
conversations
[
0
][
"from"
]
==
"human"
:
instruction
=
conversations
[
0
][
"value"
].
replace
(
"<image>"
,
""
)
# instruction = data_item['conversation']
else
:
print
(
"no caption in "
+
data_item
)
sample
=
self
.
_add_text
(
sample
,
instruction
.
rstrip
(),
need_loss
=
False
)
sample
=
self
.
_add_image
(
sample
,
# pil_img2rgb(Image.open(os.path.join(image_dir, data_item['image'][1]))),
pil_img2rgb
(
self
.
_resize_and_pad
(
load_image
(
os
.
path
.
join
(
image_dir
,
data_item
[
"image"
][
1
]))
)
),
need_loss
=
True
,
need_vae
=
False
,
need_vit
=
False
,
)
# except:
# print(f"Error in row {row_idx}")
# continue
sample
[
"data_indexes"
]
=
{
"data_indexes"
:
row_idx
,
"worker_id"
:
worker_id
,
"dataset_name"
:
self
.
dataset_name
,
}
# print('image[0]: ',sample['image_tensor_list'][0].shape)
# print('image[1]: ',sample['image_tensor_list'][1].shape)
yield
sample
row_start_id
=
0
print
(
f
"
{
self
.
dataset_name
}
repeat in rank-
{
self
.
local_rank
}
worker-
{
worker_id
}
"
)
SenseNova-SI-main/training/bagel/data/parquet_utils.py
0 → 100644
View file @
876a36a4
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
import
logging
import
os
import
subprocess
import
pyarrow.fs
as
pf
import
torch.distributed
as
dist
logger
=
logging
.
getLogger
(
__name__
)
def
get_parquet_data_paths
(
data_dir_list
,
num_sampled_data_paths
,
rank
=
0
,
world_size
=
1
):
num_data_dirs
=
len
(
data_dir_list
)
if
world_size
>
1
:
chunk_size
=
(
num_data_dirs
+
world_size
-
1
)
//
world_size
start_idx
=
rank
*
chunk_size
end_idx
=
min
(
start_idx
+
chunk_size
,
num_data_dirs
)
local_data_dir_list
=
data_dir_list
[
start_idx
:
end_idx
]
local_num_sampled_data_paths
=
num_sampled_data_paths
[
start_idx
:
end_idx
]
else
:
local_data_dir_list
=
data_dir_list
local_num_sampled_data_paths
=
num_sampled_data_paths
local_data_paths
=
[]
for
data_dir
,
num_data_path
in
zip
(
local_data_dir_list
,
local_num_sampled_data_paths
):
if
data_dir
.
startswith
(
"hdfs://"
):
files
=
hdfs_ls_cmd
(
data_dir
)
data_paths_per_dir
=
[
file
for
file
in
files
if
file
.
endswith
(
".parquet"
)]
else
:
files
=
os
.
listdir
(
data_dir
)
data_paths_per_dir
=
[
os
.
path
.
join
(
data_dir
,
name
)
for
name
in
files
if
name
.
endswith
(
".parquet"
)
]
repeat
=
num_data_path
//
len
(
data_paths_per_dir
)
data_paths_per_dir
=
data_paths_per_dir
*
(
repeat
+
1
)
local_data_paths
.
extend
(
data_paths_per_dir
[:
num_data_path
])
if
world_size
>
1
:
gather_list
=
[
None
]
*
world_size
dist
.
all_gather_object
(
gather_list
,
local_data_paths
)
combined_chunks
=
[]
for
chunk_list
in
gather_list
:
if
chunk_list
is
not
None
:
combined_chunks
.
extend
(
chunk_list
)
else
:
combined_chunks
=
local_data_paths
return
combined_chunks
# NOTE: cumtomize this function for your cluster
def
get_hdfs_host
():
return
"hdfs://xxx"
# NOTE: cumtomize this function for your cluster
def
get_hdfs_block_size
():
return
134217728
# NOTE: cumtomize this function for your cluster
def
get_hdfs_extra_conf
():
return
None
def
init_arrow_pf_fs
(
parquet_file_path
):
if
parquet_file_path
.
startswith
(
"hdfs://"
):
fs
=
pf
.
HadoopFileSystem
(
host
=
get_hdfs_host
(),
port
=
0
,
buffer_size
=
get_hdfs_block_size
(),
extra_conf
=
get_hdfs_extra_conf
(),
)
else
:
fs
=
pf
.
LocalFileSystem
()
return
fs
def
hdfs_ls_cmd
(
dir
):
result
=
subprocess
.
run
(
[
"hdfs"
,
"dfs"
,
"ls"
,
dir
],
capture_output
=
True
,
text
=
True
).
stdout
return
[
"hdfs://"
+
i
.
split
(
"hdfs://"
)[
-
1
].
strip
()
for
i
in
result
.
split
(
"
\n
"
)
if
"hdfs://"
in
i
]
SenseNova-SI-main/training/bagel/data/t2i_dataset.py
0 → 100644
View file @
876a36a4
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
import
io
import
json
import
random
import
pyarrow.parquet
as
pq
from
PIL
import
Image
from
.data_utils
import
pil_img2rgb
from
.distributed_iterable_dataset
import
DistributedIterableDataset
from
.parquet_utils
import
get_parquet_data_paths
,
init_arrow_pf_fs
Image
.
MAX_IMAGE_PIXELS
=
20_000_000
class
T2IIterableDataset
(
DistributedIterableDataset
):
def
__init__
(
self
,
dataset_name
,
transform
,
tokenizer
,
data_dir_list
,
num_used_data
,
local_rank
=
0
,
world_size
=
1
,
num_workers
=
8
,
data_status
=
None
,
):
"""
data_dir_list: list of data directories contains parquet files
num_used_data: list of number of sampled data paths for each data directory
"""
super
().
__init__
(
dataset_name
,
local_rank
,
world_size
,
num_workers
)
self
.
transform
=
transform
self
.
tokenizer
=
tokenizer
self
.
data_status
=
data_status
self
.
data_paths
=
self
.
get_data_paths
(
data_dir_list
,
num_used_data
)
self
.
set_epoch
()
def
get_data_paths
(
self
,
data_dir_list
,
num_used_data
):
return
get_parquet_data_paths
(
data_dir_list
,
num_used_data
)
def
__iter__
(
self
):
data_paths_per_worker
,
worker_id
=
self
.
get_data_paths_per_worker
()
if
self
.
data_status
is
not
None
:
parquet_start_id
=
self
.
data_status
[
worker_id
][
0
]
row_group_start_id
=
self
.
data_status
[
worker_id
][
1
]
row_start_id
=
self
.
data_status
[
worker_id
][
2
]
+
1
else
:
parquet_start_id
=
0
row_group_start_id
=
0
row_start_id
=
0
transform_stride
=
self
.
transform
.
stride
print
(
f
"rank-
{
self
.
local_rank
}
worker-
{
worker_id
}
dataset-
{
self
.
dataset_name
}
: "
f
"resuming data at parquet#
{
parquet_start_id
}
, rg#
{
row_group_start_id
}
, row#
{
row_start_id
}
"
)
while
True
:
data_paths_per_worker_
=
data_paths_per_worker
[
parquet_start_id
:]
for
parquet_idx
,
parquet_file_path
in
enumerate
(
data_paths_per_worker_
,
start
=
parquet_start_id
):
fs
=
init_arrow_pf_fs
(
parquet_file_path
)
with
fs
.
open_input_file
(
parquet_file_path
)
as
f
:
fr
=
pq
.
ParquetFile
(
f
)
row_group_ids
=
list
(
range
(
fr
.
num_row_groups
))
row_group_ids_
=
row_group_ids
[
row_group_start_id
:]
for
row_group_id
in
row_group_ids_
:
df
=
fr
.
read_row_group
(
row_group_id
).
to_pandas
()
df
=
df
.
iloc
[
row_start_id
:]
for
row_idx
,
row
in
df
.
iterrows
():
num_tokens
=
0
try
:
image_byte
=
row
[
"image"
]
image
=
pil_img2rgb
(
Image
.
open
(
io
.
BytesIO
(
image_byte
)))
except
Exception
as
e
:
print
(
f
"Error:
{
e
}
in rg#
{
row_group_id
}
,
{
parquet_file_path
}
"
)
continue
image_tensor
=
self
.
transform
(
image
)
height
,
width
=
image_tensor
.
shape
[
1
:]
num_tokens
+=
width
*
height
//
transform_stride
**
2
try
:
caption_dict
=
row
[
"captions"
]
caption_dict
=
json
.
loads
(
caption_dict
)
except
Exception
as
e
:
print
(
f
"Error:
{
e
}
in rg#
{
row_group_id
}
,
{
parquet_file_path
}
"
)
continue
caps_token
=
[
self
.
tokenizer
.
encode
(
v
)
for
_
,
v
in
caption_dict
.
items
()
]
if
len
(
caps_token
)
==
0
:
print
(
f
"no caption in rg#
{
row_group_id
}
,
{
parquet_file_path
}
"
)
caption_token
=
self
.
tokenizer
.
encode
(
" "
)
else
:
caption_token
=
random
.
choice
(
caps_token
)
sequence_plan
,
text_ids_list
=
[],
[]
text_ids
=
caption_token
num_tokens
+=
len
(
caption_token
)
text_ids_list
.
append
(
text_ids
)
sequence_plan
.
append
(
{
"type"
:
"text"
,
"enable_cfg"
:
1
,
"loss"
:
0
,
"special_token_loss"
:
0
,
"special_token_label"
:
None
,
}
)
sequence_plan
.
append
(
{
"type"
:
"vae_image"
,
"enable_cfg"
:
0
,
"loss"
:
1
,
"special_token_loss"
:
0
,
"special_token_label"
:
None
,
}
)
sample
=
dict
(
image_tensor_list
=
[
image_tensor
],
text_ids_list
=
text_ids_list
,
num_tokens
=
num_tokens
,
sequence_plan
=
sequence_plan
,
data_indexes
=
{
"data_indexes"
:
[
parquet_idx
,
row_group_id
,
row_idx
,
],
"worker_id"
:
worker_id
,
"dataset_name"
:
self
.
dataset_name
,
},
)
yield
sample
row_start_id
=
0
row_group_start_id
=
0
parquet_start_id
=
0
print
(
f
"
{
self
.
dataset_name
}
repeat in rank-
{
self
.
local_rank
}
worker-
{
worker_id
}
"
)
SenseNova-SI-main/training/bagel/data/t2i_dataset_jsonl.py
0 → 100644
View file @
876a36a4
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
import
io
import
json
import
os
import
random
import
traceback
import
pyarrow.parquet
as
pq
from
PIL
import
Image
from
.data_utils
import
load_image
,
pil_img2rgb
from
.distributed_iterable_dataset
import
DistributedIterableDataset
from
.parquet_utils
import
get_parquet_data_paths
,
init_arrow_pf_fs
Image
.
MAX_IMAGE_PIXELS
=
200_000_000
class
T2IJSONLIterableDataset
(
DistributedIterableDataset
):
def
__init__
(
self
,
dataset_name
,
transform
,
tokenizer
,
jsonl_path_list
,
data_dir_list
,
num_used_data
,
local_rank
=
0
,
world_size
=
1
,
num_workers
=
8
,
data_status
=
None
,
):
"""
data_dir_list: list of data directories contains parquet files
num_used_data: list of number of sampled data paths for each data directory
"""
super
().
__init__
(
dataset_name
,
local_rank
,
world_size
,
num_workers
)
self
.
transform
=
transform
self
.
tokenizer
=
tokenizer
self
.
data_status
=
data_status
self
.
data_paths
=
self
.
get_data_paths
(
jsonl_path_list
,
data_dir_list
,
num_used_data
)
self
.
set_epoch
()
def
get_data_paths
(
self
,
jsonl_path_list
,
data_dir_list
,
num_used_data
):
data_paths
=
[]
for
jsonl_path
,
image_dir
,
num_data_point
in
zip
(
jsonl_path_list
,
data_dir_list
,
num_used_data
):
with
open
(
jsonl_path
,
"r"
)
as
f
:
raw_data
=
f
.
readlines
()
raw_data
=
raw_data
[:
num_data_point
]
data_paths
.
extend
([(
json_data
,
image_dir
)
for
json_data
in
raw_data
])
return
data_paths
def
__iter__
(
self
):
data_paths_per_worker
,
worker_id
=
self
.
get_data_paths_per_worker
()
if
self
.
data_status
is
not
None
:
row_start_id
=
self
.
data_status
[
worker_id
]
+
1
else
:
row_start_id
=
0
transform_stride
=
self
.
transform
.
stride
print
(
f
"rank-
{
self
.
local_rank
}
worker-
{
worker_id
}
dataset-
{
self
.
dataset_name
}
: "
f
"resuming data at row#
{
row_start_id
}
"
)
while
True
:
data_paths_per_worker_
=
data_paths_per_worker
[
row_start_id
:]
for
row_idx
,
(
data
,
image_dir
)
in
enumerate
(
data_paths_per_worker_
,
start
=
row_start_id
):
num_tokens
=
0
try
:
data_item
=
json
.
loads
(
data
)
image
=
None
if
"image"
in
data_item
:
image
=
pil_img2rgb
(
load_image
(
os
.
path
.
join
(
image_dir
,
data_item
[
"image"
]))
)
except
Exception
as
e
:
# print(f'Error: {e} in rg#{row_group_id}, {parquet_file_path}')
print
(
f
"Erroe image:
{
e
}
in
{
data
}
in
{
self
.
dataset_name
}
"
)
traceback
.
print_exc
()
continue
image_tensor
=
self
.
transform
(
image
)
height
,
width
=
image_tensor
.
shape
[
1
:]
num_tokens
+=
width
*
height
//
transform_stride
**
2
try
:
if
"conversations"
in
data_item
:
caption_list
=
data_item
[
"conversations"
]
if
caption_list
[
0
][
"from"
]
==
"human"
:
caption_str
=
caption_list
[
0
][
"value"
]
caption_dict
=
{
"captions"
:
caption_str
}
# if 'captions' in row.keys():
# caption_dict = row['captions']
# caption_dict = json.loads(caption_dict)
# elif 'txt' in row.keys():
# caption_str = row['txt']
# caption_dict = {'captions':caption_str}
except
Exception
as
e
:
print
(
f
"Error caption:
{
e
}
in
{
data
}
in
{
self
.
dataset_name
}
"
)
continue
caps_token
=
[
self
.
tokenizer
.
encode
(
v
)
for
_
,
v
in
caption_dict
.
items
()]
if
len
(
caps_token
)
==
0
:
print
(
f
"no caption in
{
data
}
in
{
self
.
dataset_name
}
"
)
caption_token
=
self
.
tokenizer
.
encode
(
" "
)
else
:
caption_token
=
random
.
choice
(
caps_token
)
sequence_plan
,
text_ids_list
=
[],
[]
text_ids
=
caption_token
num_tokens
+=
len
(
caption_token
)
text_ids_list
.
append
(
text_ids
)
sequence_plan
.
append
(
{
"type"
:
"text"
,
"enable_cfg"
:
1
,
"loss"
:
0
,
"special_token_loss"
:
0
,
"special_token_label"
:
None
,
}
)
sequence_plan
.
append
(
{
"type"
:
"vae_image"
,
"enable_cfg"
:
0
,
"loss"
:
1
,
"special_token_loss"
:
0
,
"special_token_label"
:
None
,
}
)
sample
=
dict
(
image_tensor_list
=
[
image_tensor
],
text_ids_list
=
text_ids_list
,
num_tokens
=
num_tokens
,
sequence_plan
=
sequence_plan
,
data_indexes
=
{
"data_indexes"
:
row_idx
,
"worker_id"
:
worker_id
,
"dataset_name"
:
self
.
dataset_name
,
},
)
yield
sample
row_start_id
=
0
print
(
f
"
{
self
.
dataset_name
}
repeat in rank-
{
self
.
local_rank
}
worker-
{
worker_id
}
"
)
SenseNova-SI-main/training/bagel/data/transforms.py
0 → 100644
View file @
876a36a4
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
import
random
import
cv2
import
numpy
as
np
import
torch
from
PIL
import
Image
from
torchvision
import
transforms
from
torchvision.transforms
import
InterpolationMode
from
torchvision.transforms
import
functional
as
F
class
MaxLongEdgeMinShortEdgeResize
(
torch
.
nn
.
Module
):
"""Resize the input image so that its longest side and shortest side are within a specified range,
ensuring that both sides are divisible by a specified stride.
Args:
max_size (int): Maximum size for the longest edge of the image.
min_size (int): Minimum size for the shortest edge of the image.
stride (int): Value by which the height and width of the image must be divisible.
max_pixels (int): Maximum pixels for the full image.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
``InterpolationMode.BILINEAR``, and ``InterpolationMode.BICUBIC`` are supported.
The corresponding Pillow integer constants, e.g., ``PIL.Image.BILINEAR`` are also accepted.
antialias (bool, optional): Whether to apply antialiasing (default is True).
"""
def
__init__
(
self
,
max_size
:
int
,
min_size
:
int
,
stride
:
int
,
max_pixels
:
int
,
interpolation
=
InterpolationMode
.
BICUBIC
,
antialias
=
True
,
):
super
().
__init__
()
self
.
max_size
=
max_size
self
.
min_size
=
min_size
self
.
stride
=
stride
self
.
max_pixels
=
max_pixels
self
.
interpolation
=
interpolation
self
.
antialias
=
antialias
def
_make_divisible
(
self
,
value
,
stride
):
"""Ensure the value is divisible by the stride."""
return
max
(
stride
,
int
(
round
(
value
/
stride
)
*
stride
))
def
_apply_scale
(
self
,
width
,
height
,
scale
):
new_width
=
round
(
width
*
scale
)
new_height
=
round
(
height
*
scale
)
new_width
=
self
.
_make_divisible
(
new_width
,
self
.
stride
)
new_height
=
self
.
_make_divisible
(
new_height
,
self
.
stride
)
return
new_width
,
new_height
def
forward
(
self
,
img
,
img_num
=
1
):
"""
Args:
img (PIL Image): Image to be resized.
img_num (int): Number of images, used to change max_tokens.
Returns:
PIL Image or Tensor: Rescaled image with divisible dimensions.
"""
if
isinstance
(
img
,
torch
.
Tensor
):
height
,
width
=
img
.
shape
[
-
2
:]
else
:
width
,
height
=
img
.
size
scale
=
min
(
self
.
max_size
/
max
(
width
,
height
),
1.0
)
scale
=
max
(
scale
,
self
.
min_size
/
min
(
width
,
height
))
new_width
,
new_height
=
self
.
_apply_scale
(
width
,
height
,
scale
)
# Ensure the number of pixels does not exceed max_pixels
if
new_width
*
new_height
>
self
.
max_pixels
/
img_num
:
scale
=
self
.
max_pixels
/
img_num
/
(
new_width
*
new_height
)
new_width
,
new_height
=
self
.
_apply_scale
(
new_width
,
new_height
,
scale
)
# Ensure longest edge does not exceed max_size
if
max
(
new_width
,
new_height
)
>
self
.
max_size
:
scale
=
self
.
max_size
/
max
(
new_width
,
new_height
)
new_width
,
new_height
=
self
.
_apply_scale
(
new_width
,
new_height
,
scale
)
return
F
.
resize
(
img
,
(
new_height
,
new_width
),
self
.
interpolation
,
antialias
=
self
.
antialias
)
class
ImageTransform
:
def
__init__
(
self
,
max_image_size
,
min_image_size
,
image_stride
,
max_pixels
=
14
*
14
*
9
*
1024
,
image_mean
=
[
0.5
,
0.5
,
0.5
],
image_std
=
[
0.5
,
0.5
,
0.5
],
):
self
.
stride
=
image_stride
self
.
resize_transform
=
MaxLongEdgeMinShortEdgeResize
(
max_size
=
max_image_size
,
min_size
=
min_image_size
,
stride
=
image_stride
,
max_pixels
=
max_pixels
,
)
self
.
to_tensor_transform
=
transforms
.
ToTensor
()
self
.
normalize_transform
=
transforms
.
Normalize
(
mean
=
image_mean
,
std
=
image_std
,
inplace
=
True
)
def
__call__
(
self
,
img
,
img_num
=
1
):
img
=
self
.
resize_transform
(
img
,
img_num
=
img_num
)
img
=
self
.
to_tensor_transform
(
img
)
img
=
self
.
normalize_transform
(
img
)
return
img
def
decolorization
(
image
):
gray_image
=
image
.
convert
(
"L"
)
return
(
Image
.
merge
(
image
.
mode
,
[
gray_image
]
*
3
)
if
image
.
mode
in
(
"RGB"
,
"L"
)
else
gray_image
)
def
downscale
(
image
,
scale_factor
):
new_width
=
int
(
round
(
image
.
width
*
scale_factor
))
new_height
=
int
(
round
(
image
.
height
*
scale_factor
))
new_width
=
max
(
1
,
new_width
)
new_height
=
max
(
1
,
new_height
)
return
image
.
resize
((
new_width
,
new_height
),
resample
=
Image
.
BICUBIC
)
def
crop
(
image
,
crop_factors
):
target_h
,
target_w
=
crop_factors
img_w
,
img_h
=
image
.
size
if
target_h
>
img_h
or
target_w
>
img_w
:
raise
ValueError
(
"Crop size exceeds image dimensions"
)
x
=
random
.
randint
(
0
,
img_w
-
target_w
)
y
=
random
.
randint
(
0
,
img_h
-
target_h
)
return
image
.
crop
((
x
,
y
,
x
+
target_w
,
y
+
target_h
)),
[
[
x
,
y
],
[
x
+
target_w
,
y
+
target_h
],
]
def
motion_blur_opencv
(
image
,
kernel_size
=
15
,
angle
=
0
):
# 线性核
kernel
=
np
.
zeros
((
kernel_size
,
kernel_size
),
dtype
=
np
.
float32
)
kernel
[
kernel_size
//
2
,
:]
=
np
.
ones
(
kernel_size
,
dtype
=
np
.
float32
)
# 旋转核
center
=
(
kernel_size
/
2
-
0.5
,
kernel_size
/
2
-
0.5
)
M
=
cv2
.
getRotationMatrix2D
(
center
,
angle
,
1
)
rotated_kernel
=
cv2
.
warpAffine
(
kernel
,
M
,
(
kernel_size
,
kernel_size
))
# 归一化核
rotated_kernel
/=
rotated_kernel
.
sum
()
if
rotated_kernel
.
sum
()
!=
0
else
1
img
=
np
.
array
(
image
)
if
img
.
ndim
==
2
:
blurred
=
cv2
.
filter2D
(
img
,
-
1
,
rotated_kernel
,
borderType
=
cv2
.
BORDER_REFLECT
)
else
:
# 对于彩色图像,各通道独立卷积
blurred
=
np
.
zeros_like
(
img
)
for
c
in
range
(
img
.
shape
[
2
]):
blurred
[...,
c
]
=
cv2
.
filter2D
(
img
[...,
c
],
-
1
,
rotated_kernel
,
borderType
=
cv2
.
BORDER_REFLECT
)
return
Image
.
fromarray
(
blurred
.
astype
(
np
.
uint8
))
def
shuffle_patch
(
image
,
num_splits
,
gap_size
=
2
):
"""将图像分割为块(允许尺寸不整除),随机打乱后拼接,块间保留间隙"""
h_splits
,
w_splits
=
num_splits
img_w
,
img_h
=
image
.
size
base_patch_h
=
img_h
//
h_splits
patch_heights
=
[
base_patch_h
]
*
(
h_splits
-
1
)
patch_heights
.
append
(
img_h
-
sum
(
patch_heights
))
base_patch_w
=
img_w
//
w_splits
patch_widths
=
[
base_patch_w
]
*
(
w_splits
-
1
)
patch_widths
.
append
(
img_w
-
sum
(
patch_widths
))
patches
=
[]
current_y
=
0
for
i
in
range
(
h_splits
):
current_x
=
0
patch_h
=
patch_heights
[
i
]
for
j
in
range
(
w_splits
):
patch_w
=
patch_widths
[
j
]
patch
=
image
.
crop
(
(
current_x
,
current_y
,
current_x
+
patch_w
,
current_y
+
patch_h
)
)
patches
.
append
(
patch
)
current_x
+=
patch_w
current_y
+=
patch_h
random
.
shuffle
(
patches
)
total_width
=
sum
(
patch_widths
)
+
(
w_splits
-
1
)
*
gap_size
total_height
=
sum
(
patch_heights
)
+
(
h_splits
-
1
)
*
gap_size
new_image
=
Image
.
new
(
image
.
mode
,
(
total_width
,
total_height
),
color
=
(
255
,
255
,
255
)
)
current_y
=
0
# 当前行的起始 Y 坐标
patch_idx
=
0
# 当前处理的块索引
for
i
in
range
(
h_splits
):
current_x
=
0
# 当前列的起始 X 坐标
patch_h
=
patch_heights
[
i
]
# 当前行块的高度
for
j
in
range
(
w_splits
):
# 取出打乱后的块
patch
=
patches
[
patch_idx
]
patch_w
=
patch_widths
[
j
]
# 当前列块的宽度
# 粘贴块(左上角坐标为 (current_x, current_y))
new_image
.
paste
(
patch
,
(
current_x
,
current_y
))
# 更新 X 坐标(下一个块的起始位置 = 当前块宽度 + 间隙)
current_x
+=
patch_w
+
gap_size
patch_idx
+=
1
# 更新 Y 坐标(下一行的起始位置 = 当前行高度 + 间隙)
current_y
+=
patch_h
+
gap_size
return
new_image
def
inpainting
(
image
,
num_splits
,
blank_ratio
=
0.3
,
blank_color
=
(
255
,
255
,
255
)):
"""
图像分割后随机空白部分patch,用于inpainting任务
参数:
image: PIL.Image 输入图像(RGB模式)
h_splits: int 行分割数(垂直方向分割块数)
w_splits: int 列分割数(水平方向分割块数)
blank_ratio: float 空白patch的比例(0~1)
blank_color: tuple 空白区域的颜色(RGB,如白色(255,255,255))
返回:
PIL.Image 处理后拼接的图像
"""
h_splits
,
w_splits
=
num_splits
img_w
,
img_h
=
image
.
size
base_patch_h
=
img_h
//
h_splits
patch_heights
=
[
base_patch_h
]
*
(
h_splits
-
1
)
patch_heights
.
append
(
img_h
-
sum
(
patch_heights
))
base_patch_w
=
img_w
//
w_splits
patch_widths
=
[
base_patch_w
]
*
(
w_splits
-
1
)
patch_widths
.
append
(
img_w
-
sum
(
patch_widths
))
patches
=
[]
current_y
=
0
for
i
in
range
(
h_splits
):
current_x
=
0
patch_h
=
patch_heights
[
i
]
for
j
in
range
(
w_splits
):
patch_w
=
patch_widths
[
j
]
patch
=
image
.
crop
(
(
current_x
,
current_y
,
current_x
+
patch_w
,
current_y
+
patch_h
)
)
patches
.
append
(
patch
)
current_x
+=
patch_w
current_y
+=
patch_h
total_patches
=
h_splits
*
w_splits
num_blank
=
int
(
total_patches
*
blank_ratio
)
num_blank
=
max
(
0
,
min
(
num_blank
,
total_patches
))
blank_indices
=
random
.
sample
(
range
(
total_patches
),
num_blank
)
processed_patches
=
[]
for
idx
,
patch
in
enumerate
(
patches
):
if
idx
in
blank_indices
:
blank_patch
=
Image
.
new
(
"RGB"
,
patch
.
size
,
color
=
blank_color
)
processed_patches
.
append
(
blank_patch
)
else
:
processed_patches
.
append
(
patch
)
# 创建结果图像(尺寸与原图一致)
result_image
=
Image
.
new
(
"RGB"
,
(
img_w
,
img_h
))
current_y
=
0
patch_idx
=
0
for
i
in
range
(
h_splits
):
current_x
=
0
patch_h
=
patch_heights
[
i
]
for
j
in
range
(
w_splits
):
# 取出处理后的patch
patch
=
processed_patches
[
patch_idx
]
patch_w
=
patch_widths
[
j
]
# 粘贴到原位置
result_image
.
paste
(
patch
,
(
current_x
,
current_y
))
current_x
+=
patch_w
patch_idx
+=
1
current_y
+=
patch_h
return
result_image
SenseNova-SI-main/training/bagel/data/video_utils.py
0 → 100644
View file @
876a36a4
# Copyright (c) 2023 OpenGVLab
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: MIT
#
# This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025-05-20.
#
# Original file was released under MIT, with the full license text
# available at https://github.com/OpenGVLab/InternVL/blob/main/LICENSE.
#
# This modified file is released under the same license.
import
io
import
os
import
random
import
re
import
decord
import
numpy
as
np
from
PIL
import
Image
def
get_frame_indices
(
num_frames
,
vlen
,
sample
=
"rand"
,
fix_start
=
None
,
input_fps
=
1
,
max_num_frames
=-
1
):
if
sample
in
[
"rand"
,
"middle"
]:
# uniform sampling
acc_samples
=
min
(
num_frames
,
vlen
)
# split the video into `acc_samples` intervals, and sample from each interval.
intervals
=
np
.
linspace
(
start
=
0
,
stop
=
vlen
,
num
=
acc_samples
+
1
).
astype
(
int
)
ranges
=
[]
for
idx
,
interv
in
enumerate
(
intervals
[:
-
1
]):
ranges
.
append
((
interv
,
intervals
[
idx
+
1
]
-
1
))
if
sample
==
"rand"
:
try
:
frame_indices
=
[
random
.
choice
(
range
(
x
[
0
],
x
[
1
]))
for
x
in
ranges
]
except
:
frame_indices
=
np
.
random
.
permutation
(
vlen
)[:
acc_samples
]
frame_indices
.
sort
()
frame_indices
=
list
(
frame_indices
)
elif
fix_start
is
not
None
:
frame_indices
=
[
x
[
0
]
+
fix_start
for
x
in
ranges
]
elif
sample
==
"middle"
:
frame_indices
=
[(
x
[
0
]
+
x
[
1
])
//
2
for
x
in
ranges
]
else
:
raise
NotImplementedError
if
len
(
frame_indices
)
<
num_frames
:
# padded with last frame
padded_frame_indices
=
[
frame_indices
[
-
1
]]
*
num_frames
padded_frame_indices
[:
len
(
frame_indices
)]
=
frame_indices
frame_indices
=
padded_frame_indices
elif
"fps"
in
sample
:
# fps0.5, sequentially sample frames at 0.5 fps
output_fps
=
float
(
sample
[
3
:])
duration
=
float
(
vlen
)
/
input_fps
delta
=
(
1
/
output_fps
)
# gap between frames, this is also the clip length each frame represents
frame_seconds
=
np
.
arange
(
0
+
delta
/
2
,
duration
+
delta
/
2
,
delta
)
frame_indices
=
np
.
around
(
frame_seconds
*
input_fps
).
astype
(
int
)
frame_indices
=
[
e
for
e
in
frame_indices
if
e
<
vlen
]
if
max_num_frames
>
0
and
len
(
frame_indices
)
>
max_num_frames
:
frame_indices
=
frame_indices
[:
max_num_frames
]
else
:
raise
ValueError
return
frame_indices
def
read_frames_decord
(
video_path
,
num_frames
,
sample
=
"rand"
,
fix_start
=
None
,
clip
=
None
,
min_num_frames
=
4
):
video_reader
=
decord
.
VideoReader
(
video_path
,
num_threads
=
1
)
vlen
=
len
(
video_reader
)
fps
=
video_reader
.
get_avg_fps
()
duration
=
vlen
/
float
(
fps
)
if
clip
:
start
,
end
=
clip
duration
=
end
-
start
vlen
=
int
(
duration
*
fps
)
start_index
=
int
(
start
*
fps
)
t_num_frames
=
np
.
random
.
randint
(
min_num_frames
,
num_frames
+
1
)
frame_indices
=
get_frame_indices
(
t_num_frames
,
vlen
,
sample
=
sample
,
fix_start
=
fix_start
,
input_fps
=
fps
)
if
clip
:
frame_indices
=
[
f
+
start_index
for
f
in
frame_indices
]
frames
=
video_reader
.
get_batch
(
frame_indices
).
asnumpy
()
# (T, H, W, C), np.uint8
frames
=
[
Image
.
fromarray
(
frames
[
i
])
for
i
in
range
(
frames
.
shape
[
0
])]
return
frames
def
extract_frame_number
(
filename
):
# Extract the numeric part from the filename using regular expressions
match
=
re
.
search
(
r
"_(\d+).jpg$"
,
filename
)
return
int
(
match
.
group
(
1
))
if
match
else
-
1
def
sort_frames
(
frame_paths
):
# Extract filenames from each path and sort by their numeric part
return
sorted
(
frame_paths
,
key
=
lambda
x
:
extract_frame_number
(
os
.
path
.
basename
(
x
)))
def
read_frames_folder
(
video_path
,
num_frames
,
sample
=
"rand"
,
fix_start
=
None
,
min_num_frames
=
4
):
image_list
=
sort_frames
(
list
(
os
.
listdir
(
video_path
)))
frames
=
[]
for
image
in
image_list
:
fp
=
os
.
path
.
join
(
video_path
,
image
)
frame
=
Image
.
open
(
fp
).
convert
(
"RGB"
)
frames
.
append
(
frame
)
vlen
=
len
(
frames
)
t_num_frames
=
np
.
random
.
randint
(
min_num_frames
,
num_frames
+
1
)
if
vlen
>
t_num_frames
:
frame_indices
=
get_frame_indices
(
t_num_frames
,
vlen
,
sample
=
sample
,
fix_start
=
fix_start
)
frames
=
[
frames
[
i
]
for
i
in
frame_indices
]
return
frames
class
FrameSampler
:
def
__init__
(
self
,
max_num_frames
=-
1
,
min_num_frames
=
8
,
sample
=
"rand"
):
self
.
max_num_frames
=
max_num_frames
self
.
min_num_frames
=
min_num_frames
self
.
sample
=
sample
def
__call__
(
self
,
file_name
):
fn
=
read_frames_folder
if
file_name
.
endswith
(
"/"
)
else
read_frames_decord
frames
=
fn
(
file_name
,
num_frames
=
self
.
max_num_frames
,
min_num_frames
=
self
.
min_num_frames
,
sample
=
self
.
sample
,
)
return
frames
def
decode_video_byte
(
video_bytes
):
video_stream
=
io
.
BytesIO
(
video_bytes
)
vr
=
decord
.
VideoReader
(
video_stream
)
return
vr
def
sample_mp4_frames
(
mp4_p
,
n_frames
=
None
,
fps
=
None
,
return_frame_indices
=
False
,
random_sample
=
False
):
if
isinstance
(
mp4_p
,
str
):
vr
=
decord
.
VideoReader
(
mp4_p
,
num_threads
=
1
)
elif
isinstance
(
mp4_p
,
decord
.
video_reader
.
VideoReader
):
vr
=
mp4_p
video_fps
=
vr
.
get_avg_fps
()
# 获取视频的帧率
video_duration
=
len
(
vr
)
/
video_fps
if
n_frames
is
not
None
:
if
random_sample
:
frame_indices
=
sorted
(
random
.
sample
(
range
(
len
(
vr
)),
n_frames
))
else
:
frame_indices
=
np
.
linspace
(
0
,
len
(
vr
)
-
1
,
n_frames
,
dtype
=
int
).
tolist
()
else
:
frame_indices
=
[
int
(
i
)
for
i
in
np
.
arange
(
0
,
len
(
vr
)
-
1
,
video_fps
/
fps
)]
frames
=
vr
.
get_batch
(
frame_indices
).
asnumpy
()
# 转换为 numpy 数组
frames
=
[
Image
.
fromarray
(
frame
).
convert
(
"RGB"
)
for
frame
in
frames
]
if
not
return_frame_indices
:
return
frames
,
video_duration
else
:
return
frames
,
video_duration
,
frame_indices
def
sample_mp4_frames_by_indices
(
mp4_p
,
frame_indices
:
list
):
if
isinstance
(
mp4_p
,
str
):
vr
=
decord
.
VideoReader
(
mp4_p
,
num_threads
=
1
)
elif
isinstance
(
mp4_p
,
decord
.
video_reader
.
VideoReader
):
vr
=
mp4_p
# sample the frames in frame_indices
frames
=
vr
.
get_batch
(
frame_indices
).
asnumpy
()
# 转换为 numpy 数组
frames
=
[
Image
.
fromarray
(
frame
).
convert
(
"RGB"
)
for
frame
in
frames
]
return
frames
SenseNova-SI-main/training/bagel/data/vlm_dataset.py
0 → 100644
View file @
876a36a4
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
import
json
import
os
import
traceback
from
PIL
import
Image
,
ImageFile
,
PngImagePlugin
from
.data_utils
import
load_image
,
pil_img2rgb
from
.distributed_iterable_dataset
import
DistributedIterableDataset
Image
.
MAX_IMAGE_PIXELS
=
200000000
ImageFile
.
LOAD_TRUNCATED_IMAGES
=
True
MaximumDecompressedSize
=
1024
MegaByte
=
2
**
20
PngImagePlugin
.
MAX_TEXT_CHUNK
=
MaximumDecompressedSize
*
MegaByte
class
SftJSONLIterableDataset
(
DistributedIterableDataset
):
def
__init__
(
self
,
dataset_name
,
transform
,
tokenizer
,
frame_sampler
,
jsonl_path_list
,
data_dir_list
,
num_used_data
,
local_rank
=
0
,
world_size
=
1
,
num_workers
=
8
,
data_status
=
None
,
shuffle_lines
=
False
,
shuffle_seed
=
0
,
):
"""
jsonl_path_list: list of jsonl file paths
data_dir_list: list of image directories containing the images of each jsonl file
num_used_data: list of number of sampled data points for each jsonl
"""
super
().
__init__
(
dataset_name
,
local_rank
,
world_size
,
num_workers
)
self
.
transform
=
transform
self
.
tokenizer
=
tokenizer
self
.
frame_sampler
=
frame_sampler
self
.
data_status
=
data_status
self
.
data_paths
=
self
.
get_data_paths
(
jsonl_path_list
,
data_dir_list
,
num_used_data
,
shuffle_lines
,
shuffle_seed
,
)
self
.
set_epoch
()
def
get_data_paths
(
self
,
jsonl_path_list
,
data_dir_list
,
num_used_data
,
shuffle_lines
,
shuffle_seed
,
):
data_paths
=
[]
for
jsonl_path
,
image_dir
,
num_data_point
in
zip
(
jsonl_path_list
,
data_dir_list
,
num_used_data
):
with
open
(
jsonl_path
,
"r"
)
as
f
:
raw_data
=
f
.
readlines
()
if
shuffle_lines
:
self
.
rng
.
seed
(
shuffle_seed
)
self
.
rng
.
shuffle
(
raw_data
)
raw_data
=
raw_data
[:
num_data_point
]
data_paths
.
extend
([(
json_data
,
image_dir
)
for
json_data
in
raw_data
])
return
data_paths
def
change_format
(
self
,
data
,
num_images
):
elements
=
[]
for
conversation
in
data
[
"conversations"
]:
if
conversation
[
"from"
]
==
"human"
:
if
"<image>"
not
in
conversation
[
"value"
]:
elements
.
append
(
{
"type"
:
"text"
,
"has_loss"
:
0
,
"text"
:
conversation
[
"value"
],
}
)
else
:
text_list
=
conversation
[
"value"
].
split
(
"<image>"
)
for
idx
,
text
in
enumerate
(
text_list
):
if
text
.
strip
()
!=
""
:
elements
.
append
(
{
"type"
:
"text"
,
"has_loss"
:
0
,
"text"
:
text
.
strip
(),
}
)
if
(
idx
!=
len
(
text_list
)
-
1
)
and
(
idx
<
num_images
):
elements
.
append
(
{
"type"
:
"image"
,
}
)
elif
conversation
[
"from"
]
==
"gpt"
:
elements
.
append
(
{
"type"
:
"text"
,
"has_loss"
:
1
,
"text"
:
conversation
[
"value"
],
}
)
return
elements
def
__iter__
(
self
):
data_paths_per_worker
,
worker_id
=
self
.
get_data_paths_per_worker
()
if
self
.
data_status
is
not
None
:
row_start_id
=
self
.
data_status
[
worker_id
]
+
1
else
:
row_start_id
=
0
transform_stride
=
self
.
transform
.
stride
print
(
f
"rank-
{
self
.
local_rank
}
worker-
{
worker_id
}
dataset-
{
self
.
dataset_name
}
: "
f
"resuming data at row#
{
row_start_id
}
"
)
while
True
:
data_paths_per_worker_
=
data_paths_per_worker
[
row_start_id
:]
for
row_idx
,
(
data
,
image_dir
)
in
enumerate
(
data_paths_per_worker_
,
start
=
row_start_id
):
num_tokens
=
0
image_tensor_list
=
[]
text_ids_list
=
[]
sequence_plan
=
[]
try
:
data_item
=
json
.
loads
(
data
)
raw_images
=
None
if
"image"
in
data_item
:
if
type
(
data_item
[
"image"
])
==
list
:
raw_images
=
[
pil_img2rgb
(
load_image
(
os
.
path
.
join
(
image_dir
,
image
)))
for
image
in
data_item
[
"image"
]
]
else
:
raw_images
=
[
pil_img2rgb
(
load_image
(
os
.
path
.
join
(
image_dir
,
data_item
[
"image"
])
)
)
]
elif
"video"
in
data_item
:
raw_images
=
self
.
frame_sampler
(
os
.
path
.
join
(
image_dir
,
data_item
[
"video"
])
)
special_tokens
=
"<image>"
*
len
(
raw_images
)
for
item
in
data_item
[
"conversations"
]:
if
"<video>"
in
item
[
"value"
]:
item
[
"value"
]
=
item
[
"value"
].
replace
(
"<video>"
,
special_tokens
)
break
else
:
raise
ValueError
(
"Cannot find <video> in the conversation!"
)
except
:
traceback
.
print_exc
()
continue
if
raw_images
:
for
raw_image
in
raw_images
:
image_tensor
=
self
.
transform
(
raw_image
,
img_num
=
len
(
raw_images
)
)
image_tensor_list
.
append
(
image_tensor
)
height
,
width
=
image_tensor
.
shape
[
1
:]
num_tokens
+=
width
*
height
//
transform_stride
**
2
elements
=
self
.
change_format
(
data_item
,
len
(
image_tensor_list
))
for
item
in
elements
:
if
item
[
"type"
]
==
"text"
:
text_data
=
item
[
"text"
]
text_ids
=
self
.
tokenizer
.
encode
(
text_data
)
if
len
(
text_ids
)
>
0
:
text_ids_list
.
append
(
text_ids
)
num_tokens
+=
len
(
text_ids
)
current_plan
=
{
"type"
:
"text"
,
"enable_cfg"
:
0
,
"loss"
:
item
[
"has_loss"
],
"special_token_loss"
:
0
,
"special_token_label"
:
None
,
}
sequence_plan
.
append
(
current_plan
)
elif
item
[
"type"
]
==
"image"
:
current_plan
=
{
"type"
:
"vit_image"
,
"enable_cfg"
:
0
,
"loss"
:
0
,
"special_token_loss"
:
0
,
"special_token_label"
:
None
,
}
sequence_plan
.
append
(
current_plan
)
has_loss
=
[
item
[
"loss"
]
for
item
in
sequence_plan
]
if
sum
(
has_loss
)
==
0
:
print
(
f
"No loss defined, skipped."
)
continue
yield
dict
(
image_tensor_list
=
image_tensor_list
,
text_ids_list
=
text_ids_list
,
sequence_plan
=
sequence_plan
,
num_tokens
=
num_tokens
,
data_indexes
=
{
"data_indexes"
:
row_idx
,
"worker_id"
:
worker_id
,
"dataset_name"
:
self
.
dataset_name
,
},
)
row_start_id
=
0
print
(
f
"
{
self
.
dataset_name
}
repeat in rank-
{
self
.
local_rank
}
worker-
{
worker_id
}
"
)
SenseNova-SI-main/training/bagel/environment.yml
0 → 100644
View file @
876a36a4
name
:
bagel
channels
:
-
defaults
dependencies
:
-
_libgcc_mutex=0.1=main
-
_openmp_mutex=5.1=1_gnu
-
bzip2=1.0.8=h5eee18b_6
-
ca-certificates=2025.2.25=h06a4308_0
-
ld_impl_linux-64=2.40=h12ee557_0
-
libffi=3.4.4=h6a678d5_1
-
libgcc-ng=11.2.0=h1234567_1
-
libgomp=11.2.0=h1234567_1
-
libstdcxx-ng=11.2.0=h1234567_1
-
libuuid=1.41.5=h5eee18b_0
-
ncurses=6.4=h6a678d5_0
-
openssl=3.0.16=h5eee18b_0
-
pip=25.1=pyhc872135_2
-
python=3.10.16=he870216_1
-
readline=8.2=h5eee18b_0
-
setuptools=78.1.1=py310h06a4308_0
-
sqlite=3.45.3=h5eee18b_0
-
tk=8.6.14=h39e8969_0
-
wheel=0.45.1=py310h06a4308_0
-
xz=5.6.4=h5eee18b_1
-
zlib=1.2.13=h5eee18b_1
-
pip
:
-
accelerate==1.7.0
-
annotated-types==0.7.0
-
certifi==2025.4.26
-
charset-normalizer==3.4.2
-
click==8.2.1
-
contourpy==1.3.2
-
cycler==0.12.1
-
decord==0.6.0
-
docker-pycreds==0.4.0
-
einops==0.8.1
-
filelock==3.18.0
-
fonttools==4.58.0
-
fsspec==2025.5.1
-
gitdb==4.0.12
-
gitpython==3.1.44
-
huggingface-hub==0.29.1
-
idna==3.10
-
jinja2==3.1.6
-
kiwisolver==1.4.8
-
markupsafe==3.0.2
-
matplotlib==3.7.0
-
mpmath==1.3.0
-
networkx==3.4.2
-
ninja==1.11.1.4
-
numpy==1.24.4
-
nvidia-cublas-cu12==12.4.5.8
-
nvidia-cuda-cupti-cu12==12.4.127
-
nvidia-cuda-nvrtc-cu12==12.4.127
-
nvidia-cuda-runtime-cu12==12.4.127
-
nvidia-cudnn-cu12==9.1.0.70
-
nvidia-cufft-cu12==11.2.1.3
-
nvidia-curand-cu12==10.3.5.147
-
nvidia-cusolver-cu12==11.6.1.9
-
nvidia-cusparse-cu12==12.3.1.170
-
nvidia-nccl-cu12==2.21.5
-
nvidia-nvjitlink-cu12==12.4.127
-
nvidia-nvtx-cu12==12.4.127
-
opencv-python==4.7.0.72
-
packaging==25.0
-
pandas==2.3.0
-
pillow==11.2.1
-
platformdirs==4.3.8
-
protobuf==6.31.0
-
psutil==7.0.0
-
pyarrow==11.0.0
-
pydantic==2.11.5
-
pydantic-core==2.33.2
-
pyparsing==3.2.3
-
python-dateutil==2.9.0.post0
-
pytz==2025.2
-
pyyaml==6.0.2
-
regex==2024.11.6
-
requests==2.32.3
-
safetensors==0.4.5
-
scipy==1.10.1
-
sentencepiece==0.1.99
-
sentry-sdk==2.29.1
-
setproctitle==1.3.6
-
six==1.17.0
-
smmap==5.0.2
-
sympy==1.13.1
-
tokenizers==0.21.1
-
torch==2.5.1
-
torchvision==0.20.1
-
tqdm==4.67.1
-
transformers==4.49.0
-
triton==3.1.0
-
typing-extensions==4.13.2
-
typing-inspection==0.4.1
-
tzdata==2025.2
-
urllib3==2.4.0
-
wandb==0.19.11
\ No newline at end of file
SenseNova-SI-main/training/bagel/modeling/__init__.py
0 → 100644
View file @
876a36a4
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
from
.
import
autoencoder
,
bagel
,
qwen2
,
siglip
SenseNova-SI-main/training/bagel/modeling/autoencoder.py
0 → 100644
View file @
876a36a4
# Copyright (c) 2024 Black Forest Labs.
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
#
# This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025-05-20.
#
# Original file was released under Apache-2.0, with the full license text
# available at https://github.com/black-forest-labs/flux/blob/main/LICENSE.
#
# This modified file is released under the same license.
from
dataclasses
import
dataclass
import
torch
from
einops
import
rearrange
from
safetensors.torch
import
load_file
as
load_sft
from
torch
import
Tensor
,
nn
@
dataclass
class
AutoEncoderParams
:
resolution
:
int
in_channels
:
int
downsample
:
int
ch
:
int
out_ch
:
int
ch_mult
:
list
[
int
]
num_res_blocks
:
int
z_channels
:
int
scale_factor
:
float
shift_factor
:
float
def
swish
(
x
:
Tensor
)
->
Tensor
:
return
x
*
torch
.
sigmoid
(
x
)
class
AttnBlock
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
):
super
().
__init__
()
self
.
in_channels
=
in_channels
self
.
norm
=
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
self
.
q
=
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
)
self
.
k
=
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
)
self
.
v
=
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
)
self
.
proj_out
=
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
1
)
def
attention
(
self
,
h_
:
Tensor
)
->
Tensor
:
h_
=
self
.
norm
(
h_
)
q
=
self
.
q
(
h_
)
k
=
self
.
k
(
h_
)
v
=
self
.
v
(
h_
)
b
,
c
,
h
,
w
=
q
.
shape
q
=
rearrange
(
q
,
"b c h w -> b 1 (h w) c"
).
contiguous
()
k
=
rearrange
(
k
,
"b c h w -> b 1 (h w) c"
).
contiguous
()
v
=
rearrange
(
v
,
"b c h w -> b 1 (h w) c"
).
contiguous
()
h_
=
nn
.
functional
.
scaled_dot_product_attention
(
q
,
k
,
v
)
return
rearrange
(
h_
,
"b 1 (h w) c -> b c h w"
,
h
=
h
,
w
=
w
,
c
=
c
,
b
=
b
)
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
return
x
+
self
.
proj_out
(
self
.
attention
(
x
))
class
ResnetBlock
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
,
out_channels
:
int
):
super
().
__init__
()
self
.
in_channels
=
in_channels
out_channels
=
in_channels
if
out_channels
is
None
else
out_channels
self
.
out_channels
=
out_channels
self
.
norm1
=
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
in_channels
,
eps
=
1e-6
,
affine
=
True
)
self
.
conv1
=
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
self
.
norm2
=
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
out_channels
,
eps
=
1e-6
,
affine
=
True
)
self
.
conv2
=
nn
.
Conv2d
(
out_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
if
self
.
in_channels
!=
self
.
out_channels
:
self
.
nin_shortcut
=
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
1
,
padding
=
0
)
def
forward
(
self
,
x
):
h
=
x
h
=
self
.
norm1
(
h
)
h
=
swish
(
h
)
h
=
self
.
conv1
(
h
)
h
=
self
.
norm2
(
h
)
h
=
swish
(
h
)
h
=
self
.
conv2
(
h
)
if
self
.
in_channels
!=
self
.
out_channels
:
x
=
self
.
nin_shortcut
(
x
)
return
x
+
h
class
Downsample
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
):
super
().
__init__
()
# no asymmetric padding in torch conv, must do it ourselves
self
.
conv
=
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
3
,
stride
=
2
,
padding
=
0
)
def
forward
(
self
,
x
:
Tensor
):
pad
=
(
0
,
1
,
0
,
1
)
x
=
nn
.
functional
.
pad
(
x
,
pad
,
mode
=
"constant"
,
value
=
0
)
x
=
self
.
conv
(
x
)
return
x
class
Upsample
(
nn
.
Module
):
def
__init__
(
self
,
in_channels
:
int
):
super
().
__init__
()
self
.
conv
=
nn
.
Conv2d
(
in_channels
,
in_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
x
:
Tensor
):
x
=
nn
.
functional
.
interpolate
(
x
,
scale_factor
=
2.0
,
mode
=
"nearest"
)
x
=
self
.
conv
(
x
)
return
x
class
Encoder
(
nn
.
Module
):
def
__init__
(
self
,
resolution
:
int
,
in_channels
:
int
,
ch
:
int
,
ch_mult
:
list
[
int
],
num_res_blocks
:
int
,
z_channels
:
int
,
):
super
().
__init__
()
self
.
ch
=
ch
self
.
num_resolutions
=
len
(
ch_mult
)
self
.
num_res_blocks
=
num_res_blocks
self
.
resolution
=
resolution
self
.
in_channels
=
in_channels
# downsampling
self
.
conv_in
=
nn
.
Conv2d
(
in_channels
,
self
.
ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
curr_res
=
resolution
in_ch_mult
=
(
1
,)
+
tuple
(
ch_mult
)
self
.
in_ch_mult
=
in_ch_mult
self
.
down
=
nn
.
ModuleList
()
block_in
=
self
.
ch
for
i_level
in
range
(
self
.
num_resolutions
):
block
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
block_in
=
ch
*
in_ch_mult
[
i_level
]
block_out
=
ch
*
ch_mult
[
i_level
]
for
_
in
range
(
self
.
num_res_blocks
):
block
.
append
(
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_out
))
block_in
=
block_out
down
=
nn
.
Module
()
down
.
block
=
block
down
.
attn
=
attn
if
i_level
!=
self
.
num_resolutions
-
1
:
down
.
downsample
=
Downsample
(
block_in
)
curr_res
=
curr_res
//
2
self
.
down
.
append
(
down
)
# middle
self
.
mid
=
nn
.
Module
()
self
.
mid
.
block_1
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
)
self
.
mid
.
attn_1
=
AttnBlock
(
block_in
)
self
.
mid
.
block_2
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
)
# end
self
.
norm_out
=
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
block_in
,
eps
=
1e-6
,
affine
=
True
)
self
.
conv_out
=
nn
.
Conv2d
(
block_in
,
2
*
z_channels
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
# downsampling
hs
=
[
self
.
conv_in
(
x
)]
for
i_level
in
range
(
self
.
num_resolutions
):
for
i_block
in
range
(
self
.
num_res_blocks
):
h
=
self
.
down
[
i_level
].
block
[
i_block
](
hs
[
-
1
])
if
len
(
self
.
down
[
i_level
].
attn
)
>
0
:
h
=
self
.
down
[
i_level
].
attn
[
i_block
](
h
)
hs
.
append
(
h
)
if
i_level
!=
self
.
num_resolutions
-
1
:
hs
.
append
(
self
.
down
[
i_level
].
downsample
(
hs
[
-
1
]))
# middle
h
=
hs
[
-
1
]
h
=
self
.
mid
.
block_1
(
h
)
h
=
self
.
mid
.
attn_1
(
h
)
h
=
self
.
mid
.
block_2
(
h
)
# end
h
=
self
.
norm_out
(
h
)
h
=
swish
(
h
)
h
=
self
.
conv_out
(
h
)
return
h
class
Decoder
(
nn
.
Module
):
def
__init__
(
self
,
ch
:
int
,
out_ch
:
int
,
ch_mult
:
list
[
int
],
num_res_blocks
:
int
,
in_channels
:
int
,
resolution
:
int
,
z_channels
:
int
,
):
super
().
__init__
()
self
.
ch
=
ch
self
.
num_resolutions
=
len
(
ch_mult
)
self
.
num_res_blocks
=
num_res_blocks
self
.
resolution
=
resolution
self
.
in_channels
=
in_channels
self
.
ffactor
=
2
**
(
self
.
num_resolutions
-
1
)
# compute in_ch_mult, block_in and curr_res at lowest res
block_in
=
ch
*
ch_mult
[
self
.
num_resolutions
-
1
]
curr_res
=
resolution
//
2
**
(
self
.
num_resolutions
-
1
)
self
.
z_shape
=
(
1
,
z_channels
,
curr_res
,
curr_res
)
# z to block_in
self
.
conv_in
=
nn
.
Conv2d
(
z_channels
,
block_in
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
# middle
self
.
mid
=
nn
.
Module
()
self
.
mid
.
block_1
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
)
self
.
mid
.
attn_1
=
AttnBlock
(
block_in
)
self
.
mid
.
block_2
=
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_in
)
# upsampling
self
.
up
=
nn
.
ModuleList
()
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
block
=
nn
.
ModuleList
()
attn
=
nn
.
ModuleList
()
block_out
=
ch
*
ch_mult
[
i_level
]
for
_
in
range
(
self
.
num_res_blocks
+
1
):
block
.
append
(
ResnetBlock
(
in_channels
=
block_in
,
out_channels
=
block_out
))
block_in
=
block_out
up
=
nn
.
Module
()
up
.
block
=
block
up
.
attn
=
attn
if
i_level
!=
0
:
up
.
upsample
=
Upsample
(
block_in
)
curr_res
=
curr_res
*
2
self
.
up
.
insert
(
0
,
up
)
# prepend to get consistent order
# end
self
.
norm_out
=
nn
.
GroupNorm
(
num_groups
=
32
,
num_channels
=
block_in
,
eps
=
1e-6
,
affine
=
True
)
self
.
conv_out
=
nn
.
Conv2d
(
block_in
,
out_ch
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
)
def
forward
(
self
,
z
:
Tensor
)
->
Tensor
:
# z to block_in
h
=
self
.
conv_in
(
z
)
# middle
h
=
self
.
mid
.
block_1
(
h
)
h
=
self
.
mid
.
attn_1
(
h
)
h
=
self
.
mid
.
block_2
(
h
)
# upsampling
for
i_level
in
reversed
(
range
(
self
.
num_resolutions
)):
for
i_block
in
range
(
self
.
num_res_blocks
+
1
):
h
=
self
.
up
[
i_level
].
block
[
i_block
](
h
)
if
len
(
self
.
up
[
i_level
].
attn
)
>
0
:
h
=
self
.
up
[
i_level
].
attn
[
i_block
](
h
)
if
i_level
!=
0
:
h
=
self
.
up
[
i_level
].
upsample
(
h
)
# end
h
=
self
.
norm_out
(
h
)
h
=
swish
(
h
)
h
=
self
.
conv_out
(
h
)
return
h
class
DiagonalGaussian
(
nn
.
Module
):
def
__init__
(
self
,
sample
:
bool
=
True
,
chunk_dim
:
int
=
1
):
super
().
__init__
()
self
.
sample
=
sample
self
.
chunk_dim
=
chunk_dim
def
forward
(
self
,
z
:
Tensor
)
->
Tensor
:
mean
,
logvar
=
torch
.
chunk
(
z
,
2
,
dim
=
self
.
chunk_dim
)
if
self
.
sample
:
std
=
torch
.
exp
(
0.5
*
logvar
)
return
mean
+
std
*
torch
.
randn_like
(
mean
)
else
:
return
mean
class
AutoEncoder
(
nn
.
Module
):
def
__init__
(
self
,
params
:
AutoEncoderParams
):
super
().
__init__
()
self
.
encoder
=
Encoder
(
resolution
=
params
.
resolution
,
in_channels
=
params
.
in_channels
,
ch
=
params
.
ch
,
ch_mult
=
params
.
ch_mult
,
num_res_blocks
=
params
.
num_res_blocks
,
z_channels
=
params
.
z_channels
,
)
self
.
decoder
=
Decoder
(
resolution
=
params
.
resolution
,
in_channels
=
params
.
in_channels
,
ch
=
params
.
ch
,
out_ch
=
params
.
out_ch
,
ch_mult
=
params
.
ch_mult
,
num_res_blocks
=
params
.
num_res_blocks
,
z_channels
=
params
.
z_channels
,
)
self
.
reg
=
DiagonalGaussian
()
self
.
scale_factor
=
params
.
scale_factor
self
.
shift_factor
=
params
.
shift_factor
def
encode
(
self
,
x
:
Tensor
)
->
Tensor
:
z
=
self
.
reg
(
self
.
encoder
(
x
))
z
=
self
.
scale_factor
*
(
z
-
self
.
shift_factor
)
return
z
def
decode
(
self
,
z
:
Tensor
)
->
Tensor
:
z
=
z
/
self
.
scale_factor
+
self
.
shift_factor
return
self
.
decoder
(
z
)
def
forward
(
self
,
x
:
Tensor
)
->
Tensor
:
return
self
.
decode
(
self
.
encode
(
x
))
def
print_load_warning
(
missing
:
list
[
str
],
unexpected
:
list
[
str
])
->
None
:
if
len
(
missing
)
>
0
and
len
(
unexpected
)
>
0
:
print
(
f
"Got
{
len
(
missing
)
}
missing keys:
\n\t
"
+
"
\n\t
"
.
join
(
missing
))
print
(
"
\n
"
+
"-"
*
79
+
"
\n
"
)
print
(
f
"Got
{
len
(
unexpected
)
}
unexpected keys:
\n\t
"
+
"
\n\t
"
.
join
(
unexpected
))
elif
len
(
missing
)
>
0
:
print
(
f
"Got
{
len
(
missing
)
}
missing keys:
\n\t
"
+
"
\n\t
"
.
join
(
missing
))
elif
len
(
unexpected
)
>
0
:
print
(
f
"Got
{
len
(
unexpected
)
}
unexpected keys:
\n\t
"
+
"
\n\t
"
.
join
(
unexpected
))
def
load_ae
(
local_path
:
str
|
None
)
->
tuple
[
AutoEncoder
,
AutoEncoderParams
]:
ae_params
=
AutoEncoderParams
(
resolution
=
256
,
in_channels
=
3
,
downsample
=
8
,
ch
=
128
,
out_ch
=
3
,
ch_mult
=
[
1
,
2
,
4
,
4
],
num_res_blocks
=
2
,
z_channels
=
16
,
scale_factor
=
0.3611
,
shift_factor
=
0.1159
,
)
# Loading the autoencoder
ae
=
AutoEncoder
(
ae_params
)
if
local_path
is
not
None
:
sd
=
load_sft
(
local_path
)
missing
,
unexpected
=
ae
.
load_state_dict
(
sd
,
strict
=
False
,
assign
=
True
)
print_load_warning
(
missing
,
unexpected
)
return
ae
,
ae_params
SenseNova-SI-main/training/bagel/modeling/bagel/__init__.py
0 → 100644
View file @
876a36a4
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
from
.bagel
import
Bagel
,
BagelConfig
from
.qwen2_navit
import
Qwen2Config
,
Qwen2ForCausalLM
,
Qwen2Model
from
.siglip_navit
import
SiglipVisionConfig
,
SiglipVisionModel
__all__
=
[
"BagelConfig"
,
"Bagel"
,
"Qwen2Config"
,
"Qwen2Model"
,
"Qwen2ForCausalLM"
,
"SiglipVisionConfig"
,
"SiglipVisionModel"
,
]
SenseNova-SI-main/training/bagel/modeling/bagel/bagel.py
0 → 100644
View file @
876a36a4
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
import
copy
from
typing
import
List
,
Optional
,
Tuple
import
torch
import
torch.nn.functional
as
F
from
data.data_utils
import
(
create_sparse_mask
,
get_flattened_position_ids_extrapolate
,
get_flattened_position_ids_interpolate
,
patchify
,
)
from
torch
import
nn
from
torch.nn.attention.flex_attention
import
create_block_mask
from
tqdm
import
tqdm
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.modeling_utils
import
PreTrainedModel
from
.modeling_utils
import
MLPconnector
,
PositionEmbedding
,
TimestepEmbedder
from
.qwen2_navit
import
NaiveCache
class
BagelConfig
(
PretrainedConfig
):
def
__init__
(
self
,
visual_gen
=
True
,
visual_und
=
True
,
llm_config
=
None
,
vit_config
=
None
,
vae_config
=
None
,
latent_patch_size
=
2
,
max_latent_size
=
32
,
vit_max_num_patch_per_side
=
70
,
connector_act
=
"gelu_pytorch_tanh"
,
interpolate_pos
=
False
,
timestep_shift
=
1.0
,
**
kwargs
,
):
super
().
__init__
(
**
kwargs
)
self
.
visual_gen
=
visual_gen
self
.
visual_und
=
visual_und
self
.
llm_config
=
llm_config
self
.
vit_config
=
vit_config
self
.
vae_config
=
vae_config
self
.
latent_patch_size
=
latent_patch_size
self
.
max_latent_size
=
max_latent_size
self
.
vit_max_num_patch_per_side
=
vit_max_num_patch_per_side
self
.
connector_act
=
connector_act
self
.
interpolate_pos
=
interpolate_pos
self
.
timestep_shift
=
timestep_shift
class
Bagel
(
PreTrainedModel
):
config_class
=
BagelConfig
base_model_prefix
=
"bagel"
def
__init__
(
self
,
language_model
,
vit_model
,
config
:
BagelConfig
):
super
().
__init__
(
config
)
self
.
language_model
=
language_model
if
config
.
llm_config
is
None
:
raise
ValueError
(
"llm_config cannot be None"
)
self
.
hidden_size
=
config
.
llm_config
.
hidden_size
self
.
use_moe
=
"Mo"
in
config
.
llm_config
.
layer_module
self
.
num_heads
=
config
.
llm_config
.
num_attention_heads
if
config
.
visual_gen
:
if
config
.
vae_config
is
None
:
raise
ValueError
(
"vae_config cannot be None when visual_gen is True"
)
self
.
latent_patch_size
=
config
.
latent_patch_size
self
.
timestep_shift
=
config
.
timestep_shift
self
.
latent_downsample
=
(
config
.
vae_config
.
downsample
*
config
.
latent_patch_size
)
self
.
max_latent_size
=
config
.
max_latent_size
self
.
latent_channel
=
config
.
vae_config
.
z_channels
self
.
patch_latent_dim
=
self
.
latent_patch_size
**
2
*
self
.
latent_channel
self
.
time_embedder
=
TimestepEmbedder
(
self
.
hidden_size
)
self
.
vae2llm
=
nn
.
Linear
(
self
.
patch_latent_dim
,
self
.
hidden_size
)
self
.
llm2vae
=
nn
.
Linear
(
self
.
hidden_size
,
self
.
patch_latent_dim
)
self
.
latent_pos_embed
=
PositionEmbedding
(
self
.
max_latent_size
,
self
.
hidden_size
)
if
config
.
visual_und
:
if
config
.
vit_config
is
None
:
raise
ValueError
(
"vit_config cannot be None when visual_und is True"
)
self
.
vit_model
=
vit_model
self
.
vit_patch_size
=
config
.
vit_config
.
patch_size
self
.
vit_max_num_patch_per_side
=
config
.
vit_max_num_patch_per_side
self
.
vit_hidden_size
=
config
.
vit_config
.
hidden_size
self
.
connector
=
MLPconnector
(
self
.
vit_hidden_size
,
self
.
hidden_size
,
config
.
connector_act
)
self
.
vit_pos_embed
=
PositionEmbedding
(
self
.
vit_max_num_patch_per_side
,
self
.
hidden_size
)
if
config
.
interpolate_pos
:
self
.
get_flattened_position_ids
=
get_flattened_position_ids_interpolate
else
:
self
.
get_flattened_position_ids
=
get_flattened_position_ids_extrapolate
self
.
config
=
config
self
.
_init_weights
()
def
_init_weights
(
self
):
if
self
.
config
.
visual_gen
:
nn
.
init
.
constant_
(
self
.
llm2vae
.
weight
,
0
)
nn
.
init
.
constant_
(
self
.
llm2vae
.
bias
,
0
)
def
forward
(
self
,
sequence_length
:
int
,
packed_text_ids
:
torch
.
LongTensor
,
packed_text_indexes
:
torch
.
LongTensor
,
sample_lens
:
List
[
int
],
packed_position_ids
:
torch
.
LongTensor
,
nested_attention_masks
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
split_lens
:
Optional
[
List
[
int
]]
=
None
,
attn_modes
:
Optional
[
List
[
str
]]
=
None
,
# for visual understanding
ce_loss_indexes
:
Optional
[
torch
.
BoolTensor
]
=
None
,
packed_label_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
packed_vit_tokens
:
Optional
[
torch
.
Tensor
]
=
None
,
packed_vit_token_indexes
:
Optional
[
torch
.
LongTensor
]
=
None
,
packed_vit_position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
vit_token_seqlens
:
Optional
[
torch
.
IntTensor
]
=
None
,
# for visual generation
padded_latent
:
Optional
[
torch
.
Tensor
]
=
None
,
patchified_vae_latent_shapes
:
Optional
[
List
[
Tuple
[
int
,
int
]]]
=
None
,
packed_latent_position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
packed_vae_token_indexes
:
Optional
[
torch
.
LongTensor
]
=
None
,
packed_timesteps
:
Optional
[
torch
.
LongTensor
]
=
None
,
mse_loss_indexes
:
Optional
[
torch
.
BoolTensor
]
=
None
,
)
->
torch
.
Tensor
:
"""
Args:
sequence_length: length of sequence.
packed_text_ids: 1-D int tensor, packed text token ids.
packed_text_indexes: 1-D int tensor, packed text token indexes in sequence.
sample_lens: A list of N ints, length of each sample in packed_sequence.
nested_attention_masks: A list of N 2-D float tensor, where 0.0 means attention and
-inf means ignore.
packed_position_ids: packed 1-D positions, an image has only one global position shared
by all latent tokens.
packed_vit_tokens: packed patchified image tokens for vit model.
packed_vit_position_ids: 1-D int tensor, the position of each token for vit model.
packed_vit_token_indexes: 1-D int tensor, packed vit token indexes in sequence.
vit_token_seqlens: 1-D int tensor, the length of each image tokens for vit model.
packed_label_ids: 1-D int tensor, packed label token ids.
ce_loss_indexes: 1-D bool tensor, where to compute ce loss.
padded_latent: padded latent from VAE encoder.
patchified_vae_latent_shapes: A list of (h, w) tuples, patchfied latent shapes of each image.
packed_latent_position_ids: 1-D int tensor, the position of each token for latent.
packed_vae_token_indexes: 1-D int tensor, padded image token indexes in sequence.
packed_timesteps: 1-D float tensor, flow timesteps. 0 indicates use clean image.
mse_loss_indexes: 1-D bool tensor, where to compute mse loss.
"""
packed_text_embedding
=
self
.
language_model
.
model
.
embed_tokens
(
packed_text_ids
)
packed_sequence
=
packed_text_embedding
.
new_zeros
(
size
=
(
sequence_length
,
self
.
hidden_size
)
)
packed_sequence
[
packed_text_indexes
]
=
packed_text_embedding
if
nested_attention_masks
is
None
:
sparse_mask
=
create_sparse_mask
(
sample_lens
,
split_lens
,
attn_modes
,
packed_text_embedding
.
device
)
seqlen
=
sum
(
sample_lens
)
block_mask
=
create_block_mask
(
sparse_mask
,
B
=
1
,
H
=
self
.
num_heads
,
Q_LEN
=
seqlen
,
KV_LEN
=
seqlen
,
device
=
packed_text_embedding
.
device
,
BLOCK_SIZE
=
128
,
_compile
=
True
,
)
attention_mask
=
block_mask
else
:
attention_mask
=
nested_attention_masks
if
self
.
config
.
visual_und
:
cu_seqlens
=
torch
.
nn
.
functional
.
pad
(
torch
.
cumsum
(
vit_token_seqlens
,
dim
=
0
),
(
1
,
0
)
)
cu_seqlens
=
cu_seqlens
.
to
(
torch
.
int32
)
max_seqlen
=
torch
.
max
(
vit_token_seqlens
).
item
()
packed_vit_token_embed
=
self
.
vit_model
(
packed_pixel_values
=
packed_vit_tokens
,
packed_flattened_position_ids
=
packed_vit_position_ids
,
cu_seqlens
=
cu_seqlens
,
max_seqlen
=
max_seqlen
,
)
packed_vit_token_embed
=
self
.
connector
(
packed_vit_token_embed
)
vit_token_pos_emb
=
self
.
vit_pos_embed
(
packed_vit_position_ids
)
packed_vit_token_embed
=
packed_vit_token_embed
+
vit_token_pos_emb
packed_sequence
[
packed_vit_token_indexes
]
=
packed_vit_token_embed
if
self
.
config
.
visual_gen
:
p
=
self
.
latent_patch_size
packed_latent
=
[]
for
latent
,
(
h
,
w
)
in
zip
(
padded_latent
,
patchified_vae_latent_shapes
):
latent
=
latent
[:,
:
h
*
p
,
:
w
*
p
].
reshape
(
self
.
latent_channel
,
h
,
p
,
w
,
p
)
latent
=
torch
.
einsum
(
"chpwq->hwpqc"
,
latent
).
reshape
(
-
1
,
p
*
p
*
self
.
latent_channel
)
packed_latent
.
append
(
latent
)
packed_latent_clean
=
torch
.
cat
(
packed_latent
,
dim
=
0
)
noise
=
torch
.
randn_like
(
packed_latent_clean
)
packed_timesteps
=
torch
.
sigmoid
(
packed_timesteps
)
packed_timesteps
=
(
self
.
timestep_shift
*
packed_timesteps
/
(
1
+
(
self
.
timestep_shift
-
1
)
*
packed_timesteps
)
)
packed_latent
=
(
1
-
packed_timesteps
[:,
None
]
)
*
packed_latent_clean
+
packed_timesteps
[:,
None
]
*
noise
packed_timestep_embeds
=
self
.
time_embedder
(
packed_timesteps
)
latent_token_pos_emb
=
self
.
latent_pos_embed
(
packed_latent_position_ids
)
packed_latent
=
(
self
.
vae2llm
(
packed_latent
)
+
packed_timestep_embeds
+
latent_token_pos_emb
)
packed_sequence
[
packed_vae_token_indexes
]
=
packed_latent
extra_inputs
=
{}
if
self
.
use_moe
:
packed_und_token_indexes
=
packed_text_indexes
if
packed_vit_token_indexes
is
not
None
:
packed_und_token_indexes
=
torch
.
cat
(
[
packed_text_indexes
,
packed_vit_token_indexes
],
dim
=
0
)
extra_inputs
.
update
(
packed_und_token_indexes
=
packed_und_token_indexes
,
packed_gen_token_indexes
=
packed_vae_token_indexes
,
)
last_hidden_state
=
self
.
language_model
(
packed_sequence
=
packed_sequence
,
sample_lens
=
sample_lens
,
attention_mask
=
attention_mask
,
packed_position_ids
=
packed_position_ids
,
**
extra_inputs
,
)
mse
=
None
if
self
.
config
.
visual_gen
:
packed_mse_preds
=
self
.
llm2vae
(
last_hidden_state
[
mse_loss_indexes
])
target
=
(
noise
-
packed_latent_clean
)
# NOTE: v_t=dx_t/dt=x_1-x_0, pointing from data to noise
has_mse
=
packed_timesteps
>
0
mse
=
(
packed_mse_preds
-
target
[
has_mse
])
**
2
ce
=
None
if
ce_loss_indexes
is
not
None
:
packed_ce_preds
=
self
.
language_model
.
lm_head
(
last_hidden_state
[
ce_loss_indexes
]
)
ce
=
F
.
cross_entropy
(
packed_ce_preds
,
packed_label_ids
,
reduction
=
"none"
)
return
dict
(
mse
=
mse
,
ce
=
ce
)
def
prepare_prompts
(
self
,
curr_kvlens
,
curr_rope
,
prompts
,
tokenizer
,
new_token_ids
):
packed_text_ids
=
list
()
packed_text_position_ids
=
list
()
text_token_lens
=
list
()
packed_text_indexes
=
list
()
packed_key_value_indexes
=
list
()
curr
=
0
newlens
,
new_rope
=
list
(),
list
()
for
prompt
,
curr_kvlen
,
curr_position_id
in
zip
(
prompts
,
curr_kvlens
,
curr_rope
):
packed_key_value_indexes
.
extend
(
range
(
curr
,
curr
+
curr_kvlen
))
curr
+=
curr_kvlen
text_ids
=
tokenizer
.
encode
(
prompt
)
text_ids
=
(
[
new_token_ids
[
"bos_token_id"
]]
+
text_ids
+
[
new_token_ids
[
"eos_token_id"
]]
)
text_token_lens
.
append
(
len
(
text_ids
))
packed_text_ids
.
extend
(
text_ids
)
packed_text_position_ids
.
extend
(
range
(
curr_position_id
,
curr_position_id
+
len
(
text_ids
))
)
packed_text_indexes
.
extend
(
range
(
curr
,
curr
+
len
(
text_ids
)))
newlens
.
append
(
curr_kvlen
+
len
(
text_ids
))
new_rope
.
append
(
curr_position_id
+
len
(
text_ids
))
curr
+=
len
(
text_ids
)
generation_input
=
{
"text_token_lens"
:
torch
.
tensor
(
text_token_lens
,
dtype
=
torch
.
int
),
"packed_text_ids"
:
torch
.
tensor
(
packed_text_ids
,
dtype
=
torch
.
long
),
"packed_text_position_ids"
:
torch
.
tensor
(
packed_text_position_ids
,
dtype
=
torch
.
long
),
"packed_text_indexes"
:
torch
.
tensor
(
packed_text_indexes
,
dtype
=
torch
.
long
),
"packed_key_value_indexes"
:
torch
.
tensor
(
packed_key_value_indexes
,
dtype
=
torch
.
long
),
"key_values_lens"
:
torch
.
tensor
(
curr_kvlens
,
dtype
=
torch
.
int
),
}
return
generation_input
,
newlens
,
new_rope
@
torch
.
no_grad
def
forward_cache_update_text
(
self
,
past_key_values
:
NaiveCache
,
packed_text_ids
:
torch
.
IntTensor
,
packed_text_position_ids
:
torch
.
LongTensor
,
text_token_lens
:
torch
.
LongTensor
,
packed_text_indexes
:
torch
.
LongTensor
,
packed_key_value_indexes
:
torch
.
LongTensor
,
key_values_lens
:
torch
.
IntTensor
,
):
packed_text_embedding
=
self
.
language_model
.
model
.
embed_tokens
(
packed_text_ids
)
extra_inputs
=
{}
if
self
.
use_moe
:
extra_inputs
=
{
"mode"
:
"und"
}
output
=
self
.
language_model
.
forward_inference
(
packed_query_sequence
=
packed_text_embedding
,
query_lens
=
text_token_lens
,
packed_query_position_ids
=
packed_text_position_ids
,
packed_query_indexes
=
packed_text_indexes
,
past_key_values
=
past_key_values
,
packed_key_value_indexes
=
packed_key_value_indexes
,
key_values_lens
=
key_values_lens
,
update_past_key_values
=
True
,
is_causal
=
True
,
**
extra_inputs
,
)
past_key_values
=
output
.
past_key_values
return
past_key_values
def
prepare_vit_images
(
self
,
curr_kvlens
,
curr_rope
,
images
,
transforms
,
new_token_ids
):
packed_vit_token_indexes
=
list
()
vit_token_seqlens
,
packed_vit_tokens
,
packed_vit_position_ids
=
(
list
(),
list
(),
list
(),
)
packed_text_ids
,
packed_text_indexes
=
list
(),
list
()
packed_seqlens
,
packed_position_ids
,
packed_indexes
=
list
(),
list
(),
list
()
packed_key_value_indexes
=
list
()
_curr
=
curr
=
0
newlens
,
new_rope
=
list
(),
list
()
for
image
,
curr_kvlen
,
curr_position_id
in
zip
(
images
,
curr_kvlens
,
curr_rope
):
packed_key_value_indexes
.
extend
(
range
(
curr
,
curr
+
curr_kvlen
))
curr
+=
curr_kvlen
packed_text_ids
.
append
(
new_token_ids
[
"start_of_image"
])
packed_text_indexes
.
append
(
_curr
)
packed_indexes
.
append
(
curr
)
curr
+=
1
_curr
+=
1
image_tensor
=
transforms
(
image
)
vit_position_ids
=
self
.
get_flattened_position_ids
(
image_tensor
.
size
(
1
),
image_tensor
.
size
(
2
),
self
.
vit_patch_size
,
max_num_patches_per_side
=
self
.
vit_max_num_patch_per_side
,
)
vit_tokens
=
patchify
(
image_tensor
,
self
.
vit_patch_size
)
packed_vit_tokens
.
append
(
vit_tokens
)
num_img_tokens
=
vit_tokens
.
shape
[
0
]
packed_vit_position_ids
.
append
(
vit_position_ids
)
vit_token_seqlens
.
append
(
num_img_tokens
)
packed_vit_token_indexes
.
extend
(
range
(
_curr
,
_curr
+
num_img_tokens
))
packed_indexes
.
extend
(
range
(
curr
,
curr
+
num_img_tokens
))
curr
+=
num_img_tokens
_curr
+=
num_img_tokens
packed_text_ids
.
append
(
new_token_ids
[
"end_of_image"
])
packed_text_indexes
.
append
(
_curr
)
packed_indexes
.
append
(
curr
)
curr
+=
1
_curr
+=
1
packed_position_ids
.
extend
([
curr_position_id
]
*
(
num_img_tokens
+
2
))
packed_seqlens
.
append
(
num_img_tokens
+
2
)
newlens
.
append
(
curr_kvlen
+
num_img_tokens
+
2
)
new_rope
.
append
(
curr_position_id
+
1
)
generation_input
=
{
"packed_text_ids"
:
torch
.
tensor
(
packed_text_ids
,
dtype
=
torch
.
long
),
"packed_text_indexes"
:
torch
.
tensor
(
packed_text_indexes
,
dtype
=
torch
.
long
),
"vit_token_seqlens"
:
torch
.
tensor
(
vit_token_seqlens
,
dtype
=
torch
.
int
),
"packed_vit_tokens"
:
torch
.
cat
(
packed_vit_tokens
,
dim
=
0
),
"packed_vit_position_ids"
:
torch
.
cat
(
packed_vit_position_ids
,
dim
=
0
),
"packed_vit_token_indexes"
:
torch
.
tensor
(
packed_vit_token_indexes
,
dtype
=
torch
.
long
),
"packed_position_ids"
:
torch
.
tensor
(
packed_position_ids
,
dtype
=
torch
.
long
),
"packed_seqlens"
:
torch
.
tensor
(
packed_seqlens
,
dtype
=
torch
.
int
),
"packed_indexes"
:
torch
.
tensor
(
packed_indexes
,
dtype
=
torch
.
long
),
"packed_key_value_indexes"
:
torch
.
tensor
(
packed_key_value_indexes
,
dtype
=
torch
.
long
),
"key_values_lens"
:
torch
.
tensor
(
curr_kvlens
,
dtype
=
torch
.
int
),
}
return
generation_input
,
newlens
,
new_rope
@
torch
.
no_grad
def
forward_cache_update_vit
(
self
,
past_key_values
:
NaiveCache
,
packed_text_ids
:
torch
.
LongTensor
,
packed_text_indexes
:
torch
.
LongTensor
,
packed_vit_tokens
:
torch
.
Tensor
,
packed_vit_token_indexes
:
torch
.
LongTensor
,
packed_vit_position_ids
:
torch
.
LongTensor
,
vit_token_seqlens
:
torch
.
IntTensor
,
packed_position_ids
:
torch
.
LongTensor
,
packed_seqlens
:
torch
.
IntTensor
,
packed_indexes
:
torch
.
LongTensor
,
packed_key_value_indexes
:
torch
.
LongTensor
,
key_values_lens
:
torch
.
IntTensor
,
):
packed_text_embedding
=
self
.
language_model
.
model
.
embed_tokens
(
packed_text_ids
)
packed_sequence
=
packed_text_embedding
.
new_zeros
(
(
sum
(
packed_seqlens
),
self
.
hidden_size
)
)
packed_sequence
[
packed_text_indexes
]
=
packed_text_embedding
cu_seqlens
=
torch
.
nn
.
functional
.
pad
(
torch
.
cumsum
(
vit_token_seqlens
,
dim
=
0
),
(
1
,
0
)
)
cu_seqlens
=
cu_seqlens
.
to
(
torch
.
int32
)
max_seqlen
=
torch
.
max
(
vit_token_seqlens
).
item
()
packed_vit_token_embed
=
self
.
vit_model
(
packed_pixel_values
=
packed_vit_tokens
,
packed_flattened_position_ids
=
packed_vit_position_ids
,
cu_seqlens
=
cu_seqlens
,
max_seqlen
=
max_seqlen
,
)
packed_vit_token_embed
=
self
.
connector
(
packed_vit_token_embed
)
pos_emb
=
self
.
vit_pos_embed
(
packed_vit_position_ids
)
packed_vit_token_embed
=
packed_vit_token_embed
+
pos_emb
if
packed_vit_token_embed
.
dtype
!=
packed_sequence
.
dtype
:
packed_vit_token_embed
=
packed_vit_token_embed
.
to
(
packed_sequence
.
dtype
)
packed_sequence
[
packed_vit_token_indexes
]
=
packed_vit_token_embed
extra_inputs
=
{}
if
self
.
use_moe
:
extra_inputs
=
{
"mode"
:
"und"
}
output
=
self
.
language_model
.
forward_inference
(
packed_query_sequence
=
packed_sequence
,
query_lens
=
packed_seqlens
,
packed_query_position_ids
=
packed_position_ids
,
packed_query_indexes
=
packed_indexes
,
past_key_values
=
past_key_values
,
packed_key_value_indexes
=
packed_key_value_indexes
,
key_values_lens
=
key_values_lens
,
update_past_key_values
=
True
,
is_causal
=
False
,
**
extra_inputs
,
)
past_key_values
=
output
.
past_key_values
return
past_key_values
def
prepare_vae_images
(
self
,
curr_kvlens
,
curr_rope
,
images
,
transforms
,
new_token_ids
,
timestep
=
0
):
patchified_vae_latent_shapes
,
packed_vae_position_ids
=
list
(),
list
()
packed_vae_token_indexes
=
list
()
packed_text_ids
,
packed_text_indexes
=
list
(),
list
()
packed_seqlens
,
packed_position_ids
,
packed_indexes
=
list
(),
list
(),
list
()
packed_key_value_indexes
=
list
()
_curr
=
curr
=
0
vae_image_tensors
=
list
()
newlens
,
new_rope
=
list
(),
list
()
for
image
,
curr_kvlen
,
curr_position_id
in
zip
(
images
,
curr_kvlens
,
curr_rope
):
packed_key_value_indexes
.
extend
(
range
(
curr
,
curr
+
curr_kvlen
))
curr
+=
curr_kvlen
packed_text_ids
.
append
(
new_token_ids
[
"start_of_image"
])
packed_text_indexes
.
append
(
_curr
)
packed_indexes
.
append
(
curr
)
curr
+=
1
_curr
+=
1
image_tensor
=
transforms
(
image
)
vae_image_tensors
.
append
(
image_tensor
)
vae_posiiton_ids
=
self
.
get_flattened_position_ids
(
image_tensor
.
size
(
1
),
image_tensor
.
size
(
2
),
self
.
latent_downsample
,
max_num_patches_per_side
=
self
.
max_latent_size
,
)
packed_vae_position_ids
.
append
(
vae_posiiton_ids
)
H
,
W
=
image_tensor
.
shape
[
1
:]
h
=
H
//
self
.
latent_downsample
w
=
W
//
self
.
latent_downsample
patchified_vae_latent_shapes
.
append
((
h
,
w
))
num_img_tokens
=
w
*
h
packed_vae_token_indexes
.
extend
(
range
(
_curr
,
_curr
+
num_img_tokens
))
packed_indexes
.
extend
(
range
(
curr
,
curr
+
num_img_tokens
))
curr
+=
num_img_tokens
_curr
+=
num_img_tokens
packed_text_ids
.
append
(
new_token_ids
[
"end_of_image"
])
packed_text_indexes
.
append
(
_curr
)
packed_indexes
.
append
(
curr
)
curr
+=
1
_curr
+=
1
packed_position_ids
.
extend
([
curr_position_id
]
*
(
num_img_tokens
+
2
))
packed_seqlens
.
append
(
num_img_tokens
+
2
)
newlens
.
append
(
curr_kvlen
+
num_img_tokens
+
2
)
new_rope
.
append
(
curr_position_id
+
1
)
image_sizes
=
[
item
.
shape
for
item
in
vae_image_tensors
]
max_image_size
=
[
max
(
item
)
for
item
in
list
(
zip
(
*
image_sizes
))]
padded_images
=
torch
.
zeros
(
size
=
(
len
(
vae_image_tensors
),
*
max_image_size
))
for
i
,
image_tensor
in
enumerate
(
vae_image_tensors
):
padded_images
[
i
,
:,
:
image_tensor
.
shape
[
1
],
:
image_tensor
.
shape
[
2
]]
=
(
image_tensor
)
generation_input
=
{
"padded_images"
:
padded_images
,
"patchified_vae_latent_shapes"
:
patchified_vae_latent_shapes
,
"packed_vae_position_ids"
:
torch
.
cat
(
packed_vae_position_ids
,
dim
=
0
),
"packed_timesteps"
:
torch
.
tensor
([
timestep
]),
"packed_vae_token_indexes"
:
torch
.
tensor
(
packed_vae_token_indexes
,
dtype
=
torch
.
long
),
"packed_text_ids"
:
torch
.
tensor
(
packed_text_ids
,
dtype
=
torch
.
long
),
"packed_text_indexes"
:
torch
.
tensor
(
packed_text_indexes
,
dtype
=
torch
.
long
),
"packed_position_ids"
:
torch
.
tensor
(
packed_position_ids
,
dtype
=
torch
.
long
),
"packed_seqlens"
:
torch
.
tensor
(
packed_seqlens
,
dtype
=
torch
.
int
),
"packed_indexes"
:
torch
.
tensor
(
packed_indexes
,
dtype
=
torch
.
long
),
"packed_key_value_indexes"
:
torch
.
tensor
(
packed_key_value_indexes
,
dtype
=
torch
.
long
),
"key_values_lens"
:
torch
.
tensor
(
curr_kvlens
,
dtype
=
torch
.
int
),
}
return
generation_input
,
newlens
,
new_rope
@
torch
.
no_grad
def
forward_cache_update_vae
(
self
,
vae_model
,
past_key_values
:
NaiveCache
,
padded_images
:
torch
.
Tensor
,
patchified_vae_latent_shapes
:
List
,
packed_vae_position_ids
:
torch
.
LongTensor
,
packed_timesteps
:
torch
.
Tensor
,
packed_vae_token_indexes
:
torch
.
LongTensor
,
packed_text_ids
:
torch
.
LongTensor
,
packed_text_indexes
:
torch
.
LongTensor
,
packed_position_ids
:
torch
.
LongTensor
,
packed_seqlens
:
torch
.
IntTensor
,
packed_indexes
:
torch
.
LongTensor
,
key_values_lens
:
torch
.
IntTensor
,
packed_key_value_indexes
:
torch
.
Tensor
,
):
packed_text_embedding
=
self
.
language_model
.
model
.
embed_tokens
(
packed_text_ids
)
packed_sequence
=
packed_text_embedding
.
new_zeros
(
(
sum
(
packed_seqlens
),
self
.
hidden_size
)
)
packed_sequence
[
packed_text_indexes
]
=
packed_text_embedding
padded_latent
=
vae_model
.
encode
(
padded_images
)
p
=
self
.
latent_patch_size
packed_latent
=
list
()
for
latent
,
(
h
,
w
)
in
zip
(
padded_latent
,
patchified_vae_latent_shapes
):
latent
=
latent
[:,
:
h
*
p
,
:
w
*
p
].
reshape
(
self
.
latent_channel
,
h
,
p
,
w
,
p
)
latent
=
torch
.
einsum
(
"chpwq->hwpqc"
,
latent
).
reshape
(
-
1
,
p
*
p
*
self
.
latent_channel
)
packed_latent
.
append
(
latent
)
packed_latent
=
torch
.
cat
(
packed_latent
,
dim
=
0
)
packed_pos_embed
=
self
.
latent_pos_embed
(
packed_vae_position_ids
)
packed_timestep_embeds
=
self
.
time_embedder
(
packed_timesteps
)
packed_latent
=
(
self
.
vae2llm
(
packed_latent
)
+
packed_timestep_embeds
+
packed_pos_embed
)
if
packed_latent
.
dtype
!=
packed_sequence
.
dtype
:
packed_latent
=
packed_latent
.
to
(
packed_sequence
.
dtype
)
packed_sequence
[
packed_vae_token_indexes
]
=
packed_latent
extra_inputs
=
{}
if
self
.
use_moe
:
extra_inputs
=
{
"mode"
:
"gen"
,
"packed_vae_token_indexes"
:
packed_vae_token_indexes
,
"packed_text_indexes"
:
packed_text_indexes
,
}
output
=
self
.
language_model
.
forward_inference
(
packed_query_sequence
=
packed_sequence
,
query_lens
=
packed_seqlens
,
packed_query_position_ids
=
packed_position_ids
,
packed_query_indexes
=
packed_indexes
,
past_key_values
=
past_key_values
,
key_values_lens
=
key_values_lens
,
packed_key_value_indexes
=
packed_key_value_indexes
,
update_past_key_values
=
True
,
is_causal
=
False
,
**
extra_inputs
,
)
past_key_values
=
output
.
past_key_values
return
past_key_values
def
prepare_vae_latent
(
self
,
curr_kvlens
,
curr_rope
,
image_sizes
,
new_token_ids
):
packed_text_ids
,
packed_text_indexes
=
list
(),
list
()
packed_vae_position_ids
,
packed_vae_token_indexes
,
packed_init_noises
=
(
list
(),
list
(),
list
(),
)
packed_position_ids
,
packed_seqlens
,
packed_indexes
=
list
(),
list
(),
list
()
packed_key_value_indexes
=
list
()
query_curr
=
curr
=
0
for
(
H
,
W
),
curr_kvlen
,
curr_position_id
in
zip
(
image_sizes
,
curr_kvlens
,
curr_rope
):
packed_key_value_indexes
.
extend
(
range
(
curr
,
curr
+
curr_kvlen
))
curr
+=
curr_kvlen
packed_text_ids
.
append
(
new_token_ids
[
"start_of_image"
])
packed_text_indexes
.
append
(
query_curr
)
packed_indexes
.
append
(
curr
)
curr
+=
1
query_curr
+=
1
vae_posiiton_ids
=
self
.
get_flattened_position_ids
(
H
,
W
,
self
.
latent_downsample
,
max_num_patches_per_side
=
self
.
max_latent_size
,
)
packed_vae_position_ids
.
append
(
vae_posiiton_ids
)
h
,
w
=
H
//
self
.
latent_downsample
,
W
//
self
.
latent_downsample
num_image_tokens
=
h
*
w
packed_init_noises
.
append
(
torch
.
randn
(
num_image_tokens
,
self
.
latent_channel
*
self
.
latent_patch_size
**
2
)
)
packed_vae_token_indexes
.
extend
(
range
(
query_curr
,
query_curr
+
num_image_tokens
)
)
packed_indexes
.
extend
(
range
(
curr
,
curr
+
num_image_tokens
))
curr
+=
num_image_tokens
query_curr
+=
num_image_tokens
packed_text_ids
.
append
(
new_token_ids
[
"end_of_image"
])
packed_text_indexes
.
append
(
query_curr
)
packed_indexes
.
append
(
curr
)
curr
+=
1
query_curr
+=
1
packed_position_ids
.
extend
([
curr_position_id
]
*
(
num_image_tokens
+
2
))
packed_seqlens
.
append
(
num_image_tokens
+
2
)
generation_input
=
{
"packed_text_ids"
:
torch
.
tensor
(
packed_text_ids
,
dtype
=
torch
.
long
),
"packed_text_indexes"
:
torch
.
tensor
(
packed_text_indexes
,
dtype
=
torch
.
long
),
"packed_init_noises"
:
torch
.
cat
(
packed_init_noises
,
dim
=
0
),
"packed_vae_position_ids"
:
torch
.
cat
(
packed_vae_position_ids
,
dim
=
0
),
"packed_vae_token_indexes"
:
torch
.
tensor
(
packed_vae_token_indexes
,
dtype
=
torch
.
long
),
"packed_seqlens"
:
torch
.
tensor
(
packed_seqlens
,
dtype
=
torch
.
int
),
"packed_position_ids"
:
torch
.
tensor
(
packed_position_ids
,
dtype
=
torch
.
long
),
"key_values_lens"
:
torch
.
tensor
(
curr_kvlens
,
dtype
=
torch
.
int
),
"packed_indexes"
:
torch
.
tensor
(
packed_indexes
,
dtype
=
torch
.
long
),
"packed_key_value_indexes"
:
torch
.
tensor
(
packed_key_value_indexes
,
dtype
=
torch
.
long
),
}
return
generation_input
def
prepare_vae_latent_cfg
(
self
,
curr_kvlens
,
curr_rope
,
image_sizes
):
packed_position_ids
,
packed_indexes
,
packed_key_value_indexes
=
(
list
(),
list
(),
list
(),
)
query_curr
=
curr
=
0
for
(
H
,
W
),
curr_kvlen
,
curr_position_id
in
zip
(
image_sizes
,
curr_kvlens
,
curr_rope
):
packed_key_value_indexes
.
extend
(
range
(
curr
,
curr
+
curr_kvlen
))
curr
+=
curr_kvlen
packed_indexes
.
append
(
curr
)
curr
+=
1
query_curr
+=
1
h
,
w
=
H
//
self
.
latent_downsample
,
W
//
self
.
latent_downsample
num_image_tokens
=
h
*
w
packed_indexes
.
extend
(
range
(
curr
,
curr
+
num_image_tokens
))
curr
+=
num_image_tokens
query_curr
+=
num_image_tokens
packed_indexes
.
append
(
curr
)
curr
+=
1
query_curr
+=
1
packed_position_ids
.
extend
([
curr_position_id
]
*
(
num_image_tokens
+
2
))
generation_input
=
{
"cfg_packed_position_ids"
:
torch
.
tensor
(
packed_position_ids
,
dtype
=
torch
.
long
),
"cfg_key_values_lens"
:
torch
.
tensor
(
curr_kvlens
,
dtype
=
torch
.
int
),
"cfg_packed_query_indexes"
:
torch
.
tensor
(
packed_indexes
,
dtype
=
torch
.
long
),
"cfg_packed_key_value_indexes"
:
torch
.
tensor
(
packed_key_value_indexes
,
dtype
=
torch
.
long
),
}
return
generation_input
@
torch
.
no_grad
def
generate_image
(
self
,
packed_text_ids
:
torch
.
LongTensor
,
packed_text_indexes
:
torch
.
LongTensor
,
packed_init_noises
:
torch
.
Tensor
,
packed_vae_position_ids
:
torch
.
LongTensor
,
packed_vae_token_indexes
:
torch
.
LongTensor
,
packed_seqlens
:
torch
.
IntTensor
,
packed_position_ids
:
torch
.
LongTensor
,
packed_indexes
:
torch
.
LongTensor
,
past_key_values
:
NaiveCache
,
key_values_lens
:
torch
.
IntTensor
,
packed_key_value_indexes
:
torch
.
LongTensor
,
num_timesteps
:
int
=
24
,
timestep_shift
:
float
=
1.0
,
cfg_renorm_min
:
float
=
0.0
,
cfg_renorm_type
:
str
=
"global"
,
cfg_interval
:
Optional
[
Tuple
[
float
,
float
]]
=
[
0
,
1
],
# cfg_text
cfg_text_scale
:
float
=
1.0
,
cfg_text_packed_query_indexes
:
Optional
[
torch
.
LongTensor
]
=
None
,
cfg_text_packed_position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
cfg_text_past_key_values
:
Optional
[
NaiveCache
]
=
None
,
cfg_text_key_values_lens
:
Optional
[
torch
.
IntTensor
]
=
None
,
cfg_text_packed_key_value_indexes
:
Optional
[
torch
.
LongTensor
]
=
None
,
# cfg_img
cfg_img_scale
:
float
=
1.0
,
cfg_img_packed_query_indexes
:
Optional
[
torch
.
LongTensor
]
=
None
,
cfg_img_packed_position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
cfg_img_past_key_values
:
Optional
[
NaiveCache
]
=
None
,
cfg_img_key_values_lens
:
Optional
[
torch
.
IntTensor
]
=
None
,
cfg_img_packed_key_value_indexes
:
Optional
[
torch
.
LongTensor
]
=
None
,
cfg_type
:
str
=
"parallel"
,
):
x_t
=
packed_init_noises
timesteps
=
torch
.
linspace
(
1
,
0
,
num_timesteps
,
device
=
x_t
.
device
)
timesteps
=
timestep_shift
*
timesteps
/
(
1
+
(
timestep_shift
-
1
)
*
timesteps
)
dts
=
timesteps
[:
-
1
]
-
timesteps
[
1
:]
timesteps
=
timesteps
[:
-
1
]
for
i
,
t
in
tqdm
(
enumerate
(
timesteps
),
total
=
len
(
timesteps
)):
timestep
=
torch
.
tensor
([
t
]
*
x_t
.
shape
[
0
],
device
=
x_t
.
device
)
if
t
>
cfg_interval
[
0
]
and
t
<=
cfg_interval
[
1
]:
cfg_text_scale_
=
cfg_text_scale
cfg_img_scale_
=
cfg_img_scale
else
:
cfg_text_scale_
=
1.0
cfg_img_scale_
=
1.0
v_t
=
self
.
_forward_flow
(
x_t
=
x_t
,
timestep
=
timestep
,
packed_vae_token_indexes
=
packed_vae_token_indexes
,
packed_vae_position_ids
=
packed_vae_position_ids
,
packed_text_ids
=
packed_text_ids
,
packed_text_indexes
=
packed_text_indexes
,
packed_position_ids
=
packed_position_ids
,
packed_indexes
=
packed_indexes
,
packed_seqlens
=
packed_seqlens
,
key_values_lens
=
key_values_lens
,
past_key_values
=
past_key_values
,
packed_key_value_indexes
=
packed_key_value_indexes
,
cfg_renorm_min
=
cfg_renorm_min
,
cfg_renorm_type
=
cfg_renorm_type
,
# cfg_text
cfg_text_scale
=
cfg_text_scale_
,
cfg_text_packed_position_ids
=
cfg_text_packed_position_ids
,
cfg_text_packed_query_indexes
=
cfg_text_packed_query_indexes
,
cfg_text_key_values_lens
=
cfg_text_key_values_lens
,
cfg_text_past_key_values
=
cfg_text_past_key_values
,
cfg_text_packed_key_value_indexes
=
cfg_text_packed_key_value_indexes
,
# cfg_img
cfg_img_scale
=
cfg_img_scale_
,
cfg_img_packed_position_ids
=
cfg_img_packed_position_ids
,
cfg_img_packed_query_indexes
=
cfg_img_packed_query_indexes
,
cfg_img_key_values_lens
=
cfg_img_key_values_lens
,
cfg_img_past_key_values
=
cfg_img_past_key_values
,
cfg_img_packed_key_value_indexes
=
cfg_img_packed_key_value_indexes
,
cfg_type
=
cfg_type
,
)
x_t
=
(
x_t
-
v_t
.
to
(
x_t
.
device
)
*
dts
[
i
]
)
# velocity pointing from data to noise
unpacked_latent
=
x_t
.
split
((
packed_seqlens
-
2
).
tolist
())
return
unpacked_latent
@
torch
.
no_grad
def
_forward_flow
(
self
,
x_t
:
torch
.
Tensor
,
timestep
:
torch
.
LongTensor
,
packed_vae_token_indexes
:
torch
.
LongTensor
,
packed_vae_position_ids
:
torch
.
LongTensor
,
packed_text_ids
:
torch
.
LongTensor
,
packed_text_indexes
:
torch
.
LongTensor
,
packed_indexes
:
torch
.
LongTensor
,
packed_position_ids
:
torch
.
LongTensor
,
packed_seqlens
:
torch
.
IntTensor
,
key_values_lens
:
torch
.
IntTensor
,
past_key_values
:
NaiveCache
,
packed_key_value_indexes
:
torch
.
LongTensor
,
cfg_renorm_min
:
float
=
0.0
,
cfg_renorm_type
:
str
=
"global"
,
# cfg_text
cfg_text_scale
:
float
=
1.0
,
cfg_text_packed_position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
cfg_text_packed_query_indexes
:
Optional
[
torch
.
LongTensor
]
=
None
,
cfg_text_key_values_lens
:
Optional
[
torch
.
Tensor
]
=
None
,
cfg_text_past_key_values
:
Optional
[
NaiveCache
]
=
None
,
cfg_text_packed_key_value_indexes
:
Optional
[
torch
.
LongTensor
]
=
None
,
# cfg_img
cfg_img_scale
:
float
=
1.0
,
cfg_img_packed_position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
cfg_img_packed_query_indexes
:
Optional
[
torch
.
LongTensor
]
=
None
,
cfg_img_key_values_lens
:
Optional
[
torch
.
Tensor
]
=
None
,
cfg_img_past_key_values
:
Optional
[
NaiveCache
]
=
None
,
cfg_img_packed_key_value_indexes
:
Optional
[
torch
.
LongTensor
]
=
None
,
cfg_type
:
str
=
"parallel"
,
):
packed_text_embedding
=
self
.
language_model
.
model
.
embed_tokens
(
packed_text_ids
)
packed_sequence
=
packed_text_embedding
.
new_zeros
(
(
sum
(
packed_seqlens
),
self
.
hidden_size
)
)
packed_sequence
[
packed_text_indexes
]
=
packed_text_embedding
assert
timestep
.
unique
().
shape
[
0
]
==
1
packed_pos_embed
=
self
.
latent_pos_embed
(
packed_vae_position_ids
)
packed_timestep_embeds
=
self
.
time_embedder
(
timestep
)
x_t
=
self
.
vae2llm
(
x_t
)
+
packed_timestep_embeds
+
packed_pos_embed
if
x_t
.
dtype
!=
packed_sequence
.
dtype
:
x_t
=
x_t
.
to
(
packed_sequence
.
dtype
)
packed_sequence
[
packed_vae_token_indexes
]
=
x_t
extra_inputs
=
{}
if
self
.
use_moe
:
extra_inputs
=
{
"mode"
:
"gen"
,
"packed_vae_token_indexes"
:
packed_vae_token_indexes
,
"packed_text_indexes"
:
packed_text_indexes
,
}
output
=
self
.
language_model
.
forward_inference
(
packed_query_sequence
=
packed_sequence
,
query_lens
=
packed_seqlens
,
packed_query_position_ids
=
packed_position_ids
,
packed_query_indexes
=
packed_indexes
,
past_key_values
=
past_key_values
,
key_values_lens
=
key_values_lens
,
packed_key_value_indexes
=
packed_key_value_indexes
,
update_past_key_values
=
False
,
is_causal
=
False
,
**
extra_inputs
,
)
v_t
=
self
.
llm2vae
(
output
.
packed_query_sequence
)
v_t
=
v_t
[
packed_vae_token_indexes
]
if
cfg_text_scale
>
1.0
:
cfg_text_output
=
self
.
language_model
.
forward_inference
(
packed_query_sequence
=
packed_sequence
,
query_lens
=
packed_seqlens
,
packed_query_position_ids
=
cfg_text_packed_position_ids
,
packed_query_indexes
=
cfg_text_packed_query_indexes
,
past_key_values
=
cfg_text_past_key_values
,
key_values_lens
=
cfg_text_key_values_lens
,
packed_key_value_indexes
=
cfg_text_packed_key_value_indexes
,
update_past_key_values
=
False
,
is_causal
=
False
,
**
extra_inputs
,
)
cfg_text_v_t
=
self
.
llm2vae
(
cfg_text_output
.
packed_query_sequence
)
cfg_text_v_t
=
cfg_text_v_t
[
packed_vae_token_indexes
]
if
cfg_img_scale
>
1.0
:
cfg_img_output
=
self
.
language_model
.
forward_inference
(
packed_query_sequence
=
packed_sequence
,
query_lens
=
packed_seqlens
,
packed_query_position_ids
=
cfg_img_packed_position_ids
,
packed_query_indexes
=
cfg_img_packed_query_indexes
,
past_key_values
=
cfg_img_past_key_values
,
key_values_lens
=
cfg_img_key_values_lens
,
packed_key_value_indexes
=
cfg_img_packed_key_value_indexes
,
update_past_key_values
=
False
,
is_causal
=
False
,
**
extra_inputs
,
)
cfg_img_v_t
=
self
.
llm2vae
(
cfg_img_output
.
packed_query_sequence
)
cfg_img_v_t
=
cfg_img_v_t
[
packed_vae_token_indexes
]
if
cfg_text_scale
>
1.0
:
if
cfg_renorm_type
==
"text_channel"
:
v_t_text_
=
cfg_text_v_t
+
cfg_text_scale
*
(
v_t
-
cfg_text_v_t
)
norm_v_t
=
torch
.
norm
(
v_t
,
dim
=-
1
,
keepdim
=
True
)
norm_v_t_text_
=
torch
.
norm
(
v_t_text_
,
dim
=-
1
,
keepdim
=
True
)
scale
=
(
norm_v_t
/
(
norm_v_t_text_
+
1e-8
)).
clamp
(
min
=
cfg_renorm_min
,
max
=
1.0
)
v_t_text
=
v_t_text_
*
scale
if
cfg_img_scale
>
1.0
:
v_t
=
cfg_img_v_t
+
cfg_img_scale
*
(
v_t_text
-
cfg_img_v_t
)
else
:
v_t
=
v_t_text
else
:
v_t_text_
=
cfg_text_v_t
+
cfg_text_scale
*
(
v_t
-
cfg_text_v_t
)
if
cfg_img_scale
>
1.0
:
v_t_
=
cfg_img_v_t
+
cfg_img_scale
*
(
v_t_text_
-
cfg_img_v_t
)
else
:
v_t_
=
v_t_text_
# NOTE norm is computed over all dimensions, thus currently only supports batch_size = 1 with navit
if
cfg_renorm_type
==
"global"
:
norm_v_t
=
torch
.
norm
(
v_t
)
norm_v_t_
=
torch
.
norm
(
v_t_
)
elif
cfg_renorm_type
==
"channel"
:
norm_v_t
=
torch
.
norm
(
v_t
,
dim
=-
1
,
keepdim
=
True
)
norm_v_t_
=
torch
.
norm
(
v_t_
,
dim
=-
1
,
keepdim
=
True
)
else
:
raise
NotImplementedError
(
f
"
{
cfg_renorm_type
}
is not suppoprted"
)
scale
=
(
norm_v_t
/
(
norm_v_t_
+
1e-8
)).
clamp
(
min
=
cfg_renorm_min
,
max
=
1.0
)
v_t
=
v_t_
*
scale
else
:
# No CFG
pass
return
v_t
def
prepare_start_tokens
(
self
,
curr_kvlens
,
curr_rope
,
new_token_ids
):
packed_start_tokens
,
packed_key_value_indexes
=
list
(),
list
()
packed_query_position_ids
=
list
()
curr
=
0
for
curr_kvlen
,
curr_position_id
in
zip
(
curr_kvlens
,
curr_rope
):
packed_key_value_indexes
.
extend
(
range
(
curr
,
curr
+
curr_kvlen
))
packed_start_tokens
.
append
(
new_token_ids
[
"bos_token_id"
])
packed_query_position_ids
.
append
(
curr_position_id
)
curr
+=
curr_kvlen
generation_input
=
{
"packed_start_tokens"
:
torch
.
tensor
(
packed_start_tokens
,
dtype
=
torch
.
long
),
"packed_query_position_ids"
:
torch
.
tensor
(
packed_query_position_ids
,
dtype
=
torch
.
long
),
"key_values_lens"
:
torch
.
tensor
(
curr_kvlens
,
dtype
=
torch
.
int
),
"packed_key_value_indexes"
:
torch
.
tensor
(
packed_key_value_indexes
,
dtype
=
torch
.
long
),
}
return
generation_input
@
torch
.
no_grad
def
generate_text
(
self
,
past_key_values
:
NaiveCache
,
packed_key_value_indexes
:
torch
.
LongTensor
,
key_values_lens
:
torch
.
IntTensor
,
packed_start_tokens
:
torch
.
LongTensor
,
packed_query_position_ids
:
torch
.
LongTensor
,
max_length
:
int
,
do_sample
:
bool
=
False
,
temperature
:
float
=
1.0
,
end_token_id
:
int
=
None
,
):
step
=
0
generated_sequence
=
[]
curr_tokens
=
packed_start_tokens
while
step
<
max_length
:
generated_sequence
.
append
(
curr_tokens
)
packed_text_embedding
=
self
.
language_model
.
model
.
embed_tokens
(
curr_tokens
)
query_lens
=
torch
.
ones_like
(
curr_tokens
)
packed_query_indexes
=
torch
.
cumsum
(
key_values_lens
,
dim
=
0
)
+
torch
.
arange
(
0
,
len
(
key_values_lens
),
device
=
key_values_lens
.
device
,
dtype
=
key_values_lens
.
dtype
,
)
uppacked
=
list
(
packed_key_value_indexes
.
split
(
key_values_lens
.
tolist
(),
dim
=
0
)
)
for
i
in
range
(
len
(
uppacked
)):
uppacked
[
i
]
+=
i
packed_key_value_indexes
=
torch
.
cat
(
uppacked
,
dim
=
0
)
extra_inputs
=
{}
if
self
.
use_moe
:
extra_inputs
=
{
"mode"
:
"und"
}
output
=
self
.
language_model
.
forward_inference
(
packed_query_sequence
=
packed_text_embedding
,
query_lens
=
query_lens
,
packed_query_position_ids
=
packed_query_position_ids
,
packed_query_indexes
=
packed_query_indexes
,
past_key_values
=
past_key_values
,
key_values_lens
=
key_values_lens
,
packed_key_value_indexes
=
packed_key_value_indexes
,
update_past_key_values
=
True
,
is_causal
=
True
,
**
extra_inputs
,
)
past_key_values
=
output
.
past_key_values
packed_query_sequence
=
output
.
packed_query_sequence
pred_logits
=
self
.
language_model
.
lm_head
(
packed_query_sequence
)
if
do_sample
:
probs
=
nn
.
functional
.
softmax
(
pred_logits
/
temperature
,
dim
=-
1
)
curr_tokens
=
torch
.
multinomial
(
probs
,
num_samples
=
1
).
squeeze
(
1
)
else
:
curr_tokens
=
torch
.
argmax
(
pred_logits
,
dim
=-
1
)
uppacked
=
list
(
packed_key_value_indexes
.
split
(
key_values_lens
.
tolist
(),
dim
=
0
)
)
for
i
in
range
(
len
(
uppacked
)):
uppacked
[
i
]
=
torch
.
cat
(
[
uppacked
[
i
],
torch
.
tensor
([
uppacked
[
i
][
-
1
]
+
1
],
device
=
uppacked
[
i
].
device
),
],
dim
=
0
,
)
packed_key_value_indexes
=
torch
.
cat
(
uppacked
,
dim
=
0
)
key_values_lens
=
key_values_lens
+
1
packed_query_position_ids
=
packed_query_position_ids
+
1
step
+=
1
if
(
end_token_id
is
not
None
and
curr_tokens
[
0
]
==
end_token_id
):
# only support batch=1
break
output_device
=
generated_sequence
[
0
].
device
return
torch
.
stack
([
i
.
to
(
output_device
)
for
i
in
generated_sequence
],
dim
=
0
)
# for evaluation
@
torch
.
no_grad
()
def
chat
(
self
,
tokenizer
,
new_token_ids
,
image_transform
,
images
,
prompt
,
max_length
:
int
,
do_sample
:
bool
=
False
,
temperature
:
float
=
1.0
,
):
device
=
next
(
self
.
parameters
()).
device
if
isinstance
(
new_token_ids
,
dict
):
for
k
,
v
in
new_token_ids
.
items
():
if
torch
.
is_tensor
(
v
):
new_token_ids
[
k
]
=
v
.
to
(
device
)
elif
torch
.
is_tensor
(
new_token_ids
):
new_token_ids
=
new_token_ids
.
to
(
device
)
# prefill
past_key_values
=
NaiveCache
(
self
.
config
.
llm_config
.
num_hidden_layers
)
newlens
=
[
0
]
new_rope
=
[
0
]
# add images
for
image
in
images
:
generation_input
,
newlens
,
new_rope
=
self
.
prepare_vit_images
(
curr_kvlens
=
newlens
,
curr_rope
=
new_rope
,
images
=
[
image
],
transforms
=
image_transform
,
new_token_ids
=
new_token_ids
,
)
for
k
,
v
in
generation_input
.
items
():
if
torch
.
is_tensor
(
v
):
generation_input
[
k
]
=
v
.
to
(
device
)
with
torch
.
amp
.
autocast
(
"cuda"
,
enabled
=
True
,
dtype
=
torch
.
bfloat16
):
past_key_values
=
self
.
forward_cache_update_vit
(
past_key_values
,
**
generation_input
)
# add text
generation_input
,
newlens
,
new_rope
=
self
.
prepare_prompts
(
curr_kvlens
=
newlens
,
curr_rope
=
new_rope
,
prompts
=
[
prompt
],
tokenizer
=
tokenizer
,
new_token_ids
=
new_token_ids
,
)
for
k
,
v
in
generation_input
.
items
():
if
torch
.
is_tensor
(
v
):
generation_input
[
k
]
=
v
.
to
(
device
)
with
torch
.
amp
.
autocast
(
"cuda"
,
enabled
=
True
,
dtype
=
torch
.
bfloat16
):
past_key_values
=
self
.
forward_cache_update_text
(
past_key_values
,
**
generation_input
)
# decode
generation_input
=
self
.
prepare_start_tokens
(
newlens
,
new_rope
,
new_token_ids
)
for
k
,
v
in
generation_input
.
items
():
if
torch
.
is_tensor
(
v
):
generation_input
[
k
]
=
v
.
to
(
device
)
with
torch
.
amp
.
autocast
(
"cuda"
,
enabled
=
True
,
dtype
=
torch
.
bfloat16
):
unpacked_latent
=
self
.
generate_text
(
past_key_values
=
past_key_values
,
max_length
=
max_length
,
do_sample
=
do_sample
,
temperature
=
temperature
,
end_token_id
=
new_token_ids
[
"eos_token_id"
],
**
generation_input
,
)
output
=
tokenizer
.
decode
(
unpacked_latent
[:,
0
])
output
=
output
.
split
(
"<|im_end|>"
)[
0
].
split
(
"<|im_start|>"
)[
1
]
return
output
SenseNova-SI-main/training/bagel/modeling/bagel/modeling_utils.py
0 → 100644
View file @
876a36a4
# Copyright (c) 2022 Facebook, Inc. and its affiliates.
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: CC BY-NC 4.0
#
# This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025-05-20.
#
# Original file was released under CC BY-NC 4.0, with the full license text
# available at https://github.com/facebookresearch/DiT/blob/main/LICENSE.txt.
#
# This modified file is released under the same license.
import
math
import
numpy
as
np
import
torch
from
torch
import
nn
from
transformers.activations
import
ACT2FN
# --------------------------------------------------------
# 2D sine-cosine position embedding
# References:
# DiT: https://github.com/facebookresearch/DiT/blob/main/models.py
# --------------------------------------------------------
def
get_2d_sincos_pos_embed
(
embed_dim
,
grid_size
,
cls_token
=
False
,
extra_tokens
=
0
):
grid_h
=
np
.
arange
(
grid_size
,
dtype
=
np
.
float32
)
grid_w
=
np
.
arange
(
grid_size
,
dtype
=
np
.
float32
)
grid
=
np
.
meshgrid
(
grid_w
,
grid_h
)
# here w goes first
grid
=
np
.
stack
(
grid
,
axis
=
0
)
grid
=
grid
.
reshape
([
2
,
1
,
grid_size
,
grid_size
])
pos_embed
=
get_2d_sincos_pos_embed_from_grid
(
embed_dim
,
grid
)
if
cls_token
and
extra_tokens
>
0
:
pos_embed
=
np
.
concatenate
(
[
np
.
zeros
([
extra_tokens
,
embed_dim
]),
pos_embed
],
axis
=
0
)
return
pos_embed
def
get_2d_sincos_pos_embed_from_grid
(
embed_dim
,
grid
):
assert
embed_dim
%
2
==
0
# use half of dimensions to encode grid_h
emb_h
=
get_1d_sincos_pos_embed_from_grid
(
embed_dim
//
2
,
grid
[
0
])
# (H*W, D/2)
emb_w
=
get_1d_sincos_pos_embed_from_grid
(
embed_dim
//
2
,
grid
[
1
])
# (H*W, D/2)
emb
=
np
.
concatenate
([
emb_h
,
emb_w
],
axis
=
1
)
# (H*W, D)
return
emb
def
get_1d_sincos_pos_embed_from_grid
(
embed_dim
,
pos
):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert
embed_dim
%
2
==
0
omega
=
np
.
arange
(
embed_dim
//
2
,
dtype
=
np
.
float64
)
omega
/=
embed_dim
/
2.0
omega
=
1.0
/
10000
**
omega
# (D/2,)
pos
=
pos
.
reshape
(
-
1
)
# (M,)
out
=
np
.
einsum
(
"m,d->md"
,
pos
,
omega
)
# (M, D/2), outer product
emb_sin
=
np
.
sin
(
out
)
# (M, D/2)
emb_cos
=
np
.
cos
(
out
)
# (M, D/2)
emb
=
np
.
concatenate
([
emb_sin
,
emb_cos
],
axis
=
1
)
# (M, D)
return
emb
# --------------------------------------------------------
# TimestepEmbedder
# Reference:
# DiT: https://github.com/facebookresearch/DiT/blob/main/models.py
# --------------------------------------------------------
class
TimestepEmbedder
(
nn
.
Module
):
"""
Embeds scalar timesteps into vector representations.
"""
def
__init__
(
self
,
hidden_size
,
frequency_embedding_size
=
256
):
super
().
__init__
()
self
.
mlp
=
nn
.
Sequential
(
nn
.
Linear
(
frequency_embedding_size
,
hidden_size
,
bias
=
True
),
nn
.
SiLU
(),
nn
.
Linear
(
hidden_size
,
hidden_size
,
bias
=
True
),
)
self
.
frequency_embedding_size
=
frequency_embedding_size
@
staticmethod
def
timestep_embedding
(
t
,
dim
,
max_period
=
10000
):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
half
=
dim
//
2
freqs
=
torch
.
exp
(
-
math
.
log
(
max_period
)
*
torch
.
arange
(
start
=
0
,
end
=
half
,
dtype
=
torch
.
float32
)
/
half
).
to
(
device
=
t
.
device
)
args
=
t
[:,
None
].
float
()
*
freqs
[
None
]
embedding
=
torch
.
cat
([
torch
.
cos
(
args
),
torch
.
sin
(
args
)],
dim
=-
1
)
if
dim
%
2
:
embedding
=
torch
.
cat
(
[
embedding
,
torch
.
zeros_like
(
embedding
[:,
:
1
])],
dim
=-
1
)
return
embedding
def
forward
(
self
,
t
):
t_freq
=
self
.
timestep_embedding
(
t
,
self
.
frequency_embedding_size
)
t_emb
=
self
.
mlp
(
t_freq
)
return
t_emb
class
MLPconnector
(
nn
.
Module
):
def
__init__
(
self
,
in_dim
:
int
,
out_dim
:
int
,
hidden_act
:
str
):
super
().
__init__
()
self
.
activation_fn
=
ACT2FN
[
hidden_act
]
self
.
fc1
=
nn
.
Linear
(
in_dim
,
out_dim
)
self
.
fc2
=
nn
.
Linear
(
out_dim
,
out_dim
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
=
self
.
fc1
(
hidden_states
)
hidden_states
=
self
.
activation_fn
(
hidden_states
)
hidden_states
=
self
.
fc2
(
hidden_states
)
return
hidden_states
class
PositionEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
max_num_patch_per_side
,
hidden_size
):
super
().
__init__
()
self
.
max_num_patch_per_side
=
max_num_patch_per_side
self
.
hidden_size
=
hidden_size
self
.
pos_embed
=
nn
.
Parameter
(
torch
.
zeros
(
max_num_patch_per_side
**
2
,
hidden_size
),
requires_grad
=
False
)
self
.
_init_weights
()
def
_init_weights
(
self
):
# Initialize (and freeze) pos_embed by sin-cos embedding:
pos_embed
=
get_2d_sincos_pos_embed
(
self
.
hidden_size
,
self
.
max_num_patch_per_side
)
self
.
pos_embed
.
data
.
copy_
(
torch
.
from_numpy
(
pos_embed
).
float
())
def
forward
(
self
,
position_ids
):
return
self
.
pos_embed
[
position_ids
]
SenseNova-SI-main/training/bagel/modeling/bagel/qwen2_navit.py
0 → 100644
View file @
876a36a4
# Copyright (c) 2024 The Qwen Team and The HuggingFace Inc. team.
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
#
# This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025-05-20.
#
# Original file was released under Apache-2.0, with the full license text
# available at https://github.com/huggingface/transformers/blob/main/LICENSE.
#
# This modified file is released under the same license.
from
dataclasses
import
dataclass
from
functools
import
partial
from
typing
import
List
,
Optional
,
Tuple
import
torch
from
flash_attn
import
flash_attn_varlen_func
from
torch
import
nn
from
torch.nn.attention
import
SDPBackend
,
sdpa_kernel
from
torch.nn.attention.flex_attention
import
flex_attention
from
torch.nn.functional
import
scaled_dot_product_attention
from
transformers.utils
import
ModelOutput
from
modeling.qwen2.configuration_qwen2
import
Qwen2Config
as
_Qwen2Config
from
modeling.qwen2.modeling_qwen2
import
(
Qwen2Attention
,
Qwen2MLP
,
Qwen2PreTrainedModel
,
Qwen2RMSNorm
,
Qwen2RotaryEmbedding
,
apply_rotary_pos_emb
,
)
torch
.
_dynamo
.
config
.
cache_size_limit
=
512
torch
.
_dynamo
.
config
.
accumulated_cache_size_limit
=
4096
# flex_attention = torch.compile(flex_attention) # , dynamic=True, mode='max-autotune'
flex_attention
=
torch
.
compile
(
flex_attention
)
class
Qwen2Config
(
_Qwen2Config
):
r
"""
This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a
Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of
Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 151936):
Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`Qwen2Model`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 22016):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*, defaults to 32):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 32768):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be tied.
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
accordingly.
Expected contents:
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
use_sliding_window (`bool`, *optional*, defaults to `False`):
Whether to use sliding window attention.
sliding_window (`int`, *optional*, defaults to 4096):
Sliding window attention (SWA) window size. If not specified, will default to `4096`.
max_window_layers (`int`, *optional*, defaults to 28):
The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
```python
>>> from transformers import Qwen2Model, Qwen2Config
>>> # Initializing a Qwen2 style configuration
>>> configuration = Qwen2Config()
>>> # Initializing a model from the Qwen2-7B style configuration
>>> model = Qwen2Model(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type
=
"qwen2"
keys_to_ignore_at_inference
=
[
"past_key_values"
]
def
__init__
(
self
,
vocab_size
=
151936
,
hidden_size
=
4096
,
intermediate_size
=
22016
,
num_hidden_layers
=
32
,
num_attention_heads
=
32
,
num_key_value_heads
=
32
,
hidden_act
=
"silu"
,
max_position_embeddings
=
32768
,
initializer_range
=
0.02
,
rms_norm_eps
=
1e-6
,
use_cache
=
True
,
tie_word_embeddings
=
False
,
rope_theta
=
10000.0
,
rope_scaling
=
None
,
use_sliding_window
=
False
,
sliding_window
=
4096
,
max_window_layers
=
28
,
attention_dropout
=
0.0
,
is_causal
=
True
,
_attn_implementation
=
"flash_attention_2"
,
qk_norm
=
True
,
layer_module
=
"Qwen2DecoderLayer"
,
freeze_und
=
False
,
**
kwargs
,
):
super
().
__init__
(
vocab_size
=
vocab_size
,
hidden_size
=
hidden_size
,
intermediate_size
=
intermediate_size
,
num_hidden_layers
=
num_hidden_layers
,
num_attention_heads
=
num_attention_heads
,
num_key_value_heads
=
num_key_value_heads
,
hidden_act
=
hidden_act
,
max_position_embeddings
=
max_position_embeddings
,
initializer_range
=
initializer_range
,
rms_norm_eps
=
rms_norm_eps
,
use_cache
=
use_cache
,
tie_word_embeddings
=
tie_word_embeddings
,
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
use_sliding_window
=
use_sliding_window
,
sliding_window
=
sliding_window
,
max_window_layers
=
max_window_layers
,
attention_dropout
=
attention_dropout
,
is_causal
=
is_causal
,
_attn_implementation
=
_attn_implementation
,
**
kwargs
,
)
self
.
qk_norm
=
qk_norm
self
.
layer_module
=
layer_module
self
.
freeze_und
=
freeze_und
class
NaiveCache
:
def
__init__
(
self
,
num_layers
):
self
.
key_cache
=
{
k
:
None
for
k
in
range
(
num_layers
)}
self
.
value_cache
=
{
k
:
None
for
k
in
range
(
num_layers
)}
@
property
def
num_layers
(
self
):
return
len
(
self
.
key_cache
)
@
property
def
seq_lens
(
self
):
if
self
.
key_cache
[
0
]
is
not
None
:
return
self
.
key_cache
[
0
].
shape
[
0
]
else
:
return
0
@
dataclass
class
BaseNavitOutputWithPast
(
ModelOutput
):
packed_query_sequence
:
torch
.
FloatTensor
=
None
past_key_values
:
Optional
[
NaiveCache
]
=
None
attentions
:
Optional
[
Tuple
[
torch
.
FloatTensor
]]
=
None
def
pad_sequence
(
tensor
,
pad_size
):
H
,
L
,
D
=
tensor
.
shape
pad_tensor
=
tensor
.
new_zeros
((
H
,
pad_size
,
D
))
return
torch
.
cat
([
tensor
,
pad_tensor
],
dim
=
1
)
class
PackedAttention
(
Qwen2Attention
):
def
__init__
(
self
,
config
,
layer_idx
:
Optional
[
int
]
=
None
):
super
().
__init__
(
config
,
layer_idx
)
if
self
.
config
.
qk_norm
:
self
.
q_norm
=
Qwen2RMSNorm
(
self
.
head_dim
,
eps
=
config
.
rms_norm_eps
)
self
.
k_norm
=
Qwen2RMSNorm
(
self
.
head_dim
,
eps
=
config
.
rms_norm_eps
)
else
:
self
.
q_norm
=
nn
.
Identity
()
self
.
k_norm
=
nn
.
Identity
()
def
forward
(
self
,
*
args
,
**
kwargs
):
if
self
.
training
:
return
self
.
forward_train
(
*
args
,
**
kwargs
)
else
:
return
self
.
forward_inference
(
*
args
,
**
kwargs
)
def
forward_train
(
self
,
packed_sequence
:
torch
.
Tensor
,
sample_lens
:
List
[
int
],
attention_mask
:
List
[
torch
.
Tensor
],
packed_position_embeddings
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
):
packed_query_states
=
self
.
q_proj
(
packed_sequence
).
view
(
-
1
,
self
.
num_heads
,
self
.
head_dim
)
packed_key_states
=
self
.
k_proj
(
packed_sequence
).
view
(
-
1
,
self
.
num_key_value_heads
,
self
.
head_dim
)
packed_value_states
=
self
.
v_proj
(
packed_sequence
).
view
(
-
1
,
self
.
num_key_value_heads
,
self
.
head_dim
)
packed_query_states
=
self
.
q_norm
(
packed_query_states
)
packed_key_states
=
self
.
k_norm
(
packed_key_states
)
packed_cos
,
packed_sin
=
packed_position_embeddings
packed_query_states
,
packed_key_states
=
apply_rotary_pos_emb
(
packed_query_states
,
packed_key_states
,
packed_cos
,
packed_sin
,
unsqueeze_dim
=
1
,
)
if
isinstance
(
attention_mask
,
List
):
packed_key_states
=
packed_key_states
[:,
:,
None
,
:].
repeat
(
1
,
1
,
self
.
num_key_value_groups
,
1
)
packed_key_states
=
packed_key_states
.
reshape
(
-
1
,
self
.
num_heads
,
self
.
head_dim
)
packed_value_states
=
packed_value_states
[:,
:,
None
,
:].
repeat
(
1
,
1
,
self
.
num_key_value_groups
,
1
)
packed_value_states
=
packed_value_states
.
reshape
(
-
1
,
self
.
num_heads
,
self
.
head_dim
)
unpacked_query_states
=
packed_query_states
.
transpose
(
0
,
1
).
split
(
sample_lens
,
dim
=
1
)
unpacked_key_states
=
packed_key_states
.
transpose
(
0
,
1
).
split
(
sample_lens
,
dim
=
1
)
unpacked_value_states
=
packed_value_states
.
transpose
(
0
,
1
).
split
(
sample_lens
,
dim
=
1
)
upacked_attn_output
=
[]
for
(
query_states
,
key_states
,
value_states
,
attention_mask_per_sample
,
)
in
zip
(
unpacked_query_states
,
unpacked_key_states
,
unpacked_value_states
,
attention_mask
,
):
with
sdpa_kernel
(
backends
=
[
SDPBackend
.
EFFICIENT_ATTENTION
]):
attn_output
=
scaled_dot_product_attention
(
query_states
.
to
(
torch
.
bfloat16
).
unsqueeze
(
0
),
key_states
.
to
(
torch
.
bfloat16
).
unsqueeze
(
0
),
value_states
.
to
(
torch
.
bfloat16
).
unsqueeze
(
0
),
attention_mask_per_sample
.
to
(
torch
.
bfloat16
).
unsqueeze
(
0
),
)
upacked_attn_output
.
append
(
attn_output
.
squeeze
(
0
))
packed_attn_output
=
torch
.
cat
(
upacked_attn_output
,
dim
=
1
)
else
:
pad_size
=
sum
(
sample_lens
)
-
packed_query_states
.
shape
[
0
]
packed_query_states
=
pad_sequence
(
packed_query_states
.
permute
(
1
,
0
,
2
),
pad_size
)
packed_key_states
=
pad_sequence
(
packed_key_states
.
permute
(
1
,
0
,
2
),
pad_size
)
packed_value_states
=
pad_sequence
(
packed_value_states
.
permute
(
1
,
0
,
2
),
pad_size
)
packed_attn_output
=
flex_attention
(
packed_query_states
.
unsqueeze
(
0
),
packed_key_states
.
unsqueeze
(
0
),
packed_value_states
.
unsqueeze
(
0
),
enable_gqa
=
True
,
block_mask
=
attention_mask
,
)
end_index
=
packed_attn_output
.
shape
[
2
]
-
pad_size
packed_attn_output
=
packed_attn_output
[
0
,
:,
:
end_index
,
:]
packed_attn_output
=
packed_attn_output
.
transpose
(
0
,
1
).
reshape
(
-
1
,
self
.
hidden_size
)
packed_attn_output
=
self
.
o_proj
(
packed_attn_output
)
return
packed_attn_output
def
forward_inference
(
self
,
packed_query_sequence
:
torch
.
Tensor
,
query_lens
:
torch
.
Tensor
,
packed_query_position_embeddings
:
torch
.
Tensor
,
packed_query_indexes
:
torch
.
Tensor
,
past_key_values
:
Optional
[
NaiveCache
]
=
None
,
key_values_lens
:
Optional
[
torch
.
Tensor
]
=
None
,
packed_key_value_indexes
:
Optional
[
torch
.
Tensor
]
=
None
,
update_past_key_values
=
True
,
is_causal
=
True
,
output_attentions
=
False
,
):
packed_query_states
=
self
.
q_proj
(
packed_query_sequence
).
view
(
-
1
,
self
.
num_heads
,
self
.
head_dim
)
packed_key_states
=
self
.
k_proj
(
packed_query_sequence
).
view
(
-
1
,
self
.
num_key_value_heads
,
self
.
head_dim
)
packed_value_states
=
self
.
v_proj
(
packed_query_sequence
).
view
(
-
1
,
self
.
num_key_value_heads
,
self
.
head_dim
)
packed_query_states
=
self
.
q_norm
(
packed_query_states
)
packed_key_states
=
self
.
k_norm
(
packed_key_states
)
packed_cos
,
packed_sin
=
packed_query_position_embeddings
packed_query_states
,
packed_key_states
=
apply_rotary_pos_emb
(
packed_query_states
,
packed_key_states
,
packed_cos
,
packed_sin
,
unsqueeze_dim
=
1
,
)
packed_query_states
=
packed_query_states
.
to
(
torch
.
bfloat16
)
packed_key_states
=
packed_key_states
.
to
(
torch
.
bfloat16
)
packed_value_states
=
packed_value_states
.
to
(
torch
.
bfloat16
)
if
(
past_key_values
is
not
None
and
past_key_values
.
key_cache
[
self
.
layer_idx
]
is
not
None
):
past_key_states
=
past_key_values
.
key_cache
[
self
.
layer_idx
]
past_value_states
=
past_key_values
.
value_cache
[
self
.
layer_idx
]
seqlens
=
sum
(
query_lens
)
+
sum
(
key_values_lens
)
merged_key_states
=
past_key_states
.
new_zeros
(
(
seqlens
,
self
.
num_key_value_heads
,
self
.
head_dim
)
)
merged_value_states
=
past_key_states
.
new_zeros
(
(
seqlens
,
self
.
num_key_value_heads
,
self
.
head_dim
)
)
merged_key_states
[
packed_query_indexes
]
=
packed_key_states
merged_key_states
[
packed_key_value_indexes
]
=
past_key_states
merged_value_states
[
packed_query_indexes
]
=
packed_value_states
merged_value_states
[
packed_key_value_indexes
]
=
past_value_states
key_values_lens
=
key_values_lens
+
query_lens
else
:
merged_key_states
=
packed_key_states
merged_value_states
=
packed_value_states
key_values_lens
=
query_lens
cu_seqlens_q
=
torch
.
nn
.
functional
.
pad
(
torch
.
cumsum
(
query_lens
,
dim
=
0
),
(
1
,
0
))
cu_seqlens_k
=
torch
.
nn
.
functional
.
pad
(
torch
.
cumsum
(
key_values_lens
,
dim
=
0
),
(
1
,
0
)
)
packed_attn_output
=
flash_attn_varlen_func
(
q
=
packed_query_states
,
k
=
merged_key_states
,
v
=
merged_value_states
,
cu_seqlens_q
=
cu_seqlens_q
.
to
(
torch
.
int32
),
cu_seqlens_k
=
cu_seqlens_k
.
to
(
torch
.
int32
),
max_seqlen_q
=
max
(
query_lens
).
item
(),
max_seqlen_k
=
max
(
key_values_lens
).
item
(),
causal
=
is_causal
,
)
packed_attn_output
=
packed_attn_output
.
reshape
(
-
1
,
self
.
hidden_size
)
packed_attn_output
=
self
.
o_proj
(
packed_attn_output
)
if
update_past_key_values
:
past_key_values
.
key_cache
[
self
.
layer_idx
]
=
merged_key_states
past_key_values
.
value_cache
[
self
.
layer_idx
]
=
merged_value_states
return
packed_attn_output
,
past_key_values
class
PackedAttentionMoT
(
Qwen2Attention
):
def
__init__
(
self
,
config
,
layer_idx
:
Optional
[
int
]
=
None
):
super
().
__init__
(
config
,
layer_idx
)
if
self
.
config
.
qk_norm
:
self
.
q_norm
=
Qwen2RMSNorm
(
self
.
head_dim
,
eps
=
config
.
rms_norm_eps
)
self
.
k_norm
=
Qwen2RMSNorm
(
self
.
head_dim
,
eps
=
config
.
rms_norm_eps
)
self
.
q_norm_moe_gen
=
Qwen2RMSNorm
(
self
.
head_dim
,
eps
=
config
.
rms_norm_eps
)
self
.
k_norm_moe_gen
=
Qwen2RMSNorm
(
self
.
head_dim
,
eps
=
config
.
rms_norm_eps
)
else
:
self
.
q_norm
=
nn
.
Identity
()
self
.
k_norm
=
nn
.
Identity
()
self
.
q_norm_moe_gen
=
nn
.
Identity
()
self
.
k_norm_moe_gen
=
nn
.
Identity
()
self
.
q_proj_moe_gen
=
nn
.
Linear
(
self
.
hidden_size
,
self
.
num_heads
*
self
.
head_dim
,
bias
=
True
)
self
.
k_proj_moe_gen
=
nn
.
Linear
(
self
.
hidden_size
,
self
.
num_key_value_heads
*
self
.
head_dim
,
bias
=
True
)
self
.
v_proj_moe_gen
=
nn
.
Linear
(
self
.
hidden_size
,
self
.
num_key_value_heads
*
self
.
head_dim
,
bias
=
True
)
self
.
o_proj_moe_gen
=
nn
.
Linear
(
self
.
num_heads
*
self
.
head_dim
,
self
.
hidden_size
,
bias
=
False
)
def
forward
(
self
,
*
args
,
**
kwargs
):
if
self
.
training
:
return
self
.
forward_train
(
*
args
,
**
kwargs
)
else
:
return
self
.
forward_inference
(
*
args
,
**
kwargs
)
def
forward_train
(
self
,
packed_sequence
:
torch
.
Tensor
,
sample_lens
:
List
[
int
],
attention_mask
,
packed_position_embeddings
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
packed_und_token_indexes
:
torch
.
LongTensor
,
packed_gen_token_indexes
:
torch
.
LongTensor
,
):
packed_query_states
=
packed_sequence
.
new_zeros
(
(
packed_sequence
.
shape
[
0
],
self
.
num_heads
*
self
.
head_dim
)
)
packed_key_states
=
packed_sequence
.
new_zeros
(
(
packed_sequence
.
shape
[
0
],
self
.
num_key_value_heads
*
self
.
head_dim
)
)
packed_value_states
=
packed_sequence
.
new_zeros
(
(
packed_sequence
.
shape
[
0
],
self
.
num_key_value_heads
*
self
.
head_dim
)
)
packed_sequence_und
=
packed_sequence
[
packed_und_token_indexes
]
packed_sequence_gen
=
packed_sequence
[
packed_gen_token_indexes
]
packed_query_states
[
packed_und_token_indexes
]
=
self
.
q_proj
(
packed_sequence_und
)
packed_query_states
[
packed_gen_token_indexes
]
=
self
.
q_proj_moe_gen
(
packed_sequence_gen
)
packed_key_states
[
packed_und_token_indexes
]
=
self
.
k_proj
(
packed_sequence_und
)
packed_key_states
[
packed_gen_token_indexes
]
=
self
.
k_proj_moe_gen
(
packed_sequence_gen
)
packed_value_states
[
packed_und_token_indexes
]
=
self
.
v_proj
(
packed_sequence_und
)
packed_value_states
[
packed_gen_token_indexes
]
=
self
.
v_proj_moe_gen
(
packed_sequence_gen
)
packed_query_states
=
packed_query_states
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_dim
)
packed_key_states
=
packed_key_states
.
view
(
-
1
,
self
.
num_key_value_heads
,
self
.
head_dim
)
packed_value_states
=
packed_value_states
.
view
(
-
1
,
self
.
num_key_value_heads
,
self
.
head_dim
)
if
self
.
config
.
freeze_und
:
packed_value_states
[
packed_und_token_indexes
]
=
packed_value_states
[
packed_und_token_indexes
].
detach
()
packed_query_states_
=
packed_query_states
.
new_zeros
(
packed_query_states
.
shape
)
packed_key_states_
=
packed_key_states
.
new_zeros
(
packed_key_states
.
shape
)
packed_query_states_
[
packed_und_token_indexes
]
=
self
.
q_norm
(
packed_query_states
[
packed_und_token_indexes
]
)
if
self
.
config
.
freeze_und
:
packed_query_states_
[
packed_und_token_indexes
]
=
packed_query_states_
[
packed_und_token_indexes
].
detach
()
packed_query_states_
[
packed_gen_token_indexes
]
=
self
.
q_norm_moe_gen
(
packed_query_states
[
packed_gen_token_indexes
]
)
packed_key_states_
[
packed_und_token_indexes
]
=
self
.
k_norm
(
packed_key_states
[
packed_und_token_indexes
]
)
if
self
.
config
.
freeze_und
:
packed_key_states_
[
packed_und_token_indexes
]
=
packed_key_states_
[
packed_und_token_indexes
].
detach
()
packed_key_states_
[
packed_gen_token_indexes
]
=
self
.
k_norm_moe_gen
(
packed_key_states
[
packed_gen_token_indexes
]
)
packed_cos
,
packed_sin
=
packed_position_embeddings
packed_query_states_
,
packed_key_states_
=
apply_rotary_pos_emb
(
packed_query_states_
,
packed_key_states_
,
packed_cos
,
packed_sin
,
unsqueeze_dim
=
1
,
)
if
isinstance
(
attention_mask
,
List
):
packed_key_states_
=
packed_key_states_
[:,
:,
None
,
:].
repeat
(
1
,
1
,
self
.
num_key_value_groups
,
1
)
packed_key_states_
=
packed_key_states_
.
reshape
(
-
1
,
self
.
num_heads
,
self
.
head_dim
)
packed_value_states
=
packed_value_states
[:,
:,
None
,
:].
repeat
(
1
,
1
,
self
.
num_key_value_groups
,
1
)
packed_value_states
=
packed_value_states
.
reshape
(
-
1
,
self
.
num_heads
,
self
.
head_dim
)
unpacked_query_states
=
packed_query_states_
.
transpose
(
0
,
1
).
split
(
sample_lens
,
dim
=
1
)
unpacked_key_states
=
packed_key_states_
.
transpose
(
0
,
1
).
split
(
sample_lens
,
dim
=
1
)
unpacked_value_states
=
packed_value_states
.
transpose
(
0
,
1
).
split
(
sample_lens
,
dim
=
1
)
upacked_attn_output
=
[]
for
(
query_states
,
key_states
,
value_states
,
attention_mask_per_sample
,
)
in
zip
(
unpacked_query_states
,
unpacked_key_states
,
unpacked_value_states
,
attention_mask
,
):
with
sdpa_kernel
(
backends
=
[
SDPBackend
.
EFFICIENT_ATTENTION
]):
attn_output
=
scaled_dot_product_attention
(
query_states
.
to
(
torch
.
bfloat16
).
unsqueeze
(
0
),
key_states
.
to
(
torch
.
bfloat16
).
unsqueeze
(
0
),
value_states
.
to
(
torch
.
bfloat16
).
unsqueeze
(
0
),
attention_mask_per_sample
.
to
(
torch
.
bfloat16
).
unsqueeze
(
0
),
)
upacked_attn_output
.
append
(
attn_output
.
squeeze
(
0
))
packed_attn_output
=
torch
.
cat
(
upacked_attn_output
,
dim
=
1
)
else
:
pad_size
=
sum
(
sample_lens
)
-
packed_query_states
.
shape
[
0
]
packed_query_states_
=
pad_sequence
(
packed_query_states_
.
permute
(
1
,
0
,
2
),
pad_size
)
packed_key_states_
=
pad_sequence
(
packed_key_states_
.
permute
(
1
,
0
,
2
),
pad_size
)
packed_value_states
=
pad_sequence
(
packed_value_states
.
permute
(
1
,
0
,
2
),
pad_size
)
packed_attn_output
=
flex_attention
(
packed_query_states_
.
unsqueeze
(
0
),
# 1, num_head, L, head_dim
packed_key_states_
.
unsqueeze
(
0
),
packed_value_states
.
unsqueeze
(
0
),
enable_gqa
=
True
,
block_mask
=
attention_mask
,
)
end_index
=
packed_attn_output
.
shape
[
2
]
-
pad_size
packed_attn_output
=
packed_attn_output
[
0
,
:,
:
end_index
,
:]
packed_attn_output
=
packed_attn_output
.
transpose
(
0
,
1
).
reshape
(
-
1
,
self
.
num_heads
*
self
.
head_dim
)
packed_attn_output_
=
packed_attn_output
.
new_zeros
(
packed_attn_output
.
shape
)
packed_attn_output_
[
packed_und_token_indexes
]
=
self
.
o_proj
(
packed_attn_output
[
packed_und_token_indexes
]
)
packed_attn_output_
[
packed_gen_token_indexes
]
=
self
.
o_proj_moe_gen
(
packed_attn_output
[
packed_gen_token_indexes
]
)
return
packed_attn_output_
def
forward_inference
(
self
,
packed_query_sequence
:
torch
.
Tensor
,
query_lens
:
torch
.
Tensor
,
packed_query_position_embeddings
:
torch
.
Tensor
,
packed_query_indexes
:
torch
.
Tensor
,
past_key_values
:
Optional
[
NaiveCache
]
=
None
,
key_values_lens
:
Optional
[
torch
.
Tensor
]
=
None
,
packed_key_value_indexes
:
Optional
[
torch
.
Tensor
]
=
None
,
update_past_key_values
=
True
,
is_causal
=
True
,
mode
=
"und"
,
packed_vae_token_indexes
=
None
,
packed_text_indexes
=
None
,
):
if
mode
==
"und"
:
packed_query_states
=
self
.
q_proj
(
packed_query_sequence
).
view
(
-
1
,
self
.
num_heads
,
self
.
head_dim
)
packed_key_states
=
self
.
k_proj
(
packed_query_sequence
).
view
(
-
1
,
self
.
num_key_value_heads
,
self
.
head_dim
)
packed_value_states
=
self
.
v_proj
(
packed_query_sequence
).
view
(
-
1
,
self
.
num_key_value_heads
,
self
.
head_dim
)
packed_query_states
=
self
.
q_norm
(
packed_query_states
)
packed_key_states
=
self
.
k_norm
(
packed_key_states
)
elif
mode
==
"gen"
:
packed_query_sequence
=
packed_query_sequence
.
to
(
torch
.
bfloat16
)
packed_query_states
=
packed_query_sequence
.
new_zeros
(
(
packed_query_sequence
.
shape
[
0
],
self
.
num_heads
*
self
.
head_dim
)
)
packed_key_states
=
packed_query_sequence
.
new_zeros
(
(
packed_query_sequence
.
shape
[
0
],
self
.
num_key_value_heads
*
self
.
head_dim
,
)
)
packed_value_states
=
packed_query_sequence
.
new_zeros
(
(
packed_query_sequence
.
shape
[
0
],
self
.
num_key_value_heads
*
self
.
head_dim
,
)
)
packed_text_query_sequence
=
packed_query_sequence
[
packed_text_indexes
]
packed_vae_query_sequence
=
packed_query_sequence
[
packed_vae_token_indexes
]
packed_query_states
[
packed_text_indexes
]
=
self
.
q_proj
(
packed_text_query_sequence
)
packed_query_states
[
packed_vae_token_indexes
]
=
self
.
q_proj_moe_gen
(
packed_vae_query_sequence
)
packed_key_states
[
packed_text_indexes
]
=
self
.
k_proj
(
packed_text_query_sequence
)
packed_key_states
[
packed_vae_token_indexes
]
=
self
.
k_proj_moe_gen
(
packed_vae_query_sequence
)
packed_value_states
[
packed_text_indexes
]
=
self
.
v_proj
(
packed_text_query_sequence
)
packed_value_states
[
packed_vae_token_indexes
]
=
self
.
v_proj_moe_gen
(
packed_vae_query_sequence
)
packed_query_states
=
packed_query_states
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_dim
)
packed_key_states
=
packed_key_states
.
view
(
-
1
,
self
.
num_key_value_heads
,
self
.
head_dim
)
packed_value_states
=
packed_value_states
.
view
(
-
1
,
self
.
num_key_value_heads
,
self
.
head_dim
)
packed_query_states
=
packed_query_states
.
to
(
torch
.
float32
)
packed_query_states
[
packed_text_indexes
]
=
self
.
q_norm
(
packed_query_states
[
packed_text_indexes
]
)
packed_query_states
[
packed_vae_token_indexes
]
=
self
.
q_norm_moe_gen
(
packed_query_states
[
packed_vae_token_indexes
]
)
packed_key_states
=
packed_key_states
.
to
(
torch
.
float32
)
packed_key_states
[
packed_text_indexes
]
=
self
.
k_norm
(
packed_key_states
[
packed_text_indexes
]
)
packed_key_states
[
packed_vae_token_indexes
]
=
self
.
k_norm_moe_gen
(
packed_key_states
[
packed_vae_token_indexes
]
)
packed_cos
,
packed_sin
=
packed_query_position_embeddings
packed_query_states
,
packed_key_states
=
apply_rotary_pos_emb
(
packed_query_states
,
packed_key_states
,
packed_cos
,
packed_sin
,
unsqueeze_dim
=
1
,
)
packed_query_states
=
packed_query_states
.
to
(
torch
.
bfloat16
)
packed_key_states
=
packed_key_states
.
to
(
torch
.
bfloat16
)
packed_value_states
=
packed_value_states
.
to
(
torch
.
bfloat16
)
if
(
past_key_values
is
not
None
and
past_key_values
.
key_cache
[
self
.
layer_idx
]
is
not
None
):
past_key_states
=
past_key_values
.
key_cache
[
self
.
layer_idx
]
past_value_states
=
past_key_values
.
value_cache
[
self
.
layer_idx
]
seqlens
=
sum
(
query_lens
)
+
sum
(
key_values_lens
)
merged_key_states
=
past_key_states
.
new_zeros
(
size
=
[
seqlens
,
self
.
num_key_value_heads
,
self
.
head_dim
]
)
merged_value_states
=
past_key_states
.
new_zeros
(
size
=
[
seqlens
,
self
.
num_key_value_heads
,
self
.
head_dim
]
)
merged_key_states
[
packed_query_indexes
]
=
packed_key_states
merged_key_states
[
packed_key_value_indexes
]
=
past_key_states
merged_value_states
[
packed_query_indexes
]
=
packed_value_states
merged_value_states
[
packed_key_value_indexes
]
=
past_value_states
key_values_lens
=
key_values_lens
+
query_lens
else
:
merged_key_states
=
packed_key_states
merged_value_states
=
packed_value_states
key_values_lens
=
query_lens
cu_seqlens_q
=
torch
.
nn
.
functional
.
pad
(
torch
.
cumsum
(
query_lens
,
dim
=
0
),
(
1
,
0
))
cu_seqlens_k
=
torch
.
nn
.
functional
.
pad
(
torch
.
cumsum
(
key_values_lens
,
dim
=
0
),
(
1
,
0
)
)
packed_attn_output
=
flash_attn_varlen_func
(
q
=
packed_query_states
,
k
=
merged_key_states
,
v
=
merged_value_states
,
cu_seqlens_q
=
cu_seqlens_q
.
to
(
torch
.
int32
),
cu_seqlens_k
=
cu_seqlens_k
.
to
(
torch
.
int32
),
max_seqlen_q
=
max
(
query_lens
).
item
(),
max_seqlen_k
=
max
(
key_values_lens
).
item
(),
causal
=
is_causal
,
)
packed_attn_output
=
packed_attn_output
.
reshape
(
-
1
,
self
.
hidden_size
)
if
mode
==
"und"
:
packed_attn_output
=
self
.
o_proj
(
packed_attn_output
)
elif
mode
==
"gen"
:
packed_attn_output
[
packed_text_indexes
]
=
self
.
o_proj
(
packed_attn_output
[
packed_text_indexes
]
)
packed_attn_output
[
packed_vae_token_indexes
]
=
self
.
o_proj_moe_gen
(
packed_attn_output
[
packed_vae_token_indexes
]
)
if
update_past_key_values
:
past_key_values
.
key_cache
[
self
.
layer_idx
]
=
merged_key_states
past_key_values
.
value_cache
[
self
.
layer_idx
]
=
merged_value_states
return
packed_attn_output
,
past_key_values
class
Qwen2DecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
,
layer_idx
:
Optional
[
int
]
=
None
):
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
self_attn
=
PackedAttention
(
config
,
layer_idx
)
self
.
mlp
=
Qwen2MLP
(
config
)
self
.
input_layernorm
=
Qwen2RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
Qwen2RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
*
args
,
**
kwargs
):
if
self
.
training
:
return
self
.
forward_train
(
*
args
,
**
kwargs
)
else
:
return
self
.
forward_inference
(
*
args
,
**
kwargs
)
def
forward_train
(
self
,
packed_sequence
:
torch
.
Tensor
,
sample_lens
:
List
[
int
],
attention_mask
,
packed_position_embeddings
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
)
->
torch
.
Tensor
:
residual
=
packed_sequence
packed_sequence
=
self
.
input_layernorm
(
packed_sequence
)
# Self Attention
packed_sequence
=
self
.
self_attn
(
packed_sequence
=
packed_sequence
,
sample_lens
=
sample_lens
,
attention_mask
=
attention_mask
,
packed_position_embeddings
=
packed_position_embeddings
,
)
packed_sequence
=
residual
+
packed_sequence
# Fully Connected
residual
=
packed_sequence
packed_sequence
=
self
.
post_attention_layernorm
(
packed_sequence
)
packed_sequence
=
self
.
mlp
(
packed_sequence
)
packed_sequence
=
residual
+
packed_sequence
return
packed_sequence
def
forward_inference
(
self
,
packed_query_sequence
:
torch
.
Tensor
,
query_lens
:
torch
.
Tensor
,
packed_query_position_embeddings
:
torch
.
Tensor
,
packed_query_indexes
:
torch
.
Tensor
,
past_key_values
:
Optional
[
NaiveCache
]
=
None
,
key_values_lens
:
Optional
[
torch
.
Tensor
]
=
None
,
packed_key_value_indexes
:
Optional
[
torch
.
Tensor
]
=
None
,
update_past_key_values
=
True
,
is_causal
=
True
,
)
->
BaseNavitOutputWithPast
:
residual
=
packed_query_sequence
packed_query_sequence
=
self
.
input_layernorm
(
packed_query_sequence
)
# Self Attention
packed_query_sequence
,
past_key_values
=
self
.
self_attn
(
packed_query_sequence
=
packed_query_sequence
,
query_lens
=
query_lens
,
packed_query_position_embeddings
=
packed_query_position_embeddings
,
packed_query_indexes
=
packed_query_indexes
,
past_key_values
=
past_key_values
,
key_values_lens
=
key_values_lens
,
packed_key_value_indexes
=
packed_key_value_indexes
,
update_past_key_values
=
update_past_key_values
,
is_causal
=
is_causal
,
)
packed_query_sequence
=
residual
+
packed_query_sequence
# Fully Connected
residual
=
packed_query_sequence
packed_query_sequence
=
self
.
post_attention_layernorm
(
packed_query_sequence
)
packed_query_sequence
=
self
.
mlp
(
packed_query_sequence
)
packed_query_sequence
=
residual
+
packed_query_sequence
return
packed_query_sequence
,
past_key_values
class
Qwen2MoTDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
,
layer_idx
:
Optional
[
int
]
=
None
,
attn_module
:
Optional
[
Qwen2Attention
]
=
PackedAttentionMoT
,
):
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
freeze_und
=
config
.
freeze_und
self
.
self_attn
=
attn_module
(
config
,
layer_idx
)
self
.
mlp
=
Qwen2MLP
(
config
)
self
.
mlp_moe_gen
=
Qwen2MLP
(
config
)
self
.
input_layernorm
=
Qwen2RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
input_layernorm_moe_gen
=
Qwen2RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
Qwen2RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm_moe_gen
=
Qwen2RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
*
args
,
**
kwargs
):
if
self
.
training
:
return
self
.
forward_train
(
*
args
,
**
kwargs
)
else
:
return
self
.
forward_inference
(
*
args
,
**
kwargs
)
def
forward_train
(
self
,
packed_sequence
:
torch
.
Tensor
,
sample_lens
:
List
[
int
],
attention_mask
,
packed_position_embeddings
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
packed_und_token_indexes
:
torch
.
LongTensor
,
packed_gen_token_indexes
:
torch
.
LongTensor
,
)
->
torch
.
Tensor
:
residual
=
packed_sequence
packed_sequence_
=
packed_sequence
.
new_zeros
(
packed_sequence
.
shape
)
packed_sequence_
[
packed_und_token_indexes
]
=
self
.
input_layernorm
(
packed_sequence
[
packed_und_token_indexes
]
)
packed_sequence_
[
packed_gen_token_indexes
]
=
self
.
input_layernorm_moe_gen
(
packed_sequence
[
packed_gen_token_indexes
]
)
# Self Attention
packed_sequence_
=
self
.
self_attn
(
packed_sequence
=
packed_sequence_
,
sample_lens
=
sample_lens
,
attention_mask
=
attention_mask
,
packed_position_embeddings
=
packed_position_embeddings
,
packed_und_token_indexes
=
packed_und_token_indexes
,
packed_gen_token_indexes
=
packed_gen_token_indexes
,
)
if
self
.
freeze_und
:
packed_sequence_
[
packed_und_token_indexes
]
=
packed_sequence_
[
packed_und_token_indexes
].
detach
()
packed_sequence
=
residual
+
packed_sequence_
# Fully Connected
residual
=
packed_sequence
packed_sequence_
=
packed_sequence
.
new_zeros
(
packed_sequence
.
shape
)
packed_sequence_
[
packed_und_token_indexes
]
=
self
.
mlp
(
self
.
post_attention_layernorm
(
packed_sequence
[
packed_und_token_indexes
])
)
if
self
.
freeze_und
:
packed_sequence_
[
packed_und_token_indexes
]
=
packed_sequence_
[
packed_und_token_indexes
].
detach
()
packed_sequence_
[
packed_gen_token_indexes
]
=
self
.
mlp_moe_gen
(
self
.
post_attention_layernorm_moe_gen
(
packed_sequence
[
packed_gen_token_indexes
]
)
)
packed_sequence
=
residual
+
packed_sequence_
return
packed_sequence
def
forward_inference
(
self
,
packed_query_sequence
:
torch
.
Tensor
,
query_lens
:
torch
.
Tensor
,
packed_query_position_embeddings
:
torch
.
Tensor
,
packed_query_indexes
:
torch
.
Tensor
,
past_key_values
:
Optional
[
NaiveCache
]
=
None
,
key_values_lens
:
Optional
[
torch
.
Tensor
]
=
None
,
packed_key_value_indexes
:
Optional
[
torch
.
Tensor
]
=
None
,
update_past_key_values
=
True
,
is_causal
=
True
,
mode
=
"und"
,
packed_vae_token_indexes
=
None
,
packed_text_indexes
=
None
,
)
->
BaseNavitOutputWithPast
:
residual
=
packed_query_sequence
if
mode
==
"und"
:
packed_query_sequence
=
self
.
input_layernorm
(
packed_query_sequence
)
elif
mode
==
"gen"
:
packed_query_sequence_
=
torch
.
zeros_like
(
packed_query_sequence
)
packed_query_sequence_
[
packed_text_indexes
]
=
self
.
input_layernorm
(
packed_query_sequence
[
packed_text_indexes
]
)
packed_query_sequence_
[
packed_vae_token_indexes
]
=
(
self
.
input_layernorm_moe_gen
(
packed_query_sequence
[
packed_vae_token_indexes
]
)
)
packed_query_sequence
=
packed_query_sequence_
# Self Attention
packed_query_sequence
,
past_key_values
=
self
.
self_attn
(
packed_query_sequence
=
packed_query_sequence
,
query_lens
=
query_lens
,
packed_query_position_embeddings
=
packed_query_position_embeddings
,
packed_query_indexes
=
packed_query_indexes
,
past_key_values
=
past_key_values
,
key_values_lens
=
key_values_lens
,
packed_key_value_indexes
=
packed_key_value_indexes
,
update_past_key_values
=
update_past_key_values
,
is_causal
=
is_causal
,
mode
=
mode
,
packed_vae_token_indexes
=
packed_vae_token_indexes
,
packed_text_indexes
=
packed_text_indexes
,
)
packed_query_sequence
=
residual
+
packed_query_sequence
# Fully Connected
residual
=
packed_query_sequence
if
mode
==
"und"
:
packed_query_sequence
=
self
.
post_attention_layernorm
(
packed_query_sequence
)
packed_query_sequence
=
self
.
mlp
(
packed_query_sequence
)
elif
mode
==
"gen"
:
packed_text_query_sequence
=
packed_query_sequence
[
packed_text_indexes
]
packed_vae_query_sequence
=
packed_query_sequence
[
packed_vae_token_indexes
]
packed_text_query_sequence
=
self
.
post_attention_layernorm
(
packed_text_query_sequence
).
to
(
torch
.
bfloat16
)
packed_vae_query_sequence
=
self
.
post_attention_layernorm_moe_gen
(
packed_vae_query_sequence
).
to
(
torch
.
bfloat16
)
packed_query_sequence_
=
torch
.
zeros_like
(
packed_query_sequence
).
to
(
torch
.
bfloat16
)
packed_query_sequence_
[
packed_text_indexes
]
=
self
.
mlp
(
packed_text_query_sequence
)
packed_query_sequence_
[
packed_vae_token_indexes
]
=
self
.
mlp_moe_gen
(
packed_vae_query_sequence
)
packed_query_sequence
=
packed_query_sequence_
packed_query_sequence
=
residual
+
packed_query_sequence
return
packed_query_sequence
,
past_key_values
class
Qwen2MoEDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
,
layer_idx
:
Optional
[
int
]
=
None
):
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
self_attn
=
PackedAttention
(
config
,
layer_idx
)
self
.
mlp
=
Qwen2MLP
(
config
)
self
.
mlp_moe_gen
=
Qwen2MLP
(
config
)
self
.
input_layernorm
=
Qwen2RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
Qwen2RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
*
args
,
**
kwargs
):
if
self
.
training
:
return
self
.
forward_train
(
*
args
,
**
kwargs
)
else
:
return
self
.
forward_inference
(
*
args
,
**
kwargs
)
def
forward_train
(
self
,
packed_sequence
:
torch
.
Tensor
,
sample_lens
:
List
[
int
],
attention_mask
,
packed_position_embeddings
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
packed_und_token_indexes
:
torch
.
LongTensor
,
packed_gen_token_indexes
:
torch
.
LongTensor
,
)
->
torch
.
Tensor
:
residual
=
packed_sequence
packed_sequence
=
self
.
input_layernorm
(
packed_sequence
)
# Self Attention
packed_sequence
=
self
.
self_attn
(
packed_sequence
=
packed_sequence
,
sample_lens
=
sample_lens
,
attention_mask
=
attention_mask
,
packed_position_embeddings
=
packed_position_embeddings
,
)
packed_sequence
=
residual
+
packed_sequence
# Fully Connected
residual
=
packed_sequence
packed_sequence
=
self
.
post_attention_layernorm
(
packed_sequence
)
packed_sequence_new
=
packed_sequence
.
new_zeros
(
packed_sequence
.
shape
)
packed_sequence_und
=
self
.
mlp
(
packed_sequence
[
packed_und_token_indexes
])
packed_sequence_gen
=
self
.
mlp_moe_gen
(
packed_sequence
[
packed_gen_token_indexes
]
)
packed_sequence_new
[
packed_und_token_indexes
]
=
packed_sequence_und
packed_sequence_new
[
packed_gen_token_indexes
]
=
packed_sequence_gen
packed_sequence
=
residual
+
packed_sequence_new
return
packed_sequence
def
forward_inference
(
self
,
packed_query_sequence
:
torch
.
Tensor
,
query_lens
:
torch
.
Tensor
,
packed_query_position_embeddings
:
torch
.
Tensor
,
packed_query_indexes
:
torch
.
Tensor
,
past_key_values
:
Optional
[
NaiveCache
]
=
None
,
key_values_lens
:
Optional
[
torch
.
Tensor
]
=
None
,
packed_key_value_indexes
:
Optional
[
torch
.
Tensor
]
=
None
,
update_past_key_values
=
True
,
is_causal
=
True
,
mode
=
"und"
,
packed_vae_token_indexes
=
None
,
packed_text_indexes
=
None
,
)
->
BaseNavitOutputWithPast
:
residual
=
packed_query_sequence
packed_query_sequence
=
self
.
input_layernorm
(
packed_query_sequence
)
# Self Attention
packed_query_sequence
,
past_key_values
=
self
.
self_attn
(
packed_query_sequence
=
packed_query_sequence
,
query_lens
=
query_lens
,
packed_query_position_embeddings
=
packed_query_position_embeddings
,
packed_query_indexes
=
packed_query_indexes
,
past_key_values
=
past_key_values
,
key_values_lens
=
key_values_lens
,
packed_key_value_indexes
=
packed_key_value_indexes
,
update_past_key_values
=
update_past_key_values
,
is_causal
=
is_causal
,
)
packed_query_sequence
=
residual
+
packed_query_sequence
# Fully Connected
residual
=
packed_query_sequence
packed_query_sequence
=
self
.
post_attention_layernorm
(
packed_query_sequence
)
if
mode
==
"und"
:
packed_query_sequence
=
self
.
mlp
(
packed_query_sequence
)
elif
mode
==
"gen"
:
packed_query_sequence_
=
torch
.
zeros_like
(
packed_query_sequence
).
to
(
torch
.
bfloat16
)
packed_query_sequence_
[
packed_text_indexes
]
=
self
.
mlp
(
packed_query_sequence
[
packed_text_indexes
]
)
packed_query_sequence_
[
packed_vae_token_indexes
]
=
self
.
mlp_moe_gen
(
packed_query_sequence
[
packed_vae_token_indexes
]
)
packed_query_sequence
=
packed_query_sequence_
packed_query_sequence
=
residual
+
packed_query_sequence
return
packed_query_sequence
,
past_key_values
Decoder_layer_dict
=
{
"Qwen2DecoderLayer"
:
Qwen2DecoderLayer
,
"Qwen2MoEDecoderLayer"
:
Qwen2MoEDecoderLayer
,
"Qwen2MoTDecoderLayer"
:
partial
(
Qwen2MoTDecoderLayer
,
attn_module
=
PackedAttentionMoT
),
}
class
Qwen2Model
(
Qwen2PreTrainedModel
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
self
.
use_moe
=
"Mo"
in
config
.
layer_module
self
.
embed_tokens
=
nn
.
Embedding
(
config
.
vocab_size
,
config
.
hidden_size
,
self
.
padding_idx
)
layer_module
=
Decoder_layer_dict
[
config
.
layer_module
]
self
.
layers
=
nn
.
ModuleList
(
[
layer_module
(
config
,
layer_idx
)
for
layer_idx
in
range
(
config
.
num_hidden_layers
)
]
)
self
.
norm
=
Qwen2RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
if
self
.
use_moe
:
self
.
norm_moe_gen
=
Qwen2RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
rotary_emb
=
Qwen2RotaryEmbedding
(
config
=
config
)
# Initialize weights and apply final processing
self
.
post_init
()
def
forward
(
self
,
*
args
,
**
kwargs
):
if
self
.
training
:
return
self
.
forward_train
(
*
args
,
**
kwargs
)
else
:
return
self
.
forward_inference
(
*
args
,
**
kwargs
)
def
forward_train
(
self
,
packed_sequence
:
torch
.
Tensor
,
sample_lens
:
List
[
int
],
attention_mask
,
packed_position_ids
:
torch
.
Tensor
,
packed_und_token_indexes
:
Optional
[
torch
.
LongTensor
]
=
None
,
packed_gen_token_indexes
:
Optional
[
torch
.
LongTensor
]
=
None
,
)
->
torch
.
Tensor
:
if
self
.
config
.
freeze_und
:
packed_sequence
[
packed_und_token_indexes
]
=
packed_sequence
[
packed_und_token_indexes
].
detach
()
# create position embeddings to be shared across the decoder layers
cos
,
sin
=
self
.
rotary_emb
(
packed_sequence
,
packed_position_ids
.
unsqueeze
(
0
))
cos
=
cos
.
squeeze
(
0
)
sin
=
sin
.
squeeze
(
0
)
packed_position_embeddings
=
(
cos
,
sin
)
extra_inputs
=
{}
if
self
.
use_moe
:
assert
packed_und_token_indexes
is
not
None
if
packed_gen_token_indexes
is
None
:
packed_gen_token_indexes
=
packed_und_token_indexes
.
new_ones
(
size
=
[
0
])
extra_inputs
.
update
(
packed_und_token_indexes
=
packed_und_token_indexes
,
packed_gen_token_indexes
=
packed_gen_token_indexes
,
)
for
decoder_layer
in
self
.
layers
:
packed_sequence
=
decoder_layer
(
packed_sequence
=
packed_sequence
,
sample_lens
=
sample_lens
,
attention_mask
=
attention_mask
,
packed_position_embeddings
=
packed_position_embeddings
,
**
extra_inputs
,
)
if
self
.
use_moe
:
packed_sequence_
=
torch
.
zeros_like
(
packed_sequence
)
packed_sequence_
[
packed_und_token_indexes
]
=
self
.
norm
(
packed_sequence
[
packed_und_token_indexes
]
)
if
self
.
config
.
freeze_und
:
packed_sequence_
[
packed_und_token_indexes
]
=
packed_sequence_
[
packed_und_token_indexes
].
detach
()
packed_sequence_
[
packed_gen_token_indexes
]
=
self
.
norm_moe_gen
(
packed_sequence
[
packed_gen_token_indexes
]
)
return
packed_sequence_
else
:
return
self
.
norm
(
packed_sequence
)
def
forward_inference
(
self
,
packed_query_sequence
:
torch
.
Tensor
,
query_lens
:
torch
.
Tensor
,
packed_query_position_ids
:
torch
.
Tensor
,
packed_query_indexes
:
torch
.
Tensor
,
past_key_values
:
Optional
[
NaiveCache
]
=
None
,
key_values_lens
:
Optional
[
torch
.
Tensor
]
=
None
,
packed_key_value_indexes
:
Optional
[
torch
.
Tensor
]
=
None
,
update_past_key_values
=
True
,
is_causal
=
True
,
mode
=
"und"
,
packed_vae_token_indexes
=
None
,
packed_text_indexes
=
None
,
)
->
BaseNavitOutputWithPast
:
# create position embeddings to be shared across the decoder layers
cos
,
sin
=
self
.
rotary_emb
(
packed_query_sequence
,
packed_query_position_ids
.
unsqueeze
(
0
)
)
cos
=
cos
.
squeeze
(
0
)
sin
=
sin
.
squeeze
(
0
)
packed_query_position_embeddings
=
(
cos
,
sin
)
extra_inputs
=
{}
if
self
.
use_moe
:
extra_inputs
.
update
(
mode
=
mode
)
if
mode
==
"gen"
:
assert
packed_vae_token_indexes
is
not
None
assert
packed_text_indexes
is
not
None
extra_inputs
.
update
(
packed_vae_token_indexes
=
packed_vae_token_indexes
,
packed_text_indexes
=
packed_text_indexes
,
)
for
decoder_layer
in
self
.
layers
:
packed_query_sequence
,
past_key_values
=
decoder_layer
(
packed_query_sequence
=
packed_query_sequence
,
query_lens
=
query_lens
,
packed_query_position_embeddings
=
packed_query_position_embeddings
,
packed_query_indexes
=
packed_query_indexes
,
past_key_values
=
past_key_values
,
key_values_lens
=
key_values_lens
,
packed_key_value_indexes
=
packed_key_value_indexes
,
update_past_key_values
=
update_past_key_values
,
is_causal
=
is_causal
,
**
extra_inputs
,
)
if
self
.
use_moe
:
if
mode
==
"und"
:
packed_query_sequence
=
self
.
norm
(
packed_query_sequence
)
elif
mode
==
"gen"
:
packed_query_sequence_
=
torch
.
zeros_like
(
packed_query_sequence
)
packed_query_sequence_
[
packed_text_indexes
]
=
self
.
norm
(
packed_query_sequence
[
packed_text_indexes
]
)
packed_query_sequence_
[
packed_vae_token_indexes
]
=
self
.
norm_moe_gen
(
packed_query_sequence
[
packed_vae_token_indexes
]
)
packed_query_sequence
=
packed_query_sequence_
else
:
packed_query_sequence
=
self
.
norm
(
packed_query_sequence
)
return
BaseNavitOutputWithPast
(
packed_query_sequence
=
packed_query_sequence
,
past_key_values
=
past_key_values
,
)
class
Qwen2ForCausalLM
(
Qwen2PreTrainedModel
):
_tied_weights_keys
=
[
"lm_head.weight"
]
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
model
=
Qwen2Model
(
config
)
self
.
vocab_size
=
config
.
vocab_size
self
.
lm_head
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
vocab_size
,
bias
=
False
)
# Initialize weights and apply final processing
self
.
post_init
()
def
init_moe
(
self
):
for
name
,
param
in
self
.
named_parameters
():
if
"moe_gen"
in
name
:
original_name
=
name
.
replace
(
"_moe_gen"
,
""
)
param
.
data
.
copy_
(
self
.
state_dict
()[
original_name
].
data
)
def
get_input_embeddings
(
self
):
return
self
.
model
.
embed_tokens
def
set_input_embeddings
(
self
,
value
):
self
.
model
.
embed_tokens
=
value
def
get_output_embeddings
(
self
):
return
self
.
lm_head
def
set_output_embeddings
(
self
,
new_embeddings
):
self
.
lm_head
=
new_embeddings
def
set_decoder
(
self
,
decoder
):
self
.
model
=
decoder
def
get_decoder
(
self
):
return
self
.
model
def
forward
(
self
,
*
args
,
**
kwargs
):
if
self
.
training
:
return
self
.
forward_train
(
*
args
,
**
kwargs
)
else
:
return
self
.
forward_inference
(
*
args
,
**
kwargs
)
def
forward_train
(
self
,
packed_sequence
:
torch
.
Tensor
,
sample_lens
:
List
[
int
],
attention_mask
,
packed_position_ids
:
torch
.
Tensor
,
packed_und_token_indexes
:
Optional
[
torch
.
LongTensor
]
=
None
,
packed_gen_token_indexes
:
Optional
[
torch
.
LongTensor
]
=
None
,
)
->
torch
.
Tensor
:
outputs
=
self
.
model
(
packed_sequence
=
packed_sequence
,
sample_lens
=
sample_lens
,
packed_position_ids
=
packed_position_ids
,
attention_mask
=
attention_mask
,
packed_und_token_indexes
=
packed_und_token_indexes
,
packed_gen_token_indexes
=
packed_gen_token_indexes
,
)
return
outputs
def
forward_inference
(
self
,
packed_query_sequence
:
torch
.
Tensor
,
query_lens
:
torch
.
Tensor
,
packed_query_position_ids
:
torch
.
Tensor
,
packed_query_indexes
:
torch
.
Tensor
,
past_key_values
:
Optional
[
NaiveCache
]
=
None
,
key_values_lens
:
Optional
[
torch
.
Tensor
]
=
None
,
packed_key_value_indexes
:
Optional
[
torch
.
Tensor
]
=
None
,
update_past_key_values
=
True
,
is_causal
=
True
,
mode
=
"und"
,
packed_vae_token_indexes
=
None
,
packed_text_indexes
=
None
,
)
->
BaseNavitOutputWithPast
:
outputs
=
self
.
model
(
packed_query_sequence
=
packed_query_sequence
,
query_lens
=
query_lens
,
packed_query_position_ids
=
packed_query_position_ids
,
packed_query_indexes
=
packed_query_indexes
,
past_key_values
=
past_key_values
,
key_values_lens
=
key_values_lens
,
packed_key_value_indexes
=
packed_key_value_indexes
,
update_past_key_values
=
update_past_key_values
,
is_causal
=
is_causal
,
mode
=
mode
,
packed_vae_token_indexes
=
packed_vae_token_indexes
,
packed_text_indexes
=
packed_text_indexes
,
)
return
outputs
Prev
1
2
3
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment