"vscode:/vscode.git/clone" did not exist on "47684368dbbe4185d068be77d32a962059cfc37c"
Commit 876a36a4 authored by raojy's avatar raojy
Browse files

first

parent eda2afb8
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
import math
import random
import numpy as np
import torch
from PIL import Image
from torch.nn.attention.flex_attention import and_masks, or_masks
def create_sparse_mask(document_lens, split_lens, attn_modes, device):
def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
def full_and_noise_mask(b, h, q_idx, kv_idx):
return (full_and_noise_seq_id[q_idx] == full_and_noise_seq_id[kv_idx]) & (
full_and_noise_seq_id[q_idx] >= 0
)
def remove_noise_mask(b, h, q_idx, kv_idx):
return ~(
(noise_seq_id[kv_idx] >= 0) & (noise_seq_id[q_idx] != noise_seq_id[kv_idx])
)
def sample_mask(b, h, q_idx, kv_idx):
return document_id[q_idx] == document_id[kv_idx]
full_and_noise_tmp = []
noise_tmp = []
for i, (length, model) in enumerate(zip(split_lens, attn_modes)):
value = i if model in ["full", "noise"] else -1
full_and_noise_tmp.extend([value] * length)
value_noise = i if model == "noise" else -1
noise_tmp.extend([value_noise] * length)
full_and_noise_seq_id = torch.Tensor(full_and_noise_tmp).to(device)
noise_seq_id = torch.Tensor(noise_tmp).to(device)
document_id = torch.cat(
[torch.full((l,), i) for i, l in enumerate(document_lens, start=1)]
).to(device)
return and_masks(
or_masks(causal_mask, full_and_noise_mask), remove_noise_mask, sample_mask
)
def patchify(image, patch_size):
p = patch_size
c, h, w = image.shape
assert h % p == 0 and w % p == 0
image = image.reshape(c, h // p, p, w // p, p)
image = torch.einsum("chpwq->hwpqc", image)
image = image.reshape(-1, p**2 * c)
return image
def get_flattened_position_ids_extrapolate(
img_h, img_w, patch_size, max_num_patches_per_side
):
num_patches_h, num_patches_w = img_h // patch_size, img_w // patch_size
coords_h = torch.arange(0, num_patches_h)
coords_w = torch.arange(0, num_patches_w)
pos_ids = (coords_h[:, None] * max_num_patches_per_side + coords_w).flatten()
return pos_ids
def get_flattened_position_ids_interpolate(
img_h, img_w, patch_size, max_num_patches_per_side
):
num_patches_h, num_patches_w = img_h // patch_size, img_w // patch_size
boundaries = torch.arange(
1 / max_num_patches_per_side, 1.0, 1 / max_num_patches_per_side
)
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / num_patches_h)
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / num_patches_w)
bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
pos_ids = (
bucket_coords_h[:, None] * max_num_patches_per_side + bucket_coords_w
).flatten()
return pos_ids
def prepare_attention_mask_per_sample(split_lens, attn_modes, device="cpu"):
"""
nested_split_lens: A list of N lists of ints. Each int indicates the length of a split within
a sample, where each sample contains multiple splits with different attn modes.
nested_attn_modes: whether to use full attn in each split.
"""
sample_len = sum(split_lens)
attention_mask = torch.zeros(
(sample_len, sample_len), dtype=torch.bool, device=device
)
csum = 0
for s, attn_mode in zip(split_lens, attn_modes):
assert attn_mode in ["causal", "full", "noise"]
if attn_mode == "causal":
attention_mask[csum : csum + s, csum : csum + s] = torch.ones(
(s, s), device=device
).tril()
attention_mask[csum : csum + s, :csum] = 1
else:
attention_mask[csum : csum + s, csum : csum + s] = torch.ones((s, s))
attention_mask[csum : csum + s, :csum] = 1
csum += s
csum = 0
for s, attn_mode in zip(split_lens, attn_modes):
if attn_mode == "noise":
attention_mask[:, csum : csum + s] = torch.zeros((sample_len, s))
attention_mask[csum : csum + s, csum : csum + s] = torch.ones((s, s))
csum += s
attention_mask = torch.zeros_like(attention_mask, dtype=torch.float).masked_fill_(
~attention_mask, float("-inf")
)
return attention_mask
def split_integer_exp_decay(S, ng_sample_decay=1.0):
if ng_sample_decay == 1.0:
N = random.randint(1, S)
else:
base = (1 - ng_sample_decay) / (1 - math.pow(ng_sample_decay, S))
p = [base * math.pow(ng_sample_decay, i) for i in range(S)]
N = random.choices(list(range(1, S + 1)), p, k=1)[0]
cumsum = [0] + sorted(random.sample(range(1, S), N - 1)) + [S]
result = [cumsum[i + 1] - cumsum[i] for i in range(len(cumsum) - 1)]
return result, cumsum
def pil_img2rgb(image):
if image.mode == "RGBA" or image.info.get("transparency", None) is not None:
image = image.convert("RGBA")
white = Image.new(mode="RGB", size=image.size, color=(255, 255, 255))
white.paste(image, mask=image.split()[3])
image = white
else:
image = image.convert("RGB")
return image
def add_special_tokens(tokenizer):
all_special_tokens = []
for k, v in tokenizer.special_tokens_map.items():
if isinstance(v, str):
all_special_tokens.append(v)
elif isinstance(v, list):
all_special_tokens += v
new_tokens = []
if "<|im_start|>" not in all_special_tokens:
new_tokens.append("<|im_start|>")
if "<|im_end|>" not in all_special_tokens:
new_tokens.append("<|im_end|>")
if "<|vision_start|>" not in all_special_tokens:
new_tokens.append("<|vision_start|>")
if "<|vision_end|>" not in all_special_tokens:
new_tokens.append("<|vision_end|>")
num_new_tokens = tokenizer.add_tokens(new_tokens)
bos_token_id = tokenizer.convert_tokens_to_ids("<|im_start|>")
eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
start_of_image = tokenizer.convert_tokens_to_ids("<|vision_start|>")
end_of_image = tokenizer.convert_tokens_to_ids("<|vision_end|>")
new_token_ids = dict(
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
start_of_image=start_of_image,
end_of_image=end_of_image,
)
return tokenizer, new_token_ids, num_new_tokens
def len2weight(x, loss_reduction="square"):
if x == 0:
return x
if loss_reduction == "token":
return 1
if loss_reduction == "sample":
return 1 / x
if loss_reduction == "square":
return 1 / (x**0.5)
raise NotImplementedError(loss_reduction)
def load_image(image_path):
return Image.open(image_path)
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
import json
import random
import numpy as np
import torch
from .data_utils import (
get_flattened_position_ids_extrapolate,
get_flattened_position_ids_interpolate,
len2weight,
patchify,
prepare_attention_mask_per_sample,
)
from .dataset_info import DATASET_INFO, DATASET_REGISTRY
from .transforms import ImageTransform
from .video_utils import FrameSampler
class DataConfig:
def __init__(
self,
grouped_datasets,
text_cond_dropout_prob=0.1,
vit_cond_dropout_prob=0.4,
vae_cond_dropout_prob=0.1,
vae_image_downsample=16,
max_latent_size=32,
vit_patch_size=14,
max_num_patch_per_side=70,
):
self.grouped_datasets = grouped_datasets
self.text_cond_dropout_prob = text_cond_dropout_prob
self.vit_cond_dropout_prob = vit_cond_dropout_prob
self.vit_patch_size = vit_patch_size
self.max_num_patch_per_side = max_num_patch_per_side
self.vae_cond_dropout_prob = vae_cond_dropout_prob
self.vae_image_downsample = vae_image_downsample
self.max_latent_size = max_latent_size
class PackedDataset(torch.utils.data.IterableDataset):
bos_token_id: int
eos_token_id: int
start_of_image: int
end_of_image: int
def __init__(
self,
data_config,
tokenizer,
special_tokens,
local_rank,
world_size,
num_workers,
expected_num_tokens=32768,
max_num_tokens_per_sample=16384,
max_num_tokens=36864,
prefer_buffer_before=16384,
max_buffer_size=50,
interpolate_pos=False,
use_flex=False,
data_status=None,
):
super().__init__()
self.expected_num_tokens = expected_num_tokens
self.max_num_tokens_per_sample = max_num_tokens_per_sample
self.prefer_buffer_before = prefer_buffer_before
self.max_num_tokens = max_num_tokens
self.max_buffer_size = max_buffer_size
self.tokenizer = tokenizer
self.local_rank = local_rank
self.world_size = world_size
self.num_workers = num_workers
self.use_flex = use_flex
for k, v in special_tokens.items():
setattr(self, k, v)
grouped_datasets, is_mandatory, grouped_weights = self.build_datasets(
data_config.grouped_datasets, data_status
)
self.grouped_datasets = grouped_datasets
self.dataset_iters = [iter(dataset) for dataset in grouped_datasets]
self.is_mandatory = is_mandatory
self.grouped_weights = grouped_weights
self.data_config = data_config
self.interpolate_pos = interpolate_pos
if self.interpolate_pos:
self.get_flattened_position_ids = get_flattened_position_ids_interpolate
else:
self.get_flattened_position_ids = get_flattened_position_ids_extrapolate
def build_datasets(self, datasets_metainfo, data_status):
datasets = []
is_mandatory = []
grouped_weights = []
for grouped_dataset_name, dataset_args in datasets_metainfo.items():
is_mandatory.append(dataset_args.pop("is_mandatory", False))
grouped_weights.append(dataset_args.pop("weight", 0.0))
if "frame_sampler_args" in dataset_args.keys():
frame_sampler = FrameSampler(**dataset_args.pop("frame_sampler_args"))
dataset_args["frame_sampler"] = frame_sampler
if "image_transform_args" in dataset_args.keys():
transform = ImageTransform(**dataset_args.pop("image_transform_args"))
dataset_args["transform"] = transform
if "vit_image_transform_args" in dataset_args.keys():
vit_transform = ImageTransform(
**dataset_args.pop("vit_image_transform_args")
)
dataset_args["vit_transform"] = vit_transform
if "dataset_names" in dataset_args.keys():
dataset_names = dataset_args.pop("dataset_names")
else:
dataset_names = DATASET_INFO[grouped_dataset_name].keys()
if "num_used_data" not in dataset_args.keys():
dataset_args["num_used_data"] = []
append_num_used_data = True
else:
append_num_used_data = False
dataset_args["data_dir_list"] = []
for item in dataset_names:
if self.local_rank == 0:
print(f"Preparing Dataset {grouped_dataset_name}/{item}")
meta_info = DATASET_INFO[grouped_dataset_name][item]
dataset_args["data_dir_list"].append(meta_info["data_dir"])
if append_num_used_data:
dataset_args["num_used_data"].append(meta_info["num_total_samples"])
if "parquet_info_path" in meta_info.keys():
if "parquet_info" not in dataset_args.keys():
dataset_args["parquet_info"] = {}
with open(meta_info["parquet_info_path"], "r") as f:
parquet_info = json.load(f)
dataset_args["parquet_info"].update(parquet_info)
if "json_dir" in meta_info.keys():
# parquet/tar with json
if "json_dir_list" not in dataset_args.keys():
dataset_args["json_dir_list"] = [meta_info["json_dir"]]
else:
dataset_args["json_dir_list"].append(meta_info["json_dir"])
if "jsonl_path" in meta_info.keys():
# jsonl with jpeg
if "jsonl_path_list" not in dataset_args.keys():
dataset_args["jsonl_path_list"] = [meta_info["jsonl_path"]]
else:
dataset_args["jsonl_path_list"].append(meta_info["jsonl_path"])
resume_data_status = dataset_args.pop("resume_data_status", True)
if (
data_status is not None
and grouped_dataset_name in data_status.keys()
and resume_data_status
):
data_status_per_group = data_status[grouped_dataset_name]
else:
data_status_per_group = None
dataset = DATASET_REGISTRY[grouped_dataset_name](
dataset_name=grouped_dataset_name,
tokenizer=self.tokenizer,
local_rank=self.local_rank,
world_size=self.world_size,
num_workers=self.num_workers,
data_status=data_status_per_group,
**dataset_args,
)
datasets.append(dataset)
return datasets, is_mandatory, grouped_weights
def set_epoch(self, seed):
for dataset in self.grouped_datasets:
dataset.set_epoch(seed)
def set_sequence_status(self):
sequence_status = dict(
curr=0,
sample_lens=list(),
packed_position_ids=list(),
nested_attention_masks=list(),
split_lens=list(),
attn_modes=list(),
packed_text_ids=list(),
packed_text_indexes=list(),
packed_label_ids=list(),
ce_loss_indexes=list(),
ce_loss_weights=list(),
vae_image_tensors=list(),
packed_latent_position_ids=list(),
vae_latent_shapes=list(),
packed_vae_token_indexes=list(),
packed_timesteps=list(),
mse_loss_indexes=list(),
packed_vit_tokens=list(),
vit_token_seqlens=list(),
packed_vit_position_ids=list(),
packed_vit_token_indexes=list(),
)
return sequence_status
def to_tensor(self, sequence_status):
data = dict(
sequence_length=sum(sequence_status["sample_lens"]),
sample_lens=sequence_status["sample_lens"],
packed_text_ids=torch.tensor(sequence_status["packed_text_ids"]),
packed_text_indexes=torch.tensor(sequence_status["packed_text_indexes"]),
packed_position_ids=torch.tensor(sequence_status["packed_position_ids"]),
)
if not self.use_flex:
data["nested_attention_masks"] = sequence_status["nested_attention_masks"]
else:
sequence_len = data["sequence_length"]
pad_len = self.max_num_tokens - sequence_len
data["split_lens"] = sequence_status["split_lens"] + [pad_len]
data["attn_modes"] = sequence_status["attn_modes"] + ["causal"]
data["sample_lens"] += [pad_len]
# if the model has a convnet vae (e.g., as visual tokenizer)
if len(sequence_status["vae_image_tensors"]) > 0:
image_tensors = sequence_status.pop("vae_image_tensors")
image_sizes = [item.shape for item in image_tensors]
max_image_size = [max(item) for item in list(zip(*image_sizes))]
padded_images = torch.zeros(size=(len(image_tensors), *max_image_size))
for i, image_tensor in enumerate(image_tensors):
padded_images[
i, :, : image_tensor.shape[1], : image_tensor.shape[2]
] = image_tensor
data["padded_images"] = padded_images
data["patchified_vae_latent_shapes"] = sequence_status["vae_latent_shapes"]
data["packed_latent_position_ids"] = torch.cat(
sequence_status["packed_latent_position_ids"], dim=0
)
data["packed_vae_token_indexes"] = torch.tensor(
sequence_status["packed_vae_token_indexes"]
)
# if the model has a vit (e.g., as visual tokenizer)
if len(sequence_status["packed_vit_tokens"]) > 0:
data["packed_vit_tokens"] = torch.cat(
sequence_status["packed_vit_tokens"], dim=0
)
data["packed_vit_position_ids"] = torch.cat(
sequence_status["packed_vit_position_ids"], dim=0
)
data["packed_vit_token_indexes"] = torch.tensor(
sequence_status["packed_vit_token_indexes"]
)
data["vit_token_seqlens"] = torch.tensor(
sequence_status["vit_token_seqlens"]
)
# if the model is required to perform visual generation
if len(sequence_status["packed_timesteps"]) > 0:
data["packed_timesteps"] = torch.tensor(sequence_status["packed_timesteps"])
data["mse_loss_indexes"] = torch.tensor(sequence_status["mse_loss_indexes"])
# if the model is required to perform text generation
if len(sequence_status["packed_label_ids"]) > 0:
data["packed_label_ids"] = torch.tensor(sequence_status["packed_label_ids"])
data["ce_loss_indexes"] = torch.tensor(sequence_status["ce_loss_indexes"])
data["ce_loss_weights"] = torch.tensor(sequence_status["ce_loss_weights"])
return data
def __iter__(self):
total_weights = sum(self.grouped_weights)
assert total_weights > 0.0
group_cumprobs = [
sum(self.grouped_weights[: i + 1]) / total_weights
for i in range(len(self.grouped_weights))
]
sequence_status = self.set_sequence_status()
batch_data_indexes = []
buffer = []
video_buffer = [] # Separate buffer for extremely long video samples
while True:
# Ensure at least one sample from each group
if sequence_status["curr"] == 0:
if len(video_buffer) > 0:
sample = video_buffer.pop(0)
num_tokens = sample["num_tokens"] + 2 * len(sample["sequence_plan"])
sequence_status = self.pack_sequence(sample, sequence_status)
batch_data_indexes.append(sample["data_indexes"])
else:
for group_index, group_iter in enumerate(self.dataset_iters):
if self.is_mandatory[group_index]:
while True:
sample = next(group_iter)
# if a sample is too long, put it in video buffer
num_tokens = sample["num_tokens"] + 2 * len(
sample["sequence_plan"]
)
if num_tokens > self.max_num_tokens_per_sample:
if len(video_buffer) < self.max_buffer_size:
video_buffer.append(sample)
print(
f"Added sample with length {num_tokens} to video_buffer (size: {len(video_buffer)})"
)
else:
print(
f"video_buffer full, skip a sample with length {num_tokens}"
)
break
elif num_tokens < self.max_num_tokens_per_sample:
sequence_status = self.pack_sequence(
sample, sequence_status
)
batch_data_indexes.append(sample["data_indexes"])
break
if sequence_status["curr"] >= self.expected_num_tokens:
data = self.to_tensor(sequence_status)
data["batch_data_indexes"] = batch_data_indexes
print(
f"Yielding {len(sequence_status['sample_lens'])} 3D data with length {sum(sequence_status['sample_lens'])}, length of each sample: {sequence_status['sample_lens']}"
)
yield data
sequence_status = self.set_sequence_status()
batch_data_indexes = []
if sequence_status["curr"] < self.prefer_buffer_before and len(buffer) > 0:
sample = buffer.pop(0)
sample_from_buffer = True
else:
# sample normally across all groups
n = random.random()
group_index = 0
for i, cumprob in enumerate(group_cumprobs):
if n < cumprob:
group_index = i
break
sample = next(self.dataset_iters[group_index])
sample_from_buffer = False
# if a sample is too long, skip it
num_tokens = sample["num_tokens"] + 2 * len(sample["sequence_plan"])
if num_tokens > self.max_num_tokens_per_sample:
if len(video_buffer) < self.max_buffer_size:
video_buffer.append(sample)
print(
f"Added sample with length {num_tokens} to video_buffer (size: {len(video_buffer)})"
)
else:
print(f"video_buffer full, skip a sample with length {num_tokens}")
continue
if sequence_status["curr"] + num_tokens > self.max_num_tokens:
if len(buffer) < self.max_buffer_size and not sample_from_buffer:
buffer.append(sample)
else:
# print(f"Yielding data with length {sum(sequence_status['sample_lens'])}")
data = self.to_tensor(sequence_status)
data["batch_data_indexes"] = batch_data_indexes
yield data
sequence_status = self.set_sequence_status()
batch_data_indexes = []
continue
sequence_status = self.pack_sequence(sample, sequence_status)
batch_data_indexes.append(sample["data_indexes"])
if sequence_status["curr"] >= self.expected_num_tokens:
data = self.to_tensor(sequence_status)
data["batch_data_indexes"] = batch_data_indexes
yield data
sequence_status = self.set_sequence_status()
batch_data_indexes = []
def pack_sequence(self, sample, sequence_status):
image_tensor_list = sample["image_tensor_list"]
text_ids_list = sample["text_ids_list"]
sequence_plan = sample["sequence_plan"]
split_lens, attn_modes = list(), list()
curr = sequence_status["curr"]
curr_rope_id = 0
sample_lens = 0
for item in sequence_plan:
split_start = item.get("split_start", True)
if split_start:
curr_split_len = 0
if item["type"] == "text":
text_ids = text_ids_list.pop(0)
if (
item["enable_cfg"] == 1
and random.random() < self.data_config.text_cond_dropout_prob
):
continue
shifted_text_ids = [self.bos_token_id] + text_ids
sequence_status["packed_text_ids"].extend(shifted_text_ids)
sequence_status["packed_text_indexes"].extend(
range(curr, curr + len(shifted_text_ids))
)
if item["loss"] == 1:
sequence_status["ce_loss_indexes"].extend(
range(curr, curr + len(shifted_text_ids))
)
sequence_status["ce_loss_weights"].extend(
[len2weight(len(shifted_text_ids))] * len(shifted_text_ids)
)
sequence_status["packed_label_ids"].extend(
text_ids + [self.eos_token_id]
)
curr += len(shifted_text_ids)
curr_split_len += len(shifted_text_ids)
# add a <|im_end|> token
sequence_status["packed_text_ids"].append(self.eos_token_id)
sequence_status["packed_text_indexes"].append(curr)
if item["special_token_loss"] == 1: # <|im_end|> may have loss
sequence_status["ce_loss_indexes"].append(curr)
sequence_status["ce_loss_weights"].append(1.0)
sequence_status["packed_label_ids"].append(
item["special_token_label"]
)
curr += 1
curr_split_len += 1
# update sequence status
attn_modes.append("causal")
sequence_status["packed_position_ids"].extend(
range(curr_rope_id, curr_rope_id + curr_split_len)
)
curr_rope_id += curr_split_len
elif item["type"] == "vit_image":
image_tensor = image_tensor_list.pop(0)
if (
item["enable_cfg"] == 1
and random.random() < self.data_config.vit_cond_dropout_prob
):
curr_rope_id += 1
continue
# add a <|startofimage|> token
sequence_status["packed_text_ids"].append(self.start_of_image)
sequence_status["packed_text_indexes"].append(curr)
curr += 1
curr_split_len += 1
# preprocess image
vit_tokens = patchify(image_tensor, self.data_config.vit_patch_size)
num_img_tokens = vit_tokens.shape[0]
sequence_status["packed_vit_token_indexes"].extend(
range(curr, curr + num_img_tokens)
)
curr += num_img_tokens
curr_split_len += num_img_tokens
sequence_status["packed_vit_tokens"].append(vit_tokens)
sequence_status["vit_token_seqlens"].append(num_img_tokens)
sequence_status["packed_vit_position_ids"].append(
self.get_flattened_position_ids(
image_tensor.size(1),
image_tensor.size(2),
self.data_config.vit_patch_size,
max_num_patches_per_side=self.data_config.max_num_patch_per_side,
)
)
# add a <|endofimage|> token
sequence_status["packed_text_ids"].append(self.end_of_image)
sequence_status["packed_text_indexes"].append(curr)
if item["special_token_loss"] == 1: # <|endofimage|> may have loss
sequence_status["ce_loss_indexes"].append(curr)
sequence_status["ce_loss_weights"].append(1.0)
sequence_status["packed_label_ids"].append(
item["special_token_label"]
)
curr += 1
curr_split_len += 1
# update sequence status
attn_modes.append("full")
sequence_status["packed_position_ids"].extend(
[curr_rope_id] * curr_split_len
)
curr_rope_id += 1
elif item["type"] == "vae_image":
image_tensor = image_tensor_list.pop(0)
if (
item["enable_cfg"] == 1
and random.random() < self.data_config.vae_cond_dropout_prob
):
# FIXME fix vae dropout in video2video setting.
curr_rope_id += 1
continue
# add a <|startofimage|> token
sequence_status["packed_text_ids"].append(self.start_of_image)
sequence_status["packed_text_indexes"].append(curr)
curr += 1
curr_split_len += 1
# preprocess image
sequence_status["vae_image_tensors"].append(image_tensor)
sequence_status["packed_latent_position_ids"].append(
self.get_flattened_position_ids(
image_tensor.size(1),
image_tensor.size(2),
self.data_config.vae_image_downsample,
max_num_patches_per_side=self.data_config.max_latent_size,
)
)
H, W = image_tensor.shape[1:]
h = H // self.data_config.vae_image_downsample
w = W // self.data_config.vae_image_downsample
sequence_status["vae_latent_shapes"].append((h, w))
num_img_tokens = w * h
sequence_status["packed_vae_token_indexes"].extend(
range(curr, curr + num_img_tokens)
)
if item["loss"] == 1:
sequence_status["mse_loss_indexes"].extend(
range(curr, curr + num_img_tokens)
)
if split_start:
timestep = np.random.randn()
else:
timestep = float("-inf")
sequence_status["packed_timesteps"].extend([timestep] * num_img_tokens)
curr += num_img_tokens
curr_split_len += num_img_tokens
# add a <|endofimage|> token
sequence_status["packed_text_ids"].append(self.end_of_image)
sequence_status["packed_text_indexes"].append(curr)
# <|endofimage|> may have loss
if item["special_token_loss"] == 1:
sequence_status["ce_loss_indexes"].append(curr)
sequence_status["ce_loss_weights"].append(1.0)
sequence_status["packed_label_ids"].append(
item["special_token_label"]
)
curr += 1
curr_split_len += 1
# update sequence status
if split_start:
if item["loss"] == 1 and "frame_delta" not in item.keys():
attn_modes.append("noise")
else:
attn_modes.append("full")
sequence_status["packed_position_ids"].extend(
[curr_rope_id] * (num_img_tokens + 2)
)
if "frame_delta" in item.keys():
curr_rope_id += item["frame_delta"]
elif item["loss"] == 0:
curr_rope_id += 1
if item.get("split_end", True):
split_lens.append(curr_split_len)
sample_lens += curr_split_len
sequence_status["curr"] = curr
sequence_status["sample_lens"].append(sample_lens)
# prepare attention mask
if not self.use_flex:
sequence_status["nested_attention_masks"].append(
prepare_attention_mask_per_sample(split_lens, attn_modes)
)
else:
sequence_status["split_lens"].extend(split_lens)
sequence_status["attn_modes"].extend(attn_modes)
return sequence_status
class SimpleCustomBatch:
def __init__(self, batch):
data = batch[0]
self.batch_data_indexes = data["batch_data_indexes"]
self.sequence_length = data["sequence_length"]
self.sample_lens = data["sample_lens"]
self.packed_text_ids = data["packed_text_ids"]
self.packed_text_indexes = data["packed_text_indexes"]
self.packed_position_ids = data["packed_position_ids"]
self.use_flex = "nested_attention_masks" not in data.keys()
if self.use_flex:
self.split_lens = data["split_lens"]
self.attn_modes = data["attn_modes"]
else:
self.nested_attention_masks = data["nested_attention_masks"]
if "padded_images" in data.keys():
self.padded_images = data["padded_images"]
self.patchified_vae_latent_shapes = data["patchified_vae_latent_shapes"]
self.packed_latent_position_ids = data["packed_latent_position_ids"]
self.packed_vae_token_indexes = data["packed_vae_token_indexes"]
if "packed_vit_tokens" in data.keys():
self.packed_vit_tokens = data["packed_vit_tokens"]
self.packed_vit_position_ids = data["packed_vit_position_ids"]
self.packed_vit_token_indexes = data["packed_vit_token_indexes"]
self.vit_token_seqlens = data["vit_token_seqlens"]
if "packed_timesteps" in data.keys():
self.packed_timesteps = data["packed_timesteps"]
self.mse_loss_indexes = data["mse_loss_indexes"]
if "packed_label_ids" in data.keys():
self.packed_label_ids = data["packed_label_ids"]
self.ce_loss_indexes = data["ce_loss_indexes"]
self.ce_loss_weights = data["ce_loss_weights"]
def pin_memory(self):
self.packed_text_ids = self.packed_text_ids.pin_memory()
self.packed_text_indexes = self.packed_text_indexes.pin_memory()
self.packed_position_ids = self.packed_position_ids.pin_memory()
if not self.use_flex:
self.nested_attention_masks = [
item.pin_memory() for item in self.nested_attention_masks
]
if hasattr(self, "padded_images"):
self.padded_images = self.padded_images.pin_memory()
self.packed_vae_token_indexes = self.packed_vae_token_indexes.pin_memory()
self.packed_latent_position_ids = (
self.packed_latent_position_ids.pin_memory()
)
if hasattr(self, "packed_timesteps"):
self.packed_timesteps = self.packed_timesteps.pin_memory()
self.mse_loss_indexes = self.mse_loss_indexes.pin_memory()
if hasattr(self, "packed_vit_tokens"):
self.packed_vit_tokens = self.packed_vit_tokens.pin_memory()
self.packed_vit_position_ids = self.packed_vit_position_ids.pin_memory()
self.packed_vit_token_indexes = self.packed_vit_token_indexes.pin_memory()
self.vit_token_seqlens = self.vit_token_seqlens.pin_memory()
if hasattr(self, "packed_label_ids"):
self.packed_label_ids = self.packed_label_ids.pin_memory()
self.ce_loss_indexes = self.ce_loss_indexes.pin_memory()
self.ce_loss_weights = self.ce_loss_weights.pin_memory()
return self
def cuda(self, device):
self.packed_text_ids = self.packed_text_ids.to(device)
self.packed_text_indexes = self.packed_text_indexes.to(device)
self.packed_position_ids = self.packed_position_ids.to(device)
if not self.use_flex:
self.nested_attention_masks = [
item.to(device) for item in self.nested_attention_masks
]
if hasattr(self, "padded_images"):
self.padded_images = self.padded_images.to(device)
self.packed_vae_token_indexes = self.packed_vae_token_indexes.to(device)
self.packed_latent_position_ids = self.packed_latent_position_ids.to(device)
if hasattr(self, "packed_timesteps"):
self.packed_timesteps = self.packed_timesteps.to(device)
self.mse_loss_indexes = self.mse_loss_indexes.to(device)
if hasattr(self, "packed_vit_tokens"):
self.packed_vit_tokens = self.packed_vit_tokens.to(device)
self.packed_vit_position_ids = self.packed_vit_position_ids.to(device)
self.packed_vit_token_indexes = self.packed_vit_token_indexes.to(device)
self.vit_token_seqlens = self.vit_token_seqlens.to(device)
if hasattr(self, "packed_label_ids"):
self.packed_label_ids = self.packed_label_ids.to(device)
self.ce_loss_indexes = self.ce_loss_indexes.to(device)
self.ce_loss_weights = self.ce_loss_weights.to(device)
return self
def to_dict(self):
data = dict(
sequence_length=self.sequence_length,
sample_lens=self.sample_lens,
packed_text_ids=self.packed_text_ids,
packed_text_indexes=self.packed_text_indexes,
packed_position_ids=self.packed_position_ids,
batch_data_indexes=self.batch_data_indexes,
)
if not self.use_flex:
data["nested_attention_masks"] = self.nested_attention_masks
else:
data["split_lens"] = self.split_lens
data["attn_modes"] = self.attn_modes
if hasattr(self, "padded_images"):
data["padded_images"] = self.padded_images
data["patchified_vae_latent_shapes"] = self.patchified_vae_latent_shapes
data["packed_latent_position_ids"] = self.packed_latent_position_ids
data["packed_vae_token_indexes"] = self.packed_vae_token_indexes
if hasattr(self, "packed_vit_tokens"):
data["packed_vit_tokens"] = self.packed_vit_tokens
data["packed_vit_position_ids"] = self.packed_vit_position_ids
data["packed_vit_token_indexes"] = self.packed_vit_token_indexes
data["vit_token_seqlens"] = self.vit_token_seqlens
if hasattr(self, "packed_timesteps"):
data["packed_timesteps"] = self.packed_timesteps
data["mse_loss_indexes"] = self.mse_loss_indexes
if hasattr(self, "packed_label_ids"):
data["packed_label_ids"] = self.packed_label_ids
data["ce_loss_indexes"] = self.ce_loss_indexes
data["ce_loss_weights"] = self.ce_loss_weights
return data
def collate_wrapper():
def collate_fn(batch):
return SimpleCustomBatch(batch)
return collate_fn
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
import glob
import json
import os
import os.path as osp
from .edit_dataset_jsonl import EditJSONLIterableDataset
from .interleave_datasets import UnifiedEditIterableDataset
from .t2i_dataset import T2IIterableDataset
from .t2i_dataset_jsonl import T2IJSONLIterableDataset
from .vlm_dataset import SftJSONLIterableDataset
DATASET_REGISTRY = {
"sensenova_si_800K": SftJSONLIterableDataset,
"sensenova_si_8M": SftJSONLIterableDataset,
}
DATASET_INFO = {}
# load additional dataset info from the dataset_info/ directory
dataset_info_path = osp.join(osp.dirname(__file__), "dataset_info")
dataset_info_files = glob.glob(osp.join(dataset_info_path, "*.json"))
training_root = os.environ.get(
"TRAINING_ROOT",
osp.abspath(osp.join(osp.dirname(__file__), "..", "..", "..")),
)
def _resolve_training_root_path(value):
if isinstance(value, str):
return value.replace("__TRAINING_ROOT__", training_root)
if isinstance(value, list):
return [_resolve_training_root_path(v) for v in value]
if isinstance(value, dict):
return {k: _resolve_training_root_path(v) for k, v in value.items()}
return value
for dataset_info_file in dataset_info_files:
with open(dataset_info_file, "r") as f:
base_name = osp.splitext(osp.basename(dataset_info_file))[0]
dataset_info = _resolve_training_root_path(json.load(f))
for key in dataset_info.keys():
if key in DATASET_INFO:
raise ValueError(f"Key {key} already exists in DATASET_INFO")
DATASET_INFO.update({base_name: dataset_info})
{
"sensenova_si_800K": {
"data_dir": "__TRAINING_ROOT__/data/SenseNova-SI-800K/",
"jsonl_path": "__TRAINING_ROOT__/data/SenseNova-SI-800K/SenseNova-SI-800K.jsonl",
"num_total_samples": 832132
}
}
{
"sensenova_si_8M": {
"data_dir": "__TRAINING_ROOT__/data/SenseNova-SI-8M/",
"jsonl_path": "__TRAINING_ROOT__/data/SenseNova-SI-8M/SenseNova-SI-8M.jsonl",
"num_total_samples": 8165067
}
}
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
import random
import torch
class DistributedIterableDataset(torch.utils.data.IterableDataset):
def __init__(self, dataset_name, local_rank=0, world_size=1, num_workers=8):
self.dataset_name = dataset_name
self.local_rank = local_rank
self.world_size = world_size
self.num_workers = num_workers
self.rng = random.Random()
self.data_paths = None
def get_data_paths(self, *args, **kwargs):
raise NotImplementedError
def set_epoch(self, seed=42):
if self.data_paths is None:
return
if isinstance(self.data_paths[0], tuple):
data_paths = sorted(self.data_paths, key=lambda x: (x[0], x[1]))
elif isinstance(self.data_paths[0], str):
data_paths = sorted(self.data_paths)
else:
raise ValueError(f"Unknown data_paths type: {type(self.data_paths[0])}")
self.rng.seed(seed)
self.rng.shuffle(data_paths)
num_files_per_rank = len(data_paths) // self.world_size
local_start = self.local_rank * num_files_per_rank
local_end = (self.local_rank + 1) * num_files_per_rank
self.num_files_per_rank = num_files_per_rank
self.data_paths_per_rank = data_paths[local_start:local_end]
def get_data_paths_per_worker(self):
if self.data_paths is None:
return None
info = torch.utils.data.get_worker_info()
if info is None:
# Single worker: Use all files assigned to the rank
return self.data_paths_per_rank, 0
worker_id = info.id
num_files_per_worker = self.num_files_per_rank // info.num_workers
start = num_files_per_worker * worker_id
end = num_files_per_worker * (worker_id + 1)
data_paths_per_worker = self.data_paths_per_rank[start:end]
return data_paths_per_worker[::-1], worker_id
def __iter__(self):
raise NotImplementedError
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
import io
import json
import os
import random
import pyarrow.parquet as pq
from PIL import Image, ImageFile, PngImagePlugin
from .data_utils import load_image, pil_img2rgb
from .distributed_iterable_dataset import DistributedIterableDataset
from .parquet_utils import get_parquet_data_paths, init_arrow_pf_fs
Image.MAX_IMAGE_PIXELS = 200000000
ImageFile.LOAD_TRUNCATED_IMAGES = True
MaximumDecompressedSize = 1024
MegaByte = 2**20
PngImagePlugin.MAX_TEXT_CHUNK = MaximumDecompressedSize * MegaByte
class EditJSONLIterableDataset(DistributedIterableDataset):
def _add_text(self, sample, text, need_loss, enable_cfg=True):
text_ids = self.tokenizer.encode(text)
sample["num_tokens"] += len(text_ids)
sample["text_ids_list"].append(text_ids)
sample["sequence_plan"].append(
{
"type": "text",
"enable_cfg": int(enable_cfg),
"loss": int(need_loss),
"special_token_loss": 0,
"special_token_label": None,
}
)
return sample
def _resize_and_pad(self, img: Image.Image, is_mask=False) -> Image.Image:
"""根据 __init__ 里解析好的 fixed_size 进行 resize/pad"""
if self.fixed_size == None:
return img
interp = Image.NEAREST if is_mask else Image.BICUBIC
# case1: (H,W) 矩形 resize
# if self.fixed_mode == "rect":
target_h, target_w = self.fixed_size, self.fixed_size
return img.resize((target_w, target_h), interp)
def _add_image(self, sample, image, need_loss, need_vae, need_vit, enable_cfg=True):
assert need_loss or need_vae or need_vit
if need_loss:
sample["sequence_plan"].append(
{
"type": "vae_image",
"enable_cfg": 0,
"loss": 1,
"special_token_loss": 0,
"special_token_label": None,
}
)
image_tensor = self.transform(image)
height, width = image_tensor.shape[1:]
sample["num_tokens"] += width * height // self.transform.stride**2
sample["image_tensor_list"].append(image_tensor)
if need_vae:
sample["sequence_plan"].append(
{
"type": "vae_image",
"enable_cfg": int(enable_cfg),
"loss": 0,
"special_token_loss": 0,
"special_token_label": None,
}
)
image_tensor = self.transform(image)
height, width = image_tensor.shape[1:]
sample["num_tokens"] += width * height // self.transform.stride**2
sample["image_tensor_list"].append(image_tensor.clone())
if need_vit:
sample["sequence_plan"].append(
{
"type": "vit_image",
"enable_cfg": int(enable_cfg),
"loss": 0,
"special_token_loss": 0,
"special_token_label": None,
},
)
vit_image_tensor = self.vit_transform(image)
height, width = vit_image_tensor.shape[1:]
sample["num_tokens"] += width * height // self.vit_transform.stride**2
sample["image_tensor_list"].append(vit_image_tensor)
return sample
def __init__(
self,
dataset_name,
transform,
tokenizer,
vit_transform,
jsonl_path_list,
data_dir_list,
num_used_data,
local_rank=0,
world_size=1,
num_workers=8,
data_status=None,
shuffle_lines=False,
shuffle_seed=0,
fixed_size=None,
):
"""
jsonl_path_list: list of jsonl file paths
data_dir_list: list of image directories containing the images of each jsonl file
num_used_data: list of number of sampled data points for each jsonl
"""
super().__init__(dataset_name, local_rank, world_size, num_workers)
self.transform = transform
if fixed_size is None:
self.fixed_size = None
else:
self.fixed_size = fixed_size
self.tokenizer = tokenizer
self.vit_transform = vit_transform
self.data_status = data_status
self.data_paths = self.get_data_paths(
jsonl_path_list,
data_dir_list,
num_used_data,
shuffle_lines,
shuffle_seed,
)
self.set_epoch()
def get_data_paths(
self,
jsonl_path_list,
data_dir_list,
num_used_data,
shuffle_lines,
shuffle_seed,
):
data_paths = []
for jsonl_path, image_dir, num_data_point in zip(
jsonl_path_list, data_dir_list, num_used_data
):
with open(jsonl_path, "r") as f:
raw_data = f.readlines()
if shuffle_lines:
self.rng.seed(shuffle_seed)
self.rng.shuffle(raw_data)
raw_data = raw_data[:num_data_point]
data_paths.extend([(json_data, image_dir) for json_data in raw_data])
return data_paths
def __iter__(self):
data_paths_per_worker, worker_id = self.get_data_paths_per_worker()
if self.data_status is not None:
row_start_id = self.data_status[worker_id] + 1
else:
row_start_id = 0
transform_stride = self.transform.stride
print(
f"rank-{self.local_rank} worker-{worker_id} dataset-{self.dataset_name}: "
f"resuming data at row#{row_start_id}"
)
while True:
data_paths_per_worker_ = data_paths_per_worker[row_start_id:]
for row_idx, (data, image_dir) in enumerate(
data_paths_per_worker_, start=row_start_id
):
sample = {
"sequence_plan": [],
"text_ids_list": [],
"image_tensor_list": [],
"num_tokens": 0,
}
# try:
data_item = json.loads(data)
sample = self._add_image(
sample,
# pil_img2rgb(Image.open(os.path.join(image_dir, data_item['image'][0]))),
pil_img2rgb(
self._resize_and_pad(
load_image(os.path.join(image_dir, data_item["image"][0]))
)
),
need_loss=False,
need_vae=True,
need_vit=True,
)
if "instruction" in data_item:
instruction = data_item["instruction"]
elif "conversations" in data_item:
conversations = data_item["conversations"]
if len(conversations) == 2:
if conversations[0]["from"] == "human":
instruction = conversations[0]["value"].replace(
"<image>", ""
)
# instruction = data_item['conversation']
else:
print("no caption in " + data_item)
sample = self._add_text(sample, instruction.rstrip(), need_loss=False)
sample = self._add_image(
sample,
# pil_img2rgb(Image.open(os.path.join(image_dir, data_item['image'][1]))),
pil_img2rgb(
self._resize_and_pad(
load_image(os.path.join(image_dir, data_item["image"][1]))
)
),
need_loss=True,
need_vae=False,
need_vit=False,
)
# except:
# print(f"Error in row {row_idx}")
# continue
sample["data_indexes"] = {
"data_indexes": row_idx,
"worker_id": worker_id,
"dataset_name": self.dataset_name,
}
# print('image[0]: ',sample['image_tensor_list'][0].shape)
# print('image[1]: ',sample['image_tensor_list'][1].shape)
yield sample
row_start_id = 0
print(
f"{self.dataset_name} repeat in rank-{self.local_rank} worker-{worker_id}"
)
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
import logging
import os
import subprocess
import pyarrow.fs as pf
import torch.distributed as dist
logger = logging.getLogger(__name__)
def get_parquet_data_paths(data_dir_list, num_sampled_data_paths, rank=0, world_size=1):
num_data_dirs = len(data_dir_list)
if world_size > 1:
chunk_size = (num_data_dirs + world_size - 1) // world_size
start_idx = rank * chunk_size
end_idx = min(start_idx + chunk_size, num_data_dirs)
local_data_dir_list = data_dir_list[start_idx:end_idx]
local_num_sampled_data_paths = num_sampled_data_paths[start_idx:end_idx]
else:
local_data_dir_list = data_dir_list
local_num_sampled_data_paths = num_sampled_data_paths
local_data_paths = []
for data_dir, num_data_path in zip(
local_data_dir_list, local_num_sampled_data_paths
):
if data_dir.startswith("hdfs://"):
files = hdfs_ls_cmd(data_dir)
data_paths_per_dir = [file for file in files if file.endswith(".parquet")]
else:
files = os.listdir(data_dir)
data_paths_per_dir = [
os.path.join(data_dir, name)
for name in files
if name.endswith(".parquet")
]
repeat = num_data_path // len(data_paths_per_dir)
data_paths_per_dir = data_paths_per_dir * (repeat + 1)
local_data_paths.extend(data_paths_per_dir[:num_data_path])
if world_size > 1:
gather_list = [None] * world_size
dist.all_gather_object(gather_list, local_data_paths)
combined_chunks = []
for chunk_list in gather_list:
if chunk_list is not None:
combined_chunks.extend(chunk_list)
else:
combined_chunks = local_data_paths
return combined_chunks
# NOTE: cumtomize this function for your cluster
def get_hdfs_host():
return "hdfs://xxx"
# NOTE: cumtomize this function for your cluster
def get_hdfs_block_size():
return 134217728
# NOTE: cumtomize this function for your cluster
def get_hdfs_extra_conf():
return None
def init_arrow_pf_fs(parquet_file_path):
if parquet_file_path.startswith("hdfs://"):
fs = pf.HadoopFileSystem(
host=get_hdfs_host(),
port=0,
buffer_size=get_hdfs_block_size(),
extra_conf=get_hdfs_extra_conf(),
)
else:
fs = pf.LocalFileSystem()
return fs
def hdfs_ls_cmd(dir):
result = subprocess.run(
["hdfs", "dfs", "ls", dir], capture_output=True, text=True
).stdout
return [
"hdfs://" + i.split("hdfs://")[-1].strip()
for i in result.split("\n")
if "hdfs://" in i
]
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
import io
import json
import random
import pyarrow.parquet as pq
from PIL import Image
from .data_utils import pil_img2rgb
from .distributed_iterable_dataset import DistributedIterableDataset
from .parquet_utils import get_parquet_data_paths, init_arrow_pf_fs
Image.MAX_IMAGE_PIXELS = 20_000_000
class T2IIterableDataset(DistributedIterableDataset):
def __init__(
self,
dataset_name,
transform,
tokenizer,
data_dir_list,
num_used_data,
local_rank=0,
world_size=1,
num_workers=8,
data_status=None,
):
"""
data_dir_list: list of data directories contains parquet files
num_used_data: list of number of sampled data paths for each data directory
"""
super().__init__(dataset_name, local_rank, world_size, num_workers)
self.transform = transform
self.tokenizer = tokenizer
self.data_status = data_status
self.data_paths = self.get_data_paths(data_dir_list, num_used_data)
self.set_epoch()
def get_data_paths(self, data_dir_list, num_used_data):
return get_parquet_data_paths(data_dir_list, num_used_data)
def __iter__(self):
data_paths_per_worker, worker_id = self.get_data_paths_per_worker()
if self.data_status is not None:
parquet_start_id = self.data_status[worker_id][0]
row_group_start_id = self.data_status[worker_id][1]
row_start_id = self.data_status[worker_id][2] + 1
else:
parquet_start_id = 0
row_group_start_id = 0
row_start_id = 0
transform_stride = self.transform.stride
print(
f"rank-{self.local_rank} worker-{worker_id} dataset-{self.dataset_name}: "
f"resuming data at parquet#{parquet_start_id}, rg#{row_group_start_id}, row#{row_start_id}"
)
while True:
data_paths_per_worker_ = data_paths_per_worker[parquet_start_id:]
for parquet_idx, parquet_file_path in enumerate(
data_paths_per_worker_, start=parquet_start_id
):
fs = init_arrow_pf_fs(parquet_file_path)
with fs.open_input_file(parquet_file_path) as f:
fr = pq.ParquetFile(f)
row_group_ids = list(range(fr.num_row_groups))
row_group_ids_ = row_group_ids[row_group_start_id:]
for row_group_id in row_group_ids_:
df = fr.read_row_group(row_group_id).to_pandas()
df = df.iloc[row_start_id:]
for row_idx, row in df.iterrows():
num_tokens = 0
try:
image_byte = row["image"]
image = pil_img2rgb(Image.open(io.BytesIO(image_byte)))
except Exception as e:
print(
f"Error: {e} in rg#{row_group_id}, {parquet_file_path}"
)
continue
image_tensor = self.transform(image)
height, width = image_tensor.shape[1:]
num_tokens += width * height // transform_stride**2
try:
caption_dict = row["captions"]
caption_dict = json.loads(caption_dict)
except Exception as e:
print(
f"Error: {e} in rg#{row_group_id}, {parquet_file_path}"
)
continue
caps_token = [
self.tokenizer.encode(v)
for _, v in caption_dict.items()
]
if len(caps_token) == 0:
print(
f"no caption in rg#{row_group_id}, {parquet_file_path}"
)
caption_token = self.tokenizer.encode(" ")
else:
caption_token = random.choice(caps_token)
sequence_plan, text_ids_list = [], []
text_ids = caption_token
num_tokens += len(caption_token)
text_ids_list.append(text_ids)
sequence_plan.append(
{
"type": "text",
"enable_cfg": 1,
"loss": 0,
"special_token_loss": 0,
"special_token_label": None,
}
)
sequence_plan.append(
{
"type": "vae_image",
"enable_cfg": 0,
"loss": 1,
"special_token_loss": 0,
"special_token_label": None,
}
)
sample = dict(
image_tensor_list=[image_tensor],
text_ids_list=text_ids_list,
num_tokens=num_tokens,
sequence_plan=sequence_plan,
data_indexes={
"data_indexes": [
parquet_idx,
row_group_id,
row_idx,
],
"worker_id": worker_id,
"dataset_name": self.dataset_name,
},
)
yield sample
row_start_id = 0
row_group_start_id = 0
parquet_start_id = 0
print(
f"{self.dataset_name} repeat in rank-{self.local_rank} worker-{worker_id}"
)
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
import io
import json
import os
import random
import traceback
import pyarrow.parquet as pq
from PIL import Image
from .data_utils import load_image, pil_img2rgb
from .distributed_iterable_dataset import DistributedIterableDataset
from .parquet_utils import get_parquet_data_paths, init_arrow_pf_fs
Image.MAX_IMAGE_PIXELS = 200_000_000
class T2IJSONLIterableDataset(DistributedIterableDataset):
def __init__(
self,
dataset_name,
transform,
tokenizer,
jsonl_path_list,
data_dir_list,
num_used_data,
local_rank=0,
world_size=1,
num_workers=8,
data_status=None,
):
"""
data_dir_list: list of data directories contains parquet files
num_used_data: list of number of sampled data paths for each data directory
"""
super().__init__(dataset_name, local_rank, world_size, num_workers)
self.transform = transform
self.tokenizer = tokenizer
self.data_status = data_status
self.data_paths = self.get_data_paths(
jsonl_path_list, data_dir_list, num_used_data
)
self.set_epoch()
def get_data_paths(self, jsonl_path_list, data_dir_list, num_used_data):
data_paths = []
for jsonl_path, image_dir, num_data_point in zip(
jsonl_path_list, data_dir_list, num_used_data
):
with open(jsonl_path, "r") as f:
raw_data = f.readlines()
raw_data = raw_data[:num_data_point]
data_paths.extend([(json_data, image_dir) for json_data in raw_data])
return data_paths
def __iter__(self):
data_paths_per_worker, worker_id = self.get_data_paths_per_worker()
if self.data_status is not None:
row_start_id = self.data_status[worker_id] + 1
else:
row_start_id = 0
transform_stride = self.transform.stride
print(
f"rank-{self.local_rank} worker-{worker_id} dataset-{self.dataset_name}: "
f"resuming data at row#{row_start_id}"
)
while True:
data_paths_per_worker_ = data_paths_per_worker[row_start_id:]
for row_idx, (data, image_dir) in enumerate(
data_paths_per_worker_, start=row_start_id
):
num_tokens = 0
try:
data_item = json.loads(data)
image = None
if "image" in data_item:
image = pil_img2rgb(
load_image(os.path.join(image_dir, data_item["image"]))
)
except Exception as e:
# print(f'Error: {e} in rg#{row_group_id}, {parquet_file_path}')
print(f"Erroe image: {e} in {data} in {self.dataset_name}")
traceback.print_exc()
continue
image_tensor = self.transform(image)
height, width = image_tensor.shape[1:]
num_tokens += width * height // transform_stride**2
try:
if "conversations" in data_item:
caption_list = data_item["conversations"]
if caption_list[0]["from"] == "human":
caption_str = caption_list[0]["value"]
caption_dict = {"captions": caption_str}
# if 'captions' in row.keys():
# caption_dict = row['captions']
# caption_dict = json.loads(caption_dict)
# elif 'txt' in row.keys():
# caption_str = row['txt']
# caption_dict = {'captions':caption_str}
except Exception as e:
print(f"Error caption: {e} in {data} in {self.dataset_name}")
continue
caps_token = [self.tokenizer.encode(v) for _, v in caption_dict.items()]
if len(caps_token) == 0:
print(f"no caption in {data} in {self.dataset_name}")
caption_token = self.tokenizer.encode(" ")
else:
caption_token = random.choice(caps_token)
sequence_plan, text_ids_list = [], []
text_ids = caption_token
num_tokens += len(caption_token)
text_ids_list.append(text_ids)
sequence_plan.append(
{
"type": "text",
"enable_cfg": 1,
"loss": 0,
"special_token_loss": 0,
"special_token_label": None,
}
)
sequence_plan.append(
{
"type": "vae_image",
"enable_cfg": 0,
"loss": 1,
"special_token_loss": 0,
"special_token_label": None,
}
)
sample = dict(
image_tensor_list=[image_tensor],
text_ids_list=text_ids_list,
num_tokens=num_tokens,
sequence_plan=sequence_plan,
data_indexes={
"data_indexes": row_idx,
"worker_id": worker_id,
"dataset_name": self.dataset_name,
},
)
yield sample
row_start_id = 0
print(
f"{self.dataset_name} repeat in rank-{self.local_rank} worker-{worker_id}"
)
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
import random
import cv2
import numpy as np
import torch
from PIL import Image
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from torchvision.transforms import functional as F
class MaxLongEdgeMinShortEdgeResize(torch.nn.Module):
"""Resize the input image so that its longest side and shortest side are within a specified range,
ensuring that both sides are divisible by a specified stride.
Args:
max_size (int): Maximum size for the longest edge of the image.
min_size (int): Minimum size for the shortest edge of the image.
stride (int): Value by which the height and width of the image must be divisible.
max_pixels (int): Maximum pixels for the full image.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``,
``InterpolationMode.BILINEAR``, and ``InterpolationMode.BICUBIC`` are supported.
The corresponding Pillow integer constants, e.g., ``PIL.Image.BILINEAR`` are also accepted.
antialias (bool, optional): Whether to apply antialiasing (default is True).
"""
def __init__(
self,
max_size: int,
min_size: int,
stride: int,
max_pixels: int,
interpolation=InterpolationMode.BICUBIC,
antialias=True,
):
super().__init__()
self.max_size = max_size
self.min_size = min_size
self.stride = stride
self.max_pixels = max_pixels
self.interpolation = interpolation
self.antialias = antialias
def _make_divisible(self, value, stride):
"""Ensure the value is divisible by the stride."""
return max(stride, int(round(value / stride) * stride))
def _apply_scale(self, width, height, scale):
new_width = round(width * scale)
new_height = round(height * scale)
new_width = self._make_divisible(new_width, self.stride)
new_height = self._make_divisible(new_height, self.stride)
return new_width, new_height
def forward(self, img, img_num=1):
"""
Args:
img (PIL Image): Image to be resized.
img_num (int): Number of images, used to change max_tokens.
Returns:
PIL Image or Tensor: Rescaled image with divisible dimensions.
"""
if isinstance(img, torch.Tensor):
height, width = img.shape[-2:]
else:
width, height = img.size
scale = min(self.max_size / max(width, height), 1.0)
scale = max(scale, self.min_size / min(width, height))
new_width, new_height = self._apply_scale(width, height, scale)
# Ensure the number of pixels does not exceed max_pixels
if new_width * new_height > self.max_pixels / img_num:
scale = self.max_pixels / img_num / (new_width * new_height)
new_width, new_height = self._apply_scale(new_width, new_height, scale)
# Ensure longest edge does not exceed max_size
if max(new_width, new_height) > self.max_size:
scale = self.max_size / max(new_width, new_height)
new_width, new_height = self._apply_scale(new_width, new_height, scale)
return F.resize(
img, (new_height, new_width), self.interpolation, antialias=self.antialias
)
class ImageTransform:
def __init__(
self,
max_image_size,
min_image_size,
image_stride,
max_pixels=14 * 14 * 9 * 1024,
image_mean=[0.5, 0.5, 0.5],
image_std=[0.5, 0.5, 0.5],
):
self.stride = image_stride
self.resize_transform = MaxLongEdgeMinShortEdgeResize(
max_size=max_image_size,
min_size=min_image_size,
stride=image_stride,
max_pixels=max_pixels,
)
self.to_tensor_transform = transforms.ToTensor()
self.normalize_transform = transforms.Normalize(
mean=image_mean, std=image_std, inplace=True
)
def __call__(self, img, img_num=1):
img = self.resize_transform(img, img_num=img_num)
img = self.to_tensor_transform(img)
img = self.normalize_transform(img)
return img
def decolorization(image):
gray_image = image.convert("L")
return (
Image.merge(image.mode, [gray_image] * 3)
if image.mode in ("RGB", "L")
else gray_image
)
def downscale(image, scale_factor):
new_width = int(round(image.width * scale_factor))
new_height = int(round(image.height * scale_factor))
new_width = max(1, new_width)
new_height = max(1, new_height)
return image.resize((new_width, new_height), resample=Image.BICUBIC)
def crop(image, crop_factors):
target_h, target_w = crop_factors
img_w, img_h = image.size
if target_h > img_h or target_w > img_w:
raise ValueError("Crop size exceeds image dimensions")
x = random.randint(0, img_w - target_w)
y = random.randint(0, img_h - target_h)
return image.crop((x, y, x + target_w, y + target_h)), [
[x, y],
[x + target_w, y + target_h],
]
def motion_blur_opencv(image, kernel_size=15, angle=0):
# 线性核
kernel = np.zeros((kernel_size, kernel_size), dtype=np.float32)
kernel[kernel_size // 2, :] = np.ones(kernel_size, dtype=np.float32)
# 旋转核
center = (kernel_size / 2 - 0.5, kernel_size / 2 - 0.5)
M = cv2.getRotationMatrix2D(center, angle, 1)
rotated_kernel = cv2.warpAffine(kernel, M, (kernel_size, kernel_size))
# 归一化核
rotated_kernel /= rotated_kernel.sum() if rotated_kernel.sum() != 0 else 1
img = np.array(image)
if img.ndim == 2:
blurred = cv2.filter2D(img, -1, rotated_kernel, borderType=cv2.BORDER_REFLECT)
else:
# 对于彩色图像,各通道独立卷积
blurred = np.zeros_like(img)
for c in range(img.shape[2]):
blurred[..., c] = cv2.filter2D(
img[..., c], -1, rotated_kernel, borderType=cv2.BORDER_REFLECT
)
return Image.fromarray(blurred.astype(np.uint8))
def shuffle_patch(image, num_splits, gap_size=2):
"""将图像分割为块(允许尺寸不整除),随机打乱后拼接,块间保留间隙"""
h_splits, w_splits = num_splits
img_w, img_h = image.size
base_patch_h = img_h // h_splits
patch_heights = [base_patch_h] * (h_splits - 1)
patch_heights.append(img_h - sum(patch_heights))
base_patch_w = img_w // w_splits
patch_widths = [base_patch_w] * (w_splits - 1)
patch_widths.append(img_w - sum(patch_widths))
patches = []
current_y = 0
for i in range(h_splits):
current_x = 0
patch_h = patch_heights[i]
for j in range(w_splits):
patch_w = patch_widths[j]
patch = image.crop(
(current_x, current_y, current_x + patch_w, current_y + patch_h)
)
patches.append(patch)
current_x += patch_w
current_y += patch_h
random.shuffle(patches)
total_width = sum(patch_widths) + (w_splits - 1) * gap_size
total_height = sum(patch_heights) + (h_splits - 1) * gap_size
new_image = Image.new(
image.mode, (total_width, total_height), color=(255, 255, 255)
)
current_y = 0 # 当前行的起始 Y 坐标
patch_idx = 0 # 当前处理的块索引
for i in range(h_splits):
current_x = 0 # 当前列的起始 X 坐标
patch_h = patch_heights[i] # 当前行块的高度
for j in range(w_splits):
# 取出打乱后的块
patch = patches[patch_idx]
patch_w = patch_widths[j] # 当前列块的宽度
# 粘贴块(左上角坐标为 (current_x, current_y))
new_image.paste(patch, (current_x, current_y))
# 更新 X 坐标(下一个块的起始位置 = 当前块宽度 + 间隙)
current_x += patch_w + gap_size
patch_idx += 1
# 更新 Y 坐标(下一行的起始位置 = 当前行高度 + 间隙)
current_y += patch_h + gap_size
return new_image
def inpainting(image, num_splits, blank_ratio=0.3, blank_color=(255, 255, 255)):
"""
图像分割后随机空白部分patch,用于inpainting任务
参数:
image: PIL.Image 输入图像(RGB模式)
h_splits: int 行分割数(垂直方向分割块数)
w_splits: int 列分割数(水平方向分割块数)
blank_ratio: float 空白patch的比例(0~1)
blank_color: tuple 空白区域的颜色(RGB,如白色(255,255,255))
返回:
PIL.Image 处理后拼接的图像
"""
h_splits, w_splits = num_splits
img_w, img_h = image.size
base_patch_h = img_h // h_splits
patch_heights = [base_patch_h] * (h_splits - 1)
patch_heights.append(img_h - sum(patch_heights))
base_patch_w = img_w // w_splits
patch_widths = [base_patch_w] * (w_splits - 1)
patch_widths.append(img_w - sum(patch_widths))
patches = []
current_y = 0
for i in range(h_splits):
current_x = 0
patch_h = patch_heights[i]
for j in range(w_splits):
patch_w = patch_widths[j]
patch = image.crop(
(current_x, current_y, current_x + patch_w, current_y + patch_h)
)
patches.append(patch)
current_x += patch_w
current_y += patch_h
total_patches = h_splits * w_splits
num_blank = int(total_patches * blank_ratio)
num_blank = max(0, min(num_blank, total_patches))
blank_indices = random.sample(range(total_patches), num_blank)
processed_patches = []
for idx, patch in enumerate(patches):
if idx in blank_indices:
blank_patch = Image.new("RGB", patch.size, color=blank_color)
processed_patches.append(blank_patch)
else:
processed_patches.append(patch)
# 创建结果图像(尺寸与原图一致)
result_image = Image.new("RGB", (img_w, img_h))
current_y = 0
patch_idx = 0
for i in range(h_splits):
current_x = 0
patch_h = patch_heights[i]
for j in range(w_splits):
# 取出处理后的patch
patch = processed_patches[patch_idx]
patch_w = patch_widths[j]
# 粘贴到原位置
result_image.paste(patch, (current_x, current_y))
current_x += patch_w
patch_idx += 1
current_y += patch_h
return result_image
# Copyright (c) 2023 OpenGVLab
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: MIT
#
# This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025-05-20.
#
# Original file was released under MIT, with the full license text
# available at https://github.com/OpenGVLab/InternVL/blob/main/LICENSE.
#
# This modified file is released under the same license.
import io
import os
import random
import re
import decord
import numpy as np
from PIL import Image
def get_frame_indices(
num_frames, vlen, sample="rand", fix_start=None, input_fps=1, max_num_frames=-1
):
if sample in ["rand", "middle"]: # uniform sampling
acc_samples = min(num_frames, vlen)
# split the video into `acc_samples` intervals, and sample from each interval.
intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int)
ranges = []
for idx, interv in enumerate(intervals[:-1]):
ranges.append((interv, intervals[idx + 1] - 1))
if sample == "rand":
try:
frame_indices = [random.choice(range(x[0], x[1])) for x in ranges]
except:
frame_indices = np.random.permutation(vlen)[:acc_samples]
frame_indices.sort()
frame_indices = list(frame_indices)
elif fix_start is not None:
frame_indices = [x[0] + fix_start for x in ranges]
elif sample == "middle":
frame_indices = [(x[0] + x[1]) // 2 for x in ranges]
else:
raise NotImplementedError
if len(frame_indices) < num_frames: # padded with last frame
padded_frame_indices = [frame_indices[-1]] * num_frames
padded_frame_indices[: len(frame_indices)] = frame_indices
frame_indices = padded_frame_indices
elif "fps" in sample: # fps0.5, sequentially sample frames at 0.5 fps
output_fps = float(sample[3:])
duration = float(vlen) / input_fps
delta = (
1 / output_fps
) # gap between frames, this is also the clip length each frame represents
frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta)
frame_indices = np.around(frame_seconds * input_fps).astype(int)
frame_indices = [e for e in frame_indices if e < vlen]
if max_num_frames > 0 and len(frame_indices) > max_num_frames:
frame_indices = frame_indices[:max_num_frames]
else:
raise ValueError
return frame_indices
def read_frames_decord(
video_path, num_frames, sample="rand", fix_start=None, clip=None, min_num_frames=4
):
video_reader = decord.VideoReader(video_path, num_threads=1)
vlen = len(video_reader)
fps = video_reader.get_avg_fps()
duration = vlen / float(fps)
if clip:
start, end = clip
duration = end - start
vlen = int(duration * fps)
start_index = int(start * fps)
t_num_frames = np.random.randint(min_num_frames, num_frames + 1)
frame_indices = get_frame_indices(
t_num_frames, vlen, sample=sample, fix_start=fix_start, input_fps=fps
)
if clip:
frame_indices = [f + start_index for f in frame_indices]
frames = video_reader.get_batch(frame_indices).asnumpy() # (T, H, W, C), np.uint8
frames = [Image.fromarray(frames[i]) for i in range(frames.shape[0])]
return frames
def extract_frame_number(filename):
# Extract the numeric part from the filename using regular expressions
match = re.search(r"_(\d+).jpg$", filename)
return int(match.group(1)) if match else -1
def sort_frames(frame_paths):
# Extract filenames from each path and sort by their numeric part
return sorted(frame_paths, key=lambda x: extract_frame_number(os.path.basename(x)))
def read_frames_folder(
video_path, num_frames, sample="rand", fix_start=None, min_num_frames=4
):
image_list = sort_frames(list(os.listdir(video_path)))
frames = []
for image in image_list:
fp = os.path.join(video_path, image)
frame = Image.open(fp).convert("RGB")
frames.append(frame)
vlen = len(frames)
t_num_frames = np.random.randint(min_num_frames, num_frames + 1)
if vlen > t_num_frames:
frame_indices = get_frame_indices(
t_num_frames, vlen, sample=sample, fix_start=fix_start
)
frames = [frames[i] for i in frame_indices]
return frames
class FrameSampler:
def __init__(self, max_num_frames=-1, min_num_frames=8, sample="rand"):
self.max_num_frames = max_num_frames
self.min_num_frames = min_num_frames
self.sample = sample
def __call__(self, file_name):
fn = read_frames_folder if file_name.endswith("/") else read_frames_decord
frames = fn(
file_name,
num_frames=self.max_num_frames,
min_num_frames=self.min_num_frames,
sample=self.sample,
)
return frames
def decode_video_byte(video_bytes):
video_stream = io.BytesIO(video_bytes)
vr = decord.VideoReader(video_stream)
return vr
def sample_mp4_frames(
mp4_p, n_frames=None, fps=None, return_frame_indices=False, random_sample=False
):
if isinstance(mp4_p, str):
vr = decord.VideoReader(mp4_p, num_threads=1)
elif isinstance(mp4_p, decord.video_reader.VideoReader):
vr = mp4_p
video_fps = vr.get_avg_fps() # 获取视频的帧率
video_duration = len(vr) / video_fps
if n_frames is not None:
if random_sample:
frame_indices = sorted(random.sample(range(len(vr)), n_frames))
else:
frame_indices = np.linspace(0, len(vr) - 1, n_frames, dtype=int).tolist()
else:
frame_indices = [int(i) for i in np.arange(0, len(vr) - 1, video_fps / fps)]
frames = vr.get_batch(frame_indices).asnumpy() # 转换为 numpy 数组
frames = [Image.fromarray(frame).convert("RGB") for frame in frames]
if not return_frame_indices:
return frames, video_duration
else:
return frames, video_duration, frame_indices
def sample_mp4_frames_by_indices(mp4_p, frame_indices: list):
if isinstance(mp4_p, str):
vr = decord.VideoReader(mp4_p, num_threads=1)
elif isinstance(mp4_p, decord.video_reader.VideoReader):
vr = mp4_p
# sample the frames in frame_indices
frames = vr.get_batch(frame_indices).asnumpy() # 转换为 numpy 数组
frames = [Image.fromarray(frame).convert("RGB") for frame in frames]
return frames
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
import json
import os
import traceback
from PIL import Image, ImageFile, PngImagePlugin
from .data_utils import load_image, pil_img2rgb
from .distributed_iterable_dataset import DistributedIterableDataset
Image.MAX_IMAGE_PIXELS = 200000000
ImageFile.LOAD_TRUNCATED_IMAGES = True
MaximumDecompressedSize = 1024
MegaByte = 2**20
PngImagePlugin.MAX_TEXT_CHUNK = MaximumDecompressedSize * MegaByte
class SftJSONLIterableDataset(DistributedIterableDataset):
def __init__(
self,
dataset_name,
transform,
tokenizer,
frame_sampler,
jsonl_path_list,
data_dir_list,
num_used_data,
local_rank=0,
world_size=1,
num_workers=8,
data_status=None,
shuffle_lines=False,
shuffle_seed=0,
):
"""
jsonl_path_list: list of jsonl file paths
data_dir_list: list of image directories containing the images of each jsonl file
num_used_data: list of number of sampled data points for each jsonl
"""
super().__init__(dataset_name, local_rank, world_size, num_workers)
self.transform = transform
self.tokenizer = tokenizer
self.frame_sampler = frame_sampler
self.data_status = data_status
self.data_paths = self.get_data_paths(
jsonl_path_list,
data_dir_list,
num_used_data,
shuffle_lines,
shuffle_seed,
)
self.set_epoch()
def get_data_paths(
self,
jsonl_path_list,
data_dir_list,
num_used_data,
shuffle_lines,
shuffle_seed,
):
data_paths = []
for jsonl_path, image_dir, num_data_point in zip(
jsonl_path_list, data_dir_list, num_used_data
):
with open(jsonl_path, "r") as f:
raw_data = f.readlines()
if shuffle_lines:
self.rng.seed(shuffle_seed)
self.rng.shuffle(raw_data)
raw_data = raw_data[:num_data_point]
data_paths.extend([(json_data, image_dir) for json_data in raw_data])
return data_paths
def change_format(self, data, num_images):
elements = []
for conversation in data["conversations"]:
if conversation["from"] == "human":
if "<image>" not in conversation["value"]:
elements.append(
{
"type": "text",
"has_loss": 0,
"text": conversation["value"],
}
)
else:
text_list = conversation["value"].split("<image>")
for idx, text in enumerate(text_list):
if text.strip() != "":
elements.append(
{
"type": "text",
"has_loss": 0,
"text": text.strip(),
}
)
if (idx != len(text_list) - 1) and (idx < num_images):
elements.append(
{
"type": "image",
}
)
elif conversation["from"] == "gpt":
elements.append(
{
"type": "text",
"has_loss": 1,
"text": conversation["value"],
}
)
return elements
def __iter__(self):
data_paths_per_worker, worker_id = self.get_data_paths_per_worker()
if self.data_status is not None:
row_start_id = self.data_status[worker_id] + 1
else:
row_start_id = 0
transform_stride = self.transform.stride
print(
f"rank-{self.local_rank} worker-{worker_id} dataset-{self.dataset_name}: "
f"resuming data at row#{row_start_id}"
)
while True:
data_paths_per_worker_ = data_paths_per_worker[row_start_id:]
for row_idx, (data, image_dir) in enumerate(
data_paths_per_worker_, start=row_start_id
):
num_tokens = 0
image_tensor_list = []
text_ids_list = []
sequence_plan = []
try:
data_item = json.loads(data)
raw_images = None
if "image" in data_item:
if type(data_item["image"]) == list:
raw_images = [
pil_img2rgb(load_image(os.path.join(image_dir, image)))
for image in data_item["image"]
]
else:
raw_images = [
pil_img2rgb(
load_image(
os.path.join(image_dir, data_item["image"])
)
)
]
elif "video" in data_item:
raw_images = self.frame_sampler(
os.path.join(image_dir, data_item["video"])
)
special_tokens = "<image>" * len(raw_images)
for item in data_item["conversations"]:
if "<video>" in item["value"]:
item["value"] = item["value"].replace(
"<video>", special_tokens
)
break
else:
raise ValueError(
"Cannot find <video> in the conversation!"
)
except:
traceback.print_exc()
continue
if raw_images:
for raw_image in raw_images:
image_tensor = self.transform(
raw_image, img_num=len(raw_images)
)
image_tensor_list.append(image_tensor)
height, width = image_tensor.shape[1:]
num_tokens += width * height // transform_stride**2
elements = self.change_format(data_item, len(image_tensor_list))
for item in elements:
if item["type"] == "text":
text_data = item["text"]
text_ids = self.tokenizer.encode(text_data)
if len(text_ids) > 0:
text_ids_list.append(text_ids)
num_tokens += len(text_ids)
current_plan = {
"type": "text",
"enable_cfg": 0,
"loss": item["has_loss"],
"special_token_loss": 0,
"special_token_label": None,
}
sequence_plan.append(current_plan)
elif item["type"] == "image":
current_plan = {
"type": "vit_image",
"enable_cfg": 0,
"loss": 0,
"special_token_loss": 0,
"special_token_label": None,
}
sequence_plan.append(current_plan)
has_loss = [item["loss"] for item in sequence_plan]
if sum(has_loss) == 0:
print(f"No loss defined, skipped.")
continue
yield dict(
image_tensor_list=image_tensor_list,
text_ids_list=text_ids_list,
sequence_plan=sequence_plan,
num_tokens=num_tokens,
data_indexes={
"data_indexes": row_idx,
"worker_id": worker_id,
"dataset_name": self.dataset_name,
},
)
row_start_id = 0
print(
f"{self.dataset_name} repeat in rank-{self.local_rank} worker-{worker_id}"
)
name: bagel
channels:
- defaults
dependencies:
- _libgcc_mutex=0.1=main
- _openmp_mutex=5.1=1_gnu
- bzip2=1.0.8=h5eee18b_6
- ca-certificates=2025.2.25=h06a4308_0
- ld_impl_linux-64=2.40=h12ee557_0
- libffi=3.4.4=h6a678d5_1
- libgcc-ng=11.2.0=h1234567_1
- libgomp=11.2.0=h1234567_1
- libstdcxx-ng=11.2.0=h1234567_1
- libuuid=1.41.5=h5eee18b_0
- ncurses=6.4=h6a678d5_0
- openssl=3.0.16=h5eee18b_0
- pip=25.1=pyhc872135_2
- python=3.10.16=he870216_1
- readline=8.2=h5eee18b_0
- setuptools=78.1.1=py310h06a4308_0
- sqlite=3.45.3=h5eee18b_0
- tk=8.6.14=h39e8969_0
- wheel=0.45.1=py310h06a4308_0
- xz=5.6.4=h5eee18b_1
- zlib=1.2.13=h5eee18b_1
- pip:
- accelerate==1.7.0
- annotated-types==0.7.0
- certifi==2025.4.26
- charset-normalizer==3.4.2
- click==8.2.1
- contourpy==1.3.2
- cycler==0.12.1
- decord==0.6.0
- docker-pycreds==0.4.0
- einops==0.8.1
- filelock==3.18.0
- fonttools==4.58.0
- fsspec==2025.5.1
- gitdb==4.0.12
- gitpython==3.1.44
- huggingface-hub==0.29.1
- idna==3.10
- jinja2==3.1.6
- kiwisolver==1.4.8
- markupsafe==3.0.2
- matplotlib==3.7.0
- mpmath==1.3.0
- networkx==3.4.2
- ninja==1.11.1.4
- numpy==1.24.4
- nvidia-cublas-cu12==12.4.5.8
- nvidia-cuda-cupti-cu12==12.4.127
- nvidia-cuda-nvrtc-cu12==12.4.127
- nvidia-cuda-runtime-cu12==12.4.127
- nvidia-cudnn-cu12==9.1.0.70
- nvidia-cufft-cu12==11.2.1.3
- nvidia-curand-cu12==10.3.5.147
- nvidia-cusolver-cu12==11.6.1.9
- nvidia-cusparse-cu12==12.3.1.170
- nvidia-nccl-cu12==2.21.5
- nvidia-nvjitlink-cu12==12.4.127
- nvidia-nvtx-cu12==12.4.127
- opencv-python==4.7.0.72
- packaging==25.0
- pandas==2.3.0
- pillow==11.2.1
- platformdirs==4.3.8
- protobuf==6.31.0
- psutil==7.0.0
- pyarrow==11.0.0
- pydantic==2.11.5
- pydantic-core==2.33.2
- pyparsing==3.2.3
- python-dateutil==2.9.0.post0
- pytz==2025.2
- pyyaml==6.0.2
- regex==2024.11.6
- requests==2.32.3
- safetensors==0.4.5
- scipy==1.10.1
- sentencepiece==0.1.99
- sentry-sdk==2.29.1
- setproctitle==1.3.6
- six==1.17.0
- smmap==5.0.2
- sympy==1.13.1
- tokenizers==0.21.1
- torch==2.5.1
- torchvision==0.20.1
- tqdm==4.67.1
- transformers==4.49.0
- triton==3.1.0
- typing-extensions==4.13.2
- typing-inspection==0.4.1
- tzdata==2025.2
- urllib3==2.4.0
- wandb==0.19.11
\ No newline at end of file
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
from . import autoencoder, bagel, qwen2, siglip
# Copyright (c) 2024 Black Forest Labs.
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
#
# This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025-05-20.
#
# Original file was released under Apache-2.0, with the full license text
# available at https://github.com/black-forest-labs/flux/blob/main/LICENSE.
#
# This modified file is released under the same license.
from dataclasses import dataclass
import torch
from einops import rearrange
from safetensors.torch import load_file as load_sft
from torch import Tensor, nn
@dataclass
class AutoEncoderParams:
resolution: int
in_channels: int
downsample: int
ch: int
out_ch: int
ch_mult: list[int]
num_res_blocks: int
z_channels: int
scale_factor: float
shift_factor: float
def swish(x: Tensor) -> Tensor:
return x * torch.sigmoid(x)
class AttnBlock(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
self.in_channels = in_channels
self.norm = nn.GroupNorm(
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
)
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
def attention(self, h_: Tensor) -> Tensor:
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
b, c, h, w = q.shape
q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
h_ = nn.functional.scaled_dot_product_attention(q, k, v)
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
def forward(self, x: Tensor) -> Tensor:
return x + self.proj_out(self.attention(x))
class ResnetBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.norm1 = nn.GroupNorm(
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
)
self.conv1 = nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=1, padding=1
)
self.norm2 = nn.GroupNorm(
num_groups=32, num_channels=out_channels, eps=1e-6, affine=True
)
self.conv2 = nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=1, padding=1
)
if self.in_channels != self.out_channels:
self.nin_shortcut = nn.Conv2d(
in_channels, out_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x):
h = x
h = self.norm1(h)
h = swish(h)
h = self.conv1(h)
h = self.norm2(h)
h = swish(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
x = self.nin_shortcut(x)
return x + h
class Downsample(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
# no asymmetric padding in torch conv, must do it ourselves
self.conv = nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=2, padding=0
)
def forward(self, x: Tensor):
pad = (0, 1, 0, 1)
x = nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
return x
class Upsample(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
self.conv = nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=1, padding=1
)
def forward(self, x: Tensor):
x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
x = self.conv(x)
return x
class Encoder(nn.Module):
def __init__(
self,
resolution: int,
in_channels: int,
ch: int,
ch_mult: list[int],
num_res_blocks: int,
z_channels: int,
):
super().__init__()
self.ch = ch
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
# downsampling
self.conv_in = nn.Conv2d(
in_channels, self.ch, kernel_size=3, stride=1, padding=1
)
curr_res = resolution
in_ch_mult = (1,) + tuple(ch_mult)
self.in_ch_mult = in_ch_mult
self.down = nn.ModuleList()
block_in = self.ch
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for _ in range(self.num_res_blocks):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
block_in = block_out
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
# end
self.norm_out = nn.GroupNorm(
num_groups=32, num_channels=block_in, eps=1e-6, affine=True
)
self.conv_out = nn.Conv2d(
block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1
)
def forward(self, x: Tensor) -> Tensor:
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1])
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = hs[-1]
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
# end
h = self.norm_out(h)
h = swish(h)
h = self.conv_out(h)
return h
class Decoder(nn.Module):
def __init__(
self,
ch: int,
out_ch: int,
ch_mult: list[int],
num_res_blocks: int,
in_channels: int,
resolution: int,
z_channels: int,
):
super().__init__()
self.ch = ch
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.ffactor = 2 ** (self.num_resolutions - 1)
# compute in_ch_mult, block_in and curr_res at lowest res
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
# z to block_in
self.conv_in = nn.Conv2d(
z_channels, block_in, kernel_size=3, stride=1, padding=1
)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for _ in range(self.num_res_blocks + 1):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
block_in = block_out
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = nn.GroupNorm(
num_groups=32, num_channels=block_in, eps=1e-6, affine=True
)
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
def forward(self, z: Tensor) -> Tensor:
# z to block_in
h = self.conv_in(z)
# middle
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
h = self.norm_out(h)
h = swish(h)
h = self.conv_out(h)
return h
class DiagonalGaussian(nn.Module):
def __init__(self, sample: bool = True, chunk_dim: int = 1):
super().__init__()
self.sample = sample
self.chunk_dim = chunk_dim
def forward(self, z: Tensor) -> Tensor:
mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
if self.sample:
std = torch.exp(0.5 * logvar)
return mean + std * torch.randn_like(mean)
else:
return mean
class AutoEncoder(nn.Module):
def __init__(self, params: AutoEncoderParams):
super().__init__()
self.encoder = Encoder(
resolution=params.resolution,
in_channels=params.in_channels,
ch=params.ch,
ch_mult=params.ch_mult,
num_res_blocks=params.num_res_blocks,
z_channels=params.z_channels,
)
self.decoder = Decoder(
resolution=params.resolution,
in_channels=params.in_channels,
ch=params.ch,
out_ch=params.out_ch,
ch_mult=params.ch_mult,
num_res_blocks=params.num_res_blocks,
z_channels=params.z_channels,
)
self.reg = DiagonalGaussian()
self.scale_factor = params.scale_factor
self.shift_factor = params.shift_factor
def encode(self, x: Tensor) -> Tensor:
z = self.reg(self.encoder(x))
z = self.scale_factor * (z - self.shift_factor)
return z
def decode(self, z: Tensor) -> Tensor:
z = z / self.scale_factor + self.shift_factor
return self.decoder(z)
def forward(self, x: Tensor) -> Tensor:
return self.decode(self.encode(x))
def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
if len(missing) > 0 and len(unexpected) > 0:
print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
print("\n" + "-" * 79 + "\n")
print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
elif len(missing) > 0:
print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
elif len(unexpected) > 0:
print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
def load_ae(local_path: str | None) -> tuple[AutoEncoder, AutoEncoderParams]:
ae_params = AutoEncoderParams(
resolution=256,
in_channels=3,
downsample=8,
ch=128,
out_ch=3,
ch_mult=[1, 2, 4, 4],
num_res_blocks=2,
z_channels=16,
scale_factor=0.3611,
shift_factor=0.1159,
)
# Loading the autoencoder
ae = AutoEncoder(ae_params)
if local_path is not None:
sd = load_sft(local_path)
missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True)
print_load_warning(missing, unexpected)
return ae, ae_params
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
from .bagel import Bagel, BagelConfig
from .qwen2_navit import Qwen2Config, Qwen2ForCausalLM, Qwen2Model
from .siglip_navit import SiglipVisionConfig, SiglipVisionModel
__all__ = [
"BagelConfig",
"Bagel",
"Qwen2Config",
"Qwen2Model",
"Qwen2ForCausalLM",
"SiglipVisionConfig",
"SiglipVisionModel",
]
# Copyright 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
import copy
from typing import List, Optional, Tuple
import torch
import torch.nn.functional as F
from data.data_utils import (
create_sparse_mask,
get_flattened_position_ids_extrapolate,
get_flattened_position_ids_interpolate,
patchify,
)
from torch import nn
from torch.nn.attention.flex_attention import create_block_mask
from tqdm import tqdm
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_utils import PreTrainedModel
from .modeling_utils import MLPconnector, PositionEmbedding, TimestepEmbedder
from .qwen2_navit import NaiveCache
class BagelConfig(PretrainedConfig):
def __init__(
self,
visual_gen=True,
visual_und=True,
llm_config=None,
vit_config=None,
vae_config=None,
latent_patch_size=2,
max_latent_size=32,
vit_max_num_patch_per_side=70,
connector_act="gelu_pytorch_tanh",
interpolate_pos=False,
timestep_shift=1.0,
**kwargs,
):
super().__init__(**kwargs)
self.visual_gen = visual_gen
self.visual_und = visual_und
self.llm_config = llm_config
self.vit_config = vit_config
self.vae_config = vae_config
self.latent_patch_size = latent_patch_size
self.max_latent_size = max_latent_size
self.vit_max_num_patch_per_side = vit_max_num_patch_per_side
self.connector_act = connector_act
self.interpolate_pos = interpolate_pos
self.timestep_shift = timestep_shift
class Bagel(PreTrainedModel):
config_class = BagelConfig
base_model_prefix = "bagel"
def __init__(self, language_model, vit_model, config: BagelConfig):
super().__init__(config)
self.language_model = language_model
if config.llm_config is None:
raise ValueError("llm_config cannot be None")
self.hidden_size = config.llm_config.hidden_size
self.use_moe = "Mo" in config.llm_config.layer_module
self.num_heads = config.llm_config.num_attention_heads
if config.visual_gen:
if config.vae_config is None:
raise ValueError("vae_config cannot be None when visual_gen is True")
self.latent_patch_size = config.latent_patch_size
self.timestep_shift = config.timestep_shift
self.latent_downsample = (
config.vae_config.downsample * config.latent_patch_size
)
self.max_latent_size = config.max_latent_size
self.latent_channel = config.vae_config.z_channels
self.patch_latent_dim = self.latent_patch_size**2 * self.latent_channel
self.time_embedder = TimestepEmbedder(self.hidden_size)
self.vae2llm = nn.Linear(self.patch_latent_dim, self.hidden_size)
self.llm2vae = nn.Linear(self.hidden_size, self.patch_latent_dim)
self.latent_pos_embed = PositionEmbedding(
self.max_latent_size, self.hidden_size
)
if config.visual_und:
if config.vit_config is None:
raise ValueError("vit_config cannot be None when visual_und is True")
self.vit_model = vit_model
self.vit_patch_size = config.vit_config.patch_size
self.vit_max_num_patch_per_side = config.vit_max_num_patch_per_side
self.vit_hidden_size = config.vit_config.hidden_size
self.connector = MLPconnector(
self.vit_hidden_size, self.hidden_size, config.connector_act
)
self.vit_pos_embed = PositionEmbedding(
self.vit_max_num_patch_per_side, self.hidden_size
)
if config.interpolate_pos:
self.get_flattened_position_ids = get_flattened_position_ids_interpolate
else:
self.get_flattened_position_ids = get_flattened_position_ids_extrapolate
self.config = config
self._init_weights()
def _init_weights(self):
if self.config.visual_gen:
nn.init.constant_(self.llm2vae.weight, 0)
nn.init.constant_(self.llm2vae.bias, 0)
def forward(
self,
sequence_length: int,
packed_text_ids: torch.LongTensor,
packed_text_indexes: torch.LongTensor,
sample_lens: List[int],
packed_position_ids: torch.LongTensor,
nested_attention_masks: Optional[List[torch.Tensor]] = None,
split_lens: Optional[List[int]] = None,
attn_modes: Optional[List[str]] = None,
# for visual understanding
ce_loss_indexes: Optional[torch.BoolTensor] = None,
packed_label_ids: Optional[torch.LongTensor] = None,
packed_vit_tokens: Optional[torch.Tensor] = None,
packed_vit_token_indexes: Optional[torch.LongTensor] = None,
packed_vit_position_ids: Optional[torch.LongTensor] = None,
vit_token_seqlens: Optional[torch.IntTensor] = None,
# for visual generation
padded_latent: Optional[torch.Tensor] = None,
patchified_vae_latent_shapes: Optional[List[Tuple[int, int]]] = None,
packed_latent_position_ids: Optional[torch.LongTensor] = None,
packed_vae_token_indexes: Optional[torch.LongTensor] = None,
packed_timesteps: Optional[torch.LongTensor] = None,
mse_loss_indexes: Optional[torch.BoolTensor] = None,
) -> torch.Tensor:
"""
Args:
sequence_length: length of sequence.
packed_text_ids: 1-D int tensor, packed text token ids.
packed_text_indexes: 1-D int tensor, packed text token indexes in sequence.
sample_lens: A list of N ints, length of each sample in packed_sequence.
nested_attention_masks: A list of N 2-D float tensor, where 0.0 means attention and
-inf means ignore.
packed_position_ids: packed 1-D positions, an image has only one global position shared
by all latent tokens.
packed_vit_tokens: packed patchified image tokens for vit model.
packed_vit_position_ids: 1-D int tensor, the position of each token for vit model.
packed_vit_token_indexes: 1-D int tensor, packed vit token indexes in sequence.
vit_token_seqlens: 1-D int tensor, the length of each image tokens for vit model.
packed_label_ids: 1-D int tensor, packed label token ids.
ce_loss_indexes: 1-D bool tensor, where to compute ce loss.
padded_latent: padded latent from VAE encoder.
patchified_vae_latent_shapes: A list of (h, w) tuples, patchfied latent shapes of each image.
packed_latent_position_ids: 1-D int tensor, the position of each token for latent.
packed_vae_token_indexes: 1-D int tensor, padded image token indexes in sequence.
packed_timesteps: 1-D float tensor, flow timesteps. 0 indicates use clean image.
mse_loss_indexes: 1-D bool tensor, where to compute mse loss.
"""
packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids)
packed_sequence = packed_text_embedding.new_zeros(
size=(sequence_length, self.hidden_size)
)
packed_sequence[packed_text_indexes] = packed_text_embedding
if nested_attention_masks is None:
sparse_mask = create_sparse_mask(
sample_lens, split_lens, attn_modes, packed_text_embedding.device
)
seqlen = sum(sample_lens)
block_mask = create_block_mask(
sparse_mask,
B=1,
H=self.num_heads,
Q_LEN=seqlen,
KV_LEN=seqlen,
device=packed_text_embedding.device,
BLOCK_SIZE=128,
_compile=True,
)
attention_mask = block_mask
else:
attention_mask = nested_attention_masks
if self.config.visual_und:
cu_seqlens = torch.nn.functional.pad(
torch.cumsum(vit_token_seqlens, dim=0), (1, 0)
)
cu_seqlens = cu_seqlens.to(torch.int32)
max_seqlen = torch.max(vit_token_seqlens).item()
packed_vit_token_embed = self.vit_model(
packed_pixel_values=packed_vit_tokens,
packed_flattened_position_ids=packed_vit_position_ids,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
packed_vit_token_embed = self.connector(packed_vit_token_embed)
vit_token_pos_emb = self.vit_pos_embed(packed_vit_position_ids)
packed_vit_token_embed = packed_vit_token_embed + vit_token_pos_emb
packed_sequence[packed_vit_token_indexes] = packed_vit_token_embed
if self.config.visual_gen:
p = self.latent_patch_size
packed_latent = []
for latent, (h, w) in zip(padded_latent, patchified_vae_latent_shapes):
latent = latent[:, : h * p, : w * p].reshape(
self.latent_channel, h, p, w, p
)
latent = torch.einsum("chpwq->hwpqc", latent).reshape(
-1, p * p * self.latent_channel
)
packed_latent.append(latent)
packed_latent_clean = torch.cat(packed_latent, dim=0)
noise = torch.randn_like(packed_latent_clean)
packed_timesteps = torch.sigmoid(packed_timesteps)
packed_timesteps = (
self.timestep_shift
* packed_timesteps
/ (1 + (self.timestep_shift - 1) * packed_timesteps)
)
packed_latent = (
1 - packed_timesteps[:, None]
) * packed_latent_clean + packed_timesteps[:, None] * noise
packed_timestep_embeds = self.time_embedder(packed_timesteps)
latent_token_pos_emb = self.latent_pos_embed(packed_latent_position_ids)
packed_latent = (
self.vae2llm(packed_latent)
+ packed_timestep_embeds
+ latent_token_pos_emb
)
packed_sequence[packed_vae_token_indexes] = packed_latent
extra_inputs = {}
if self.use_moe:
packed_und_token_indexes = packed_text_indexes
if packed_vit_token_indexes is not None:
packed_und_token_indexes = torch.cat(
[packed_text_indexes, packed_vit_token_indexes], dim=0
)
extra_inputs.update(
packed_und_token_indexes=packed_und_token_indexes,
packed_gen_token_indexes=packed_vae_token_indexes,
)
last_hidden_state = self.language_model(
packed_sequence=packed_sequence,
sample_lens=sample_lens,
attention_mask=attention_mask,
packed_position_ids=packed_position_ids,
**extra_inputs,
)
mse = None
if self.config.visual_gen:
packed_mse_preds = self.llm2vae(last_hidden_state[mse_loss_indexes])
target = (
noise - packed_latent_clean
) # NOTE: v_t=dx_t/dt=x_1-x_0, pointing from data to noise
has_mse = packed_timesteps > 0
mse = (packed_mse_preds - target[has_mse]) ** 2
ce = None
if ce_loss_indexes is not None:
packed_ce_preds = self.language_model.lm_head(
last_hidden_state[ce_loss_indexes]
)
ce = F.cross_entropy(packed_ce_preds, packed_label_ids, reduction="none")
return dict(mse=mse, ce=ce)
def prepare_prompts(
self, curr_kvlens, curr_rope, prompts, tokenizer, new_token_ids
):
packed_text_ids = list()
packed_text_position_ids = list()
text_token_lens = list()
packed_text_indexes = list()
packed_key_value_indexes = list()
curr = 0
newlens, new_rope = list(), list()
for prompt, curr_kvlen, curr_position_id in zip(
prompts, curr_kvlens, curr_rope
):
packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
curr += curr_kvlen
text_ids = tokenizer.encode(prompt)
text_ids = (
[new_token_ids["bos_token_id"]]
+ text_ids
+ [new_token_ids["eos_token_id"]]
)
text_token_lens.append(len(text_ids))
packed_text_ids.extend(text_ids)
packed_text_position_ids.extend(
range(curr_position_id, curr_position_id + len(text_ids))
)
packed_text_indexes.extend(range(curr, curr + len(text_ids)))
newlens.append(curr_kvlen + len(text_ids))
new_rope.append(curr_position_id + len(text_ids))
curr += len(text_ids)
generation_input = {
"text_token_lens": torch.tensor(text_token_lens, dtype=torch.int),
"packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long),
"packed_text_position_ids": torch.tensor(
packed_text_position_ids, dtype=torch.long
),
"packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long),
"packed_key_value_indexes": torch.tensor(
packed_key_value_indexes, dtype=torch.long
),
"key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
}
return generation_input, newlens, new_rope
@torch.no_grad
def forward_cache_update_text(
self,
past_key_values: NaiveCache,
packed_text_ids: torch.IntTensor,
packed_text_position_ids: torch.LongTensor,
text_token_lens: torch.LongTensor,
packed_text_indexes: torch.LongTensor,
packed_key_value_indexes: torch.LongTensor,
key_values_lens: torch.IntTensor,
):
packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids)
extra_inputs = {}
if self.use_moe:
extra_inputs = {"mode": "und"}
output = self.language_model.forward_inference(
packed_query_sequence=packed_text_embedding,
query_lens=text_token_lens,
packed_query_position_ids=packed_text_position_ids,
packed_query_indexes=packed_text_indexes,
past_key_values=past_key_values,
packed_key_value_indexes=packed_key_value_indexes,
key_values_lens=key_values_lens,
update_past_key_values=True,
is_causal=True,
**extra_inputs,
)
past_key_values = output.past_key_values
return past_key_values
def prepare_vit_images(
self, curr_kvlens, curr_rope, images, transforms, new_token_ids
):
packed_vit_token_indexes = list()
vit_token_seqlens, packed_vit_tokens, packed_vit_position_ids = (
list(),
list(),
list(),
)
packed_text_ids, packed_text_indexes = list(), list()
packed_seqlens, packed_position_ids, packed_indexes = list(), list(), list()
packed_key_value_indexes = list()
_curr = curr = 0
newlens, new_rope = list(), list()
for image, curr_kvlen, curr_position_id in zip(images, curr_kvlens, curr_rope):
packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
curr += curr_kvlen
packed_text_ids.append(new_token_ids["start_of_image"])
packed_text_indexes.append(_curr)
packed_indexes.append(curr)
curr += 1
_curr += 1
image_tensor = transforms(image)
vit_position_ids = self.get_flattened_position_ids(
image_tensor.size(1),
image_tensor.size(2),
self.vit_patch_size,
max_num_patches_per_side=self.vit_max_num_patch_per_side,
)
vit_tokens = patchify(image_tensor, self.vit_patch_size)
packed_vit_tokens.append(vit_tokens)
num_img_tokens = vit_tokens.shape[0]
packed_vit_position_ids.append(vit_position_ids)
vit_token_seqlens.append(num_img_tokens)
packed_vit_token_indexes.extend(range(_curr, _curr + num_img_tokens))
packed_indexes.extend(range(curr, curr + num_img_tokens))
curr += num_img_tokens
_curr += num_img_tokens
packed_text_ids.append(new_token_ids["end_of_image"])
packed_text_indexes.append(_curr)
packed_indexes.append(curr)
curr += 1
_curr += 1
packed_position_ids.extend([curr_position_id] * (num_img_tokens + 2))
packed_seqlens.append(num_img_tokens + 2)
newlens.append(curr_kvlen + num_img_tokens + 2)
new_rope.append(curr_position_id + 1)
generation_input = {
"packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long),
"packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long),
"vit_token_seqlens": torch.tensor(vit_token_seqlens, dtype=torch.int),
"packed_vit_tokens": torch.cat(packed_vit_tokens, dim=0),
"packed_vit_position_ids": torch.cat(packed_vit_position_ids, dim=0),
"packed_vit_token_indexes": torch.tensor(
packed_vit_token_indexes, dtype=torch.long
),
"packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long),
"packed_seqlens": torch.tensor(packed_seqlens, dtype=torch.int),
"packed_indexes": torch.tensor(packed_indexes, dtype=torch.long),
"packed_key_value_indexes": torch.tensor(
packed_key_value_indexes, dtype=torch.long
),
"key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
}
return generation_input, newlens, new_rope
@torch.no_grad
def forward_cache_update_vit(
self,
past_key_values: NaiveCache,
packed_text_ids: torch.LongTensor,
packed_text_indexes: torch.LongTensor,
packed_vit_tokens: torch.Tensor,
packed_vit_token_indexes: torch.LongTensor,
packed_vit_position_ids: torch.LongTensor,
vit_token_seqlens: torch.IntTensor,
packed_position_ids: torch.LongTensor,
packed_seqlens: torch.IntTensor,
packed_indexes: torch.LongTensor,
packed_key_value_indexes: torch.LongTensor,
key_values_lens: torch.IntTensor,
):
packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids)
packed_sequence = packed_text_embedding.new_zeros(
(sum(packed_seqlens), self.hidden_size)
)
packed_sequence[packed_text_indexes] = packed_text_embedding
cu_seqlens = torch.nn.functional.pad(
torch.cumsum(vit_token_seqlens, dim=0), (1, 0)
)
cu_seqlens = cu_seqlens.to(torch.int32)
max_seqlen = torch.max(vit_token_seqlens).item()
packed_vit_token_embed = self.vit_model(
packed_pixel_values=packed_vit_tokens,
packed_flattened_position_ids=packed_vit_position_ids,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
packed_vit_token_embed = self.connector(packed_vit_token_embed)
pos_emb = self.vit_pos_embed(packed_vit_position_ids)
packed_vit_token_embed = packed_vit_token_embed + pos_emb
if packed_vit_token_embed.dtype != packed_sequence.dtype:
packed_vit_token_embed = packed_vit_token_embed.to(packed_sequence.dtype)
packed_sequence[packed_vit_token_indexes] = packed_vit_token_embed
extra_inputs = {}
if self.use_moe:
extra_inputs = {"mode": "und"}
output = self.language_model.forward_inference(
packed_query_sequence=packed_sequence,
query_lens=packed_seqlens,
packed_query_position_ids=packed_position_ids,
packed_query_indexes=packed_indexes,
past_key_values=past_key_values,
packed_key_value_indexes=packed_key_value_indexes,
key_values_lens=key_values_lens,
update_past_key_values=True,
is_causal=False,
**extra_inputs,
)
past_key_values = output.past_key_values
return past_key_values
def prepare_vae_images(
self, curr_kvlens, curr_rope, images, transforms, new_token_ids, timestep=0
):
patchified_vae_latent_shapes, packed_vae_position_ids = list(), list()
packed_vae_token_indexes = list()
packed_text_ids, packed_text_indexes = list(), list()
packed_seqlens, packed_position_ids, packed_indexes = list(), list(), list()
packed_key_value_indexes = list()
_curr = curr = 0
vae_image_tensors = list()
newlens, new_rope = list(), list()
for image, curr_kvlen, curr_position_id in zip(images, curr_kvlens, curr_rope):
packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
curr += curr_kvlen
packed_text_ids.append(new_token_ids["start_of_image"])
packed_text_indexes.append(_curr)
packed_indexes.append(curr)
curr += 1
_curr += 1
image_tensor = transforms(image)
vae_image_tensors.append(image_tensor)
vae_posiiton_ids = self.get_flattened_position_ids(
image_tensor.size(1),
image_tensor.size(2),
self.latent_downsample,
max_num_patches_per_side=self.max_latent_size,
)
packed_vae_position_ids.append(vae_posiiton_ids)
H, W = image_tensor.shape[1:]
h = H // self.latent_downsample
w = W // self.latent_downsample
patchified_vae_latent_shapes.append((h, w))
num_img_tokens = w * h
packed_vae_token_indexes.extend(range(_curr, _curr + num_img_tokens))
packed_indexes.extend(range(curr, curr + num_img_tokens))
curr += num_img_tokens
_curr += num_img_tokens
packed_text_ids.append(new_token_ids["end_of_image"])
packed_text_indexes.append(_curr)
packed_indexes.append(curr)
curr += 1
_curr += 1
packed_position_ids.extend([curr_position_id] * (num_img_tokens + 2))
packed_seqlens.append(num_img_tokens + 2)
newlens.append(curr_kvlen + num_img_tokens + 2)
new_rope.append(curr_position_id + 1)
image_sizes = [item.shape for item in vae_image_tensors]
max_image_size = [max(item) for item in list(zip(*image_sizes))]
padded_images = torch.zeros(size=(len(vae_image_tensors), *max_image_size))
for i, image_tensor in enumerate(vae_image_tensors):
padded_images[i, :, : image_tensor.shape[1], : image_tensor.shape[2]] = (
image_tensor
)
generation_input = {
"padded_images": padded_images,
"patchified_vae_latent_shapes": patchified_vae_latent_shapes,
"packed_vae_position_ids": torch.cat(packed_vae_position_ids, dim=0),
"packed_timesteps": torch.tensor([timestep]),
"packed_vae_token_indexes": torch.tensor(
packed_vae_token_indexes, dtype=torch.long
),
"packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long),
"packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long),
"packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long),
"packed_seqlens": torch.tensor(packed_seqlens, dtype=torch.int),
"packed_indexes": torch.tensor(packed_indexes, dtype=torch.long),
"packed_key_value_indexes": torch.tensor(
packed_key_value_indexes, dtype=torch.long
),
"key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
}
return generation_input, newlens, new_rope
@torch.no_grad
def forward_cache_update_vae(
self,
vae_model,
past_key_values: NaiveCache,
padded_images: torch.Tensor,
patchified_vae_latent_shapes: List,
packed_vae_position_ids: torch.LongTensor,
packed_timesteps: torch.Tensor,
packed_vae_token_indexes: torch.LongTensor,
packed_text_ids: torch.LongTensor,
packed_text_indexes: torch.LongTensor,
packed_position_ids: torch.LongTensor,
packed_seqlens: torch.IntTensor,
packed_indexes: torch.LongTensor,
key_values_lens: torch.IntTensor,
packed_key_value_indexes: torch.Tensor,
):
packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids)
packed_sequence = packed_text_embedding.new_zeros(
(sum(packed_seqlens), self.hidden_size)
)
packed_sequence[packed_text_indexes] = packed_text_embedding
padded_latent = vae_model.encode(padded_images)
p = self.latent_patch_size
packed_latent = list()
for latent, (h, w) in zip(padded_latent, patchified_vae_latent_shapes):
latent = latent[:, : h * p, : w * p].reshape(
self.latent_channel, h, p, w, p
)
latent = torch.einsum("chpwq->hwpqc", latent).reshape(
-1, p * p * self.latent_channel
)
packed_latent.append(latent)
packed_latent = torch.cat(packed_latent, dim=0)
packed_pos_embed = self.latent_pos_embed(packed_vae_position_ids)
packed_timestep_embeds = self.time_embedder(packed_timesteps)
packed_latent = (
self.vae2llm(packed_latent) + packed_timestep_embeds + packed_pos_embed
)
if packed_latent.dtype != packed_sequence.dtype:
packed_latent = packed_latent.to(packed_sequence.dtype)
packed_sequence[packed_vae_token_indexes] = packed_latent
extra_inputs = {}
if self.use_moe:
extra_inputs = {
"mode": "gen",
"packed_vae_token_indexes": packed_vae_token_indexes,
"packed_text_indexes": packed_text_indexes,
}
output = self.language_model.forward_inference(
packed_query_sequence=packed_sequence,
query_lens=packed_seqlens,
packed_query_position_ids=packed_position_ids,
packed_query_indexes=packed_indexes,
past_key_values=past_key_values,
key_values_lens=key_values_lens,
packed_key_value_indexes=packed_key_value_indexes,
update_past_key_values=True,
is_causal=False,
**extra_inputs,
)
past_key_values = output.past_key_values
return past_key_values
def prepare_vae_latent(self, curr_kvlens, curr_rope, image_sizes, new_token_ids):
packed_text_ids, packed_text_indexes = list(), list()
packed_vae_position_ids, packed_vae_token_indexes, packed_init_noises = (
list(),
list(),
list(),
)
packed_position_ids, packed_seqlens, packed_indexes = list(), list(), list()
packed_key_value_indexes = list()
query_curr = curr = 0
for (H, W), curr_kvlen, curr_position_id in zip(
image_sizes, curr_kvlens, curr_rope
):
packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
curr += curr_kvlen
packed_text_ids.append(new_token_ids["start_of_image"])
packed_text_indexes.append(query_curr)
packed_indexes.append(curr)
curr += 1
query_curr += 1
vae_posiiton_ids = self.get_flattened_position_ids(
H,
W,
self.latent_downsample,
max_num_patches_per_side=self.max_latent_size,
)
packed_vae_position_ids.append(vae_posiiton_ids)
h, w = H // self.latent_downsample, W // self.latent_downsample
num_image_tokens = h * w
packed_init_noises.append(
torch.randn(
num_image_tokens, self.latent_channel * self.latent_patch_size**2
)
)
packed_vae_token_indexes.extend(
range(query_curr, query_curr + num_image_tokens)
)
packed_indexes.extend(range(curr, curr + num_image_tokens))
curr += num_image_tokens
query_curr += num_image_tokens
packed_text_ids.append(new_token_ids["end_of_image"])
packed_text_indexes.append(query_curr)
packed_indexes.append(curr)
curr += 1
query_curr += 1
packed_position_ids.extend([curr_position_id] * (num_image_tokens + 2))
packed_seqlens.append(num_image_tokens + 2)
generation_input = {
"packed_text_ids": torch.tensor(packed_text_ids, dtype=torch.long),
"packed_text_indexes": torch.tensor(packed_text_indexes, dtype=torch.long),
"packed_init_noises": torch.cat(packed_init_noises, dim=0),
"packed_vae_position_ids": torch.cat(packed_vae_position_ids, dim=0),
"packed_vae_token_indexes": torch.tensor(
packed_vae_token_indexes, dtype=torch.long
),
"packed_seqlens": torch.tensor(packed_seqlens, dtype=torch.int),
"packed_position_ids": torch.tensor(packed_position_ids, dtype=torch.long),
"key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
"packed_indexes": torch.tensor(packed_indexes, dtype=torch.long),
"packed_key_value_indexes": torch.tensor(
packed_key_value_indexes, dtype=torch.long
),
}
return generation_input
def prepare_vae_latent_cfg(self, curr_kvlens, curr_rope, image_sizes):
packed_position_ids, packed_indexes, packed_key_value_indexes = (
list(),
list(),
list(),
)
query_curr = curr = 0
for (H, W), curr_kvlen, curr_position_id in zip(
image_sizes, curr_kvlens, curr_rope
):
packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
curr += curr_kvlen
packed_indexes.append(curr)
curr += 1
query_curr += 1
h, w = H // self.latent_downsample, W // self.latent_downsample
num_image_tokens = h * w
packed_indexes.extend(range(curr, curr + num_image_tokens))
curr += num_image_tokens
query_curr += num_image_tokens
packed_indexes.append(curr)
curr += 1
query_curr += 1
packed_position_ids.extend([curr_position_id] * (num_image_tokens + 2))
generation_input = {
"cfg_packed_position_ids": torch.tensor(
packed_position_ids, dtype=torch.long
),
"cfg_key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
"cfg_packed_query_indexes": torch.tensor(packed_indexes, dtype=torch.long),
"cfg_packed_key_value_indexes": torch.tensor(
packed_key_value_indexes, dtype=torch.long
),
}
return generation_input
@torch.no_grad
def generate_image(
self,
packed_text_ids: torch.LongTensor,
packed_text_indexes: torch.LongTensor,
packed_init_noises: torch.Tensor,
packed_vae_position_ids: torch.LongTensor,
packed_vae_token_indexes: torch.LongTensor,
packed_seqlens: torch.IntTensor,
packed_position_ids: torch.LongTensor,
packed_indexes: torch.LongTensor,
past_key_values: NaiveCache,
key_values_lens: torch.IntTensor,
packed_key_value_indexes: torch.LongTensor,
num_timesteps: int = 24,
timestep_shift: float = 1.0,
cfg_renorm_min: float = 0.0,
cfg_renorm_type: str = "global",
cfg_interval: Optional[Tuple[float, float]] = [0, 1],
# cfg_text
cfg_text_scale: float = 1.0,
cfg_text_packed_query_indexes: Optional[torch.LongTensor] = None,
cfg_text_packed_position_ids: Optional[torch.LongTensor] = None,
cfg_text_past_key_values: Optional[NaiveCache] = None,
cfg_text_key_values_lens: Optional[torch.IntTensor] = None,
cfg_text_packed_key_value_indexes: Optional[torch.LongTensor] = None,
# cfg_img
cfg_img_scale: float = 1.0,
cfg_img_packed_query_indexes: Optional[torch.LongTensor] = None,
cfg_img_packed_position_ids: Optional[torch.LongTensor] = None,
cfg_img_past_key_values: Optional[NaiveCache] = None,
cfg_img_key_values_lens: Optional[torch.IntTensor] = None,
cfg_img_packed_key_value_indexes: Optional[torch.LongTensor] = None,
cfg_type: str = "parallel",
):
x_t = packed_init_noises
timesteps = torch.linspace(1, 0, num_timesteps, device=x_t.device)
timesteps = timestep_shift * timesteps / (1 + (timestep_shift - 1) * timesteps)
dts = timesteps[:-1] - timesteps[1:]
timesteps = timesteps[:-1]
for i, t in tqdm(enumerate(timesteps), total=len(timesteps)):
timestep = torch.tensor([t] * x_t.shape[0], device=x_t.device)
if t > cfg_interval[0] and t <= cfg_interval[1]:
cfg_text_scale_ = cfg_text_scale
cfg_img_scale_ = cfg_img_scale
else:
cfg_text_scale_ = 1.0
cfg_img_scale_ = 1.0
v_t = self._forward_flow(
x_t=x_t,
timestep=timestep,
packed_vae_token_indexes=packed_vae_token_indexes,
packed_vae_position_ids=packed_vae_position_ids,
packed_text_ids=packed_text_ids,
packed_text_indexes=packed_text_indexes,
packed_position_ids=packed_position_ids,
packed_indexes=packed_indexes,
packed_seqlens=packed_seqlens,
key_values_lens=key_values_lens,
past_key_values=past_key_values,
packed_key_value_indexes=packed_key_value_indexes,
cfg_renorm_min=cfg_renorm_min,
cfg_renorm_type=cfg_renorm_type,
# cfg_text
cfg_text_scale=cfg_text_scale_,
cfg_text_packed_position_ids=cfg_text_packed_position_ids,
cfg_text_packed_query_indexes=cfg_text_packed_query_indexes,
cfg_text_key_values_lens=cfg_text_key_values_lens,
cfg_text_past_key_values=cfg_text_past_key_values,
cfg_text_packed_key_value_indexes=cfg_text_packed_key_value_indexes,
# cfg_img
cfg_img_scale=cfg_img_scale_,
cfg_img_packed_position_ids=cfg_img_packed_position_ids,
cfg_img_packed_query_indexes=cfg_img_packed_query_indexes,
cfg_img_key_values_lens=cfg_img_key_values_lens,
cfg_img_past_key_values=cfg_img_past_key_values,
cfg_img_packed_key_value_indexes=cfg_img_packed_key_value_indexes,
cfg_type=cfg_type,
)
x_t = (
x_t - v_t.to(x_t.device) * dts[i]
) # velocity pointing from data to noise
unpacked_latent = x_t.split((packed_seqlens - 2).tolist())
return unpacked_latent
@torch.no_grad
def _forward_flow(
self,
x_t: torch.Tensor,
timestep: torch.LongTensor,
packed_vae_token_indexes: torch.LongTensor,
packed_vae_position_ids: torch.LongTensor,
packed_text_ids: torch.LongTensor,
packed_text_indexes: torch.LongTensor,
packed_indexes: torch.LongTensor,
packed_position_ids: torch.LongTensor,
packed_seqlens: torch.IntTensor,
key_values_lens: torch.IntTensor,
past_key_values: NaiveCache,
packed_key_value_indexes: torch.LongTensor,
cfg_renorm_min: float = 0.0,
cfg_renorm_type: str = "global",
# cfg_text
cfg_text_scale: float = 1.0,
cfg_text_packed_position_ids: Optional[torch.LongTensor] = None,
cfg_text_packed_query_indexes: Optional[torch.LongTensor] = None,
cfg_text_key_values_lens: Optional[torch.Tensor] = None,
cfg_text_past_key_values: Optional[NaiveCache] = None,
cfg_text_packed_key_value_indexes: Optional[torch.LongTensor] = None,
# cfg_img
cfg_img_scale: float = 1.0,
cfg_img_packed_position_ids: Optional[torch.LongTensor] = None,
cfg_img_packed_query_indexes: Optional[torch.LongTensor] = None,
cfg_img_key_values_lens: Optional[torch.Tensor] = None,
cfg_img_past_key_values: Optional[NaiveCache] = None,
cfg_img_packed_key_value_indexes: Optional[torch.LongTensor] = None,
cfg_type: str = "parallel",
):
packed_text_embedding = self.language_model.model.embed_tokens(packed_text_ids)
packed_sequence = packed_text_embedding.new_zeros(
(sum(packed_seqlens), self.hidden_size)
)
packed_sequence[packed_text_indexes] = packed_text_embedding
assert timestep.unique().shape[0] == 1
packed_pos_embed = self.latent_pos_embed(packed_vae_position_ids)
packed_timestep_embeds = self.time_embedder(timestep)
x_t = self.vae2llm(x_t) + packed_timestep_embeds + packed_pos_embed
if x_t.dtype != packed_sequence.dtype:
x_t = x_t.to(packed_sequence.dtype)
packed_sequence[packed_vae_token_indexes] = x_t
extra_inputs = {}
if self.use_moe:
extra_inputs = {
"mode": "gen",
"packed_vae_token_indexes": packed_vae_token_indexes,
"packed_text_indexes": packed_text_indexes,
}
output = self.language_model.forward_inference(
packed_query_sequence=packed_sequence,
query_lens=packed_seqlens,
packed_query_position_ids=packed_position_ids,
packed_query_indexes=packed_indexes,
past_key_values=past_key_values,
key_values_lens=key_values_lens,
packed_key_value_indexes=packed_key_value_indexes,
update_past_key_values=False,
is_causal=False,
**extra_inputs,
)
v_t = self.llm2vae(output.packed_query_sequence)
v_t = v_t[packed_vae_token_indexes]
if cfg_text_scale > 1.0:
cfg_text_output = self.language_model.forward_inference(
packed_query_sequence=packed_sequence,
query_lens=packed_seqlens,
packed_query_position_ids=cfg_text_packed_position_ids,
packed_query_indexes=cfg_text_packed_query_indexes,
past_key_values=cfg_text_past_key_values,
key_values_lens=cfg_text_key_values_lens,
packed_key_value_indexes=cfg_text_packed_key_value_indexes,
update_past_key_values=False,
is_causal=False,
**extra_inputs,
)
cfg_text_v_t = self.llm2vae(cfg_text_output.packed_query_sequence)
cfg_text_v_t = cfg_text_v_t[packed_vae_token_indexes]
if cfg_img_scale > 1.0:
cfg_img_output = self.language_model.forward_inference(
packed_query_sequence=packed_sequence,
query_lens=packed_seqlens,
packed_query_position_ids=cfg_img_packed_position_ids,
packed_query_indexes=cfg_img_packed_query_indexes,
past_key_values=cfg_img_past_key_values,
key_values_lens=cfg_img_key_values_lens,
packed_key_value_indexes=cfg_img_packed_key_value_indexes,
update_past_key_values=False,
is_causal=False,
**extra_inputs,
)
cfg_img_v_t = self.llm2vae(cfg_img_output.packed_query_sequence)
cfg_img_v_t = cfg_img_v_t[packed_vae_token_indexes]
if cfg_text_scale > 1.0:
if cfg_renorm_type == "text_channel":
v_t_text_ = cfg_text_v_t + cfg_text_scale * (v_t - cfg_text_v_t)
norm_v_t = torch.norm(v_t, dim=-1, keepdim=True)
norm_v_t_text_ = torch.norm(v_t_text_, dim=-1, keepdim=True)
scale = (norm_v_t / (norm_v_t_text_ + 1e-8)).clamp(
min=cfg_renorm_min, max=1.0
)
v_t_text = v_t_text_ * scale
if cfg_img_scale > 1.0:
v_t = cfg_img_v_t + cfg_img_scale * (v_t_text - cfg_img_v_t)
else:
v_t = v_t_text
else:
v_t_text_ = cfg_text_v_t + cfg_text_scale * (v_t - cfg_text_v_t)
if cfg_img_scale > 1.0:
v_t_ = cfg_img_v_t + cfg_img_scale * (v_t_text_ - cfg_img_v_t)
else:
v_t_ = v_t_text_
# NOTE norm is computed over all dimensions, thus currently only supports batch_size = 1 with navit
if cfg_renorm_type == "global":
norm_v_t = torch.norm(v_t)
norm_v_t_ = torch.norm(v_t_)
elif cfg_renorm_type == "channel":
norm_v_t = torch.norm(v_t, dim=-1, keepdim=True)
norm_v_t_ = torch.norm(v_t_, dim=-1, keepdim=True)
else:
raise NotImplementedError(f"{cfg_renorm_type} is not suppoprted")
scale = (norm_v_t / (norm_v_t_ + 1e-8)).clamp(
min=cfg_renorm_min, max=1.0
)
v_t = v_t_ * scale
else:
# No CFG
pass
return v_t
def prepare_start_tokens(self, curr_kvlens, curr_rope, new_token_ids):
packed_start_tokens, packed_key_value_indexes = list(), list()
packed_query_position_ids = list()
curr = 0
for curr_kvlen, curr_position_id in zip(curr_kvlens, curr_rope):
packed_key_value_indexes.extend(range(curr, curr + curr_kvlen))
packed_start_tokens.append(new_token_ids["bos_token_id"])
packed_query_position_ids.append(curr_position_id)
curr += curr_kvlen
generation_input = {
"packed_start_tokens": torch.tensor(packed_start_tokens, dtype=torch.long),
"packed_query_position_ids": torch.tensor(
packed_query_position_ids, dtype=torch.long
),
"key_values_lens": torch.tensor(curr_kvlens, dtype=torch.int),
"packed_key_value_indexes": torch.tensor(
packed_key_value_indexes, dtype=torch.long
),
}
return generation_input
@torch.no_grad
def generate_text(
self,
past_key_values: NaiveCache,
packed_key_value_indexes: torch.LongTensor,
key_values_lens: torch.IntTensor,
packed_start_tokens: torch.LongTensor,
packed_query_position_ids: torch.LongTensor,
max_length: int,
do_sample: bool = False,
temperature: float = 1.0,
end_token_id: int = None,
):
step = 0
generated_sequence = []
curr_tokens = packed_start_tokens
while step < max_length:
generated_sequence.append(curr_tokens)
packed_text_embedding = self.language_model.model.embed_tokens(curr_tokens)
query_lens = torch.ones_like(curr_tokens)
packed_query_indexes = torch.cumsum(key_values_lens, dim=0) + torch.arange(
0,
len(key_values_lens),
device=key_values_lens.device,
dtype=key_values_lens.dtype,
)
uppacked = list(
packed_key_value_indexes.split(key_values_lens.tolist(), dim=0)
)
for i in range(len(uppacked)):
uppacked[i] += i
packed_key_value_indexes = torch.cat(uppacked, dim=0)
extra_inputs = {}
if self.use_moe:
extra_inputs = {"mode": "und"}
output = self.language_model.forward_inference(
packed_query_sequence=packed_text_embedding,
query_lens=query_lens,
packed_query_position_ids=packed_query_position_ids,
packed_query_indexes=packed_query_indexes,
past_key_values=past_key_values,
key_values_lens=key_values_lens,
packed_key_value_indexes=packed_key_value_indexes,
update_past_key_values=True,
is_causal=True,
**extra_inputs,
)
past_key_values = output.past_key_values
packed_query_sequence = output.packed_query_sequence
pred_logits = self.language_model.lm_head(packed_query_sequence)
if do_sample:
probs = nn.functional.softmax(pred_logits / temperature, dim=-1)
curr_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
else:
curr_tokens = torch.argmax(pred_logits, dim=-1)
uppacked = list(
packed_key_value_indexes.split(key_values_lens.tolist(), dim=0)
)
for i in range(len(uppacked)):
uppacked[i] = torch.cat(
[
uppacked[i],
torch.tensor([uppacked[i][-1] + 1], device=uppacked[i].device),
],
dim=0,
)
packed_key_value_indexes = torch.cat(uppacked, dim=0)
key_values_lens = key_values_lens + 1
packed_query_position_ids = packed_query_position_ids + 1
step += 1
if (
end_token_id is not None and curr_tokens[0] == end_token_id
): # only support batch=1
break
output_device = generated_sequence[0].device
return torch.stack([i.to(output_device) for i in generated_sequence], dim=0)
# for evaluation
@torch.no_grad()
def chat(
self,
tokenizer,
new_token_ids,
image_transform,
images,
prompt,
max_length: int,
do_sample: bool = False,
temperature: float = 1.0,
):
device = next(self.parameters()).device
if isinstance(new_token_ids, dict):
for k, v in new_token_ids.items():
if torch.is_tensor(v):
new_token_ids[k] = v.to(device)
elif torch.is_tensor(new_token_ids):
new_token_ids = new_token_ids.to(device)
# prefill
past_key_values = NaiveCache(self.config.llm_config.num_hidden_layers)
newlens = [0]
new_rope = [0]
# add images
for image in images:
generation_input, newlens, new_rope = self.prepare_vit_images(
curr_kvlens=newlens,
curr_rope=new_rope,
images=[image],
transforms=image_transform,
new_token_ids=new_token_ids,
)
for k, v in generation_input.items():
if torch.is_tensor(v):
generation_input[k] = v.to(device)
with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
past_key_values = self.forward_cache_update_vit(
past_key_values, **generation_input
)
# add text
generation_input, newlens, new_rope = self.prepare_prompts(
curr_kvlens=newlens,
curr_rope=new_rope,
prompts=[prompt],
tokenizer=tokenizer,
new_token_ids=new_token_ids,
)
for k, v in generation_input.items():
if torch.is_tensor(v):
generation_input[k] = v.to(device)
with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
past_key_values = self.forward_cache_update_text(
past_key_values, **generation_input
)
# decode
generation_input = self.prepare_start_tokens(newlens, new_rope, new_token_ids)
for k, v in generation_input.items():
if torch.is_tensor(v):
generation_input[k] = v.to(device)
with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16):
unpacked_latent = self.generate_text(
past_key_values=past_key_values,
max_length=max_length,
do_sample=do_sample,
temperature=temperature,
end_token_id=new_token_ids["eos_token_id"],
**generation_input,
)
output = tokenizer.decode(unpacked_latent[:, 0])
output = output.split("<|im_end|>")[0].split("<|im_start|>")[1]
return output
# Copyright (c) 2022 Facebook, Inc. and its affiliates.
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: CC BY-NC 4.0
#
# This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025-05-20.
#
# Original file was released under CC BY-NC 4.0, with the full license text
# available at https://github.com/facebookresearch/DiT/blob/main/LICENSE.txt.
#
# This modified file is released under the same license.
import math
import numpy as np
import torch
from torch import nn
from transformers.activations import ACT2FN
# --------------------------------------------------------
# 2D sine-cosine position embedding
# References:
# DiT: https://github.com/facebookresearch/DiT/blob/main/models.py
# --------------------------------------------------------
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size, grid_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token and extra_tokens > 0:
pos_embed = np.concatenate(
[np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0
)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float64)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
# --------------------------------------------------------
# TimestepEmbedder
# Reference:
# DiT: https://github.com/facebookresearch/DiT/blob/main/models.py
# --------------------------------------------------------
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
half = dim // 2
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32)
/ half
).to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
class MLPconnector(nn.Module):
def __init__(self, in_dim: int, out_dim: int, hidden_act: str):
super().__init__()
self.activation_fn = ACT2FN[hidden_act]
self.fc1 = nn.Linear(in_dim, out_dim)
self.fc2 = nn.Linear(out_dim, out_dim)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class PositionEmbedding(nn.Module):
def __init__(self, max_num_patch_per_side, hidden_size):
super().__init__()
self.max_num_patch_per_side = max_num_patch_per_side
self.hidden_size = hidden_size
self.pos_embed = nn.Parameter(
torch.zeros(max_num_patch_per_side**2, hidden_size), requires_grad=False
)
self._init_weights()
def _init_weights(self):
# Initialize (and freeze) pos_embed by sin-cos embedding:
pos_embed = get_2d_sincos_pos_embed(
self.hidden_size, self.max_num_patch_per_side
)
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float())
def forward(self, position_ids):
return self.pos_embed[position_ids]
# Copyright (c) 2024 The Qwen Team and The HuggingFace Inc. team.
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
#
# This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025-05-20.
#
# Original file was released under Apache-2.0, with the full license text
# available at https://github.com/huggingface/transformers/blob/main/LICENSE.
#
# This modified file is released under the same license.
from dataclasses import dataclass
from functools import partial
from typing import List, Optional, Tuple
import torch
from flash_attn import flash_attn_varlen_func
from torch import nn
from torch.nn.attention import SDPBackend, sdpa_kernel
from torch.nn.attention.flex_attention import flex_attention
from torch.nn.functional import scaled_dot_product_attention
from transformers.utils import ModelOutput
from modeling.qwen2.configuration_qwen2 import Qwen2Config as _Qwen2Config
from modeling.qwen2.modeling_qwen2 import (
Qwen2Attention,
Qwen2MLP,
Qwen2PreTrainedModel,
Qwen2RMSNorm,
Qwen2RotaryEmbedding,
apply_rotary_pos_emb,
)
torch._dynamo.config.cache_size_limit = 512
torch._dynamo.config.accumulated_cache_size_limit = 4096
# flex_attention = torch.compile(flex_attention) # , dynamic=True, mode='max-autotune'
flex_attention = torch.compile(flex_attention)
class Qwen2Config(_Qwen2Config):
r"""
This is the configuration class to store the configuration of a [`Qwen2Model`]. It is used to instantiate a
Qwen2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of
Qwen2-7B-beta [Qwen/Qwen2-7B-beta](https://huggingface.co/Qwen/Qwen2-7B-beta).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 151936):
Vocabulary size of the Qwen2 model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`Qwen2Model`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 22016):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*, defaults to 32):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 32768):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be tied.
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
accordingly.
Expected contents:
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'llama3'], with 'default' being the original RoPE implementation.
`factor` (`float`, *optional*):
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
original maximum pre-trained length.
`original_max_position_embeddings` (`int`, *optional*):
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
pretraining.
`attention_factor` (`float`, *optional*):
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
computation. If unspecified, it defaults to value recommended by the implementation, using the
`factor` field to infer the suggested value.
`beta_fast` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
ramp function. If unspecified, it defaults to 32.
`beta_slow` (`float`, *optional*):
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
ramp function. If unspecified, it defaults to 1.
`short_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`long_factor` (`List[float]`, *optional*):
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
size divided by the number of attention heads divided by 2
`low_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
`high_freq_factor` (`float`, *optional*):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
use_sliding_window (`bool`, *optional*, defaults to `False`):
Whether to use sliding window attention.
sliding_window (`int`, *optional*, defaults to 4096):
Sliding window attention (SWA) window size. If not specified, will default to `4096`.
max_window_layers (`int`, *optional*, defaults to 28):
The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
```python
>>> from transformers import Qwen2Model, Qwen2Config
>>> # Initializing a Qwen2 style configuration
>>> configuration = Qwen2Config()
>>> # Initializing a model from the Qwen2-7B style configuration
>>> model = Qwen2Model(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "qwen2"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=151936,
hidden_size=4096,
intermediate_size=22016,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=32,
hidden_act="silu",
max_position_embeddings=32768,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
use_sliding_window=False,
sliding_window=4096,
max_window_layers=28,
attention_dropout=0.0,
is_causal=True,
_attn_implementation="flash_attention_2",
qk_norm=True,
layer_module="Qwen2DecoderLayer",
freeze_und=False,
**kwargs,
):
super().__init__(
vocab_size=vocab_size,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
hidden_act=hidden_act,
max_position_embeddings=max_position_embeddings,
initializer_range=initializer_range,
rms_norm_eps=rms_norm_eps,
use_cache=use_cache,
tie_word_embeddings=tie_word_embeddings,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
use_sliding_window=use_sliding_window,
sliding_window=sliding_window,
max_window_layers=max_window_layers,
attention_dropout=attention_dropout,
is_causal=is_causal,
_attn_implementation=_attn_implementation,
**kwargs,
)
self.qk_norm = qk_norm
self.layer_module = layer_module
self.freeze_und = freeze_und
class NaiveCache:
def __init__(self, num_layers):
self.key_cache = {k: None for k in range(num_layers)}
self.value_cache = {k: None for k in range(num_layers)}
@property
def num_layers(self):
return len(self.key_cache)
@property
def seq_lens(self):
if self.key_cache[0] is not None:
return self.key_cache[0].shape[0]
else:
return 0
@dataclass
class BaseNavitOutputWithPast(ModelOutput):
packed_query_sequence: torch.FloatTensor = None
past_key_values: Optional[NaiveCache] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
def pad_sequence(tensor, pad_size):
H, L, D = tensor.shape
pad_tensor = tensor.new_zeros((H, pad_size, D))
return torch.cat([tensor, pad_tensor], dim=1)
class PackedAttention(Qwen2Attention):
def __init__(self, config, layer_idx: Optional[int] = None):
super().__init__(config, layer_idx)
if self.config.qk_norm:
self.q_norm = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.k_norm = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps)
else:
self.q_norm = nn.Identity()
self.k_norm = nn.Identity()
def forward(self, *args, **kwargs):
if self.training:
return self.forward_train(*args, **kwargs)
else:
return self.forward_inference(*args, **kwargs)
def forward_train(
self,
packed_sequence: torch.Tensor,
sample_lens: List[int],
attention_mask: List[torch.Tensor],
packed_position_embeddings: Tuple[torch.Tensor, torch.Tensor],
):
packed_query_states = self.q_proj(packed_sequence).view(
-1, self.num_heads, self.head_dim
)
packed_key_states = self.k_proj(packed_sequence).view(
-1, self.num_key_value_heads, self.head_dim
)
packed_value_states = self.v_proj(packed_sequence).view(
-1, self.num_key_value_heads, self.head_dim
)
packed_query_states = self.q_norm(packed_query_states)
packed_key_states = self.k_norm(packed_key_states)
packed_cos, packed_sin = packed_position_embeddings
packed_query_states, packed_key_states = apply_rotary_pos_emb(
packed_query_states,
packed_key_states,
packed_cos,
packed_sin,
unsqueeze_dim=1,
)
if isinstance(attention_mask, List):
packed_key_states = packed_key_states[:, :, None, :].repeat(
1, 1, self.num_key_value_groups, 1
)
packed_key_states = packed_key_states.reshape(
-1, self.num_heads, self.head_dim
)
packed_value_states = packed_value_states[:, :, None, :].repeat(
1, 1, self.num_key_value_groups, 1
)
packed_value_states = packed_value_states.reshape(
-1, self.num_heads, self.head_dim
)
unpacked_query_states = packed_query_states.transpose(0, 1).split(
sample_lens, dim=1
)
unpacked_key_states = packed_key_states.transpose(0, 1).split(
sample_lens, dim=1
)
unpacked_value_states = packed_value_states.transpose(0, 1).split(
sample_lens, dim=1
)
upacked_attn_output = []
for (
query_states,
key_states,
value_states,
attention_mask_per_sample,
) in zip(
unpacked_query_states,
unpacked_key_states,
unpacked_value_states,
attention_mask,
):
with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
attn_output = scaled_dot_product_attention(
query_states.to(torch.bfloat16).unsqueeze(0),
key_states.to(torch.bfloat16).unsqueeze(0),
value_states.to(torch.bfloat16).unsqueeze(0),
attention_mask_per_sample.to(torch.bfloat16).unsqueeze(0),
)
upacked_attn_output.append(attn_output.squeeze(0))
packed_attn_output = torch.cat(upacked_attn_output, dim=1)
else:
pad_size = sum(sample_lens) - packed_query_states.shape[0]
packed_query_states = pad_sequence(
packed_query_states.permute(1, 0, 2), pad_size
)
packed_key_states = pad_sequence(
packed_key_states.permute(1, 0, 2), pad_size
)
packed_value_states = pad_sequence(
packed_value_states.permute(1, 0, 2), pad_size
)
packed_attn_output = flex_attention(
packed_query_states.unsqueeze(0),
packed_key_states.unsqueeze(0),
packed_value_states.unsqueeze(0),
enable_gqa=True,
block_mask=attention_mask,
)
end_index = packed_attn_output.shape[2] - pad_size
packed_attn_output = packed_attn_output[0, :, :end_index, :]
packed_attn_output = packed_attn_output.transpose(0, 1).reshape(
-1, self.hidden_size
)
packed_attn_output = self.o_proj(packed_attn_output)
return packed_attn_output
def forward_inference(
self,
packed_query_sequence: torch.Tensor,
query_lens: torch.Tensor,
packed_query_position_embeddings: torch.Tensor,
packed_query_indexes: torch.Tensor,
past_key_values: Optional[NaiveCache] = None,
key_values_lens: Optional[torch.Tensor] = None,
packed_key_value_indexes: Optional[torch.Tensor] = None,
update_past_key_values=True,
is_causal=True,
output_attentions=False,
):
packed_query_states = self.q_proj(packed_query_sequence).view(
-1, self.num_heads, self.head_dim
)
packed_key_states = self.k_proj(packed_query_sequence).view(
-1, self.num_key_value_heads, self.head_dim
)
packed_value_states = self.v_proj(packed_query_sequence).view(
-1, self.num_key_value_heads, self.head_dim
)
packed_query_states = self.q_norm(packed_query_states)
packed_key_states = self.k_norm(packed_key_states)
packed_cos, packed_sin = packed_query_position_embeddings
packed_query_states, packed_key_states = apply_rotary_pos_emb(
packed_query_states,
packed_key_states,
packed_cos,
packed_sin,
unsqueeze_dim=1,
)
packed_query_states = packed_query_states.to(torch.bfloat16)
packed_key_states = packed_key_states.to(torch.bfloat16)
packed_value_states = packed_value_states.to(torch.bfloat16)
if (
past_key_values is not None
and past_key_values.key_cache[self.layer_idx] is not None
):
past_key_states = past_key_values.key_cache[self.layer_idx]
past_value_states = past_key_values.value_cache[self.layer_idx]
seqlens = sum(query_lens) + sum(key_values_lens)
merged_key_states = past_key_states.new_zeros(
(seqlens, self.num_key_value_heads, self.head_dim)
)
merged_value_states = past_key_states.new_zeros(
(seqlens, self.num_key_value_heads, self.head_dim)
)
merged_key_states[packed_query_indexes] = packed_key_states
merged_key_states[packed_key_value_indexes] = past_key_states
merged_value_states[packed_query_indexes] = packed_value_states
merged_value_states[packed_key_value_indexes] = past_value_states
key_values_lens = key_values_lens + query_lens
else:
merged_key_states = packed_key_states
merged_value_states = packed_value_states
key_values_lens = query_lens
cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(query_lens, dim=0), (1, 0))
cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(key_values_lens, dim=0), (1, 0)
)
packed_attn_output = flash_attn_varlen_func(
q=packed_query_states,
k=merged_key_states,
v=merged_value_states,
cu_seqlens_q=cu_seqlens_q.to(torch.int32),
cu_seqlens_k=cu_seqlens_k.to(torch.int32),
max_seqlen_q=max(query_lens).item(),
max_seqlen_k=max(key_values_lens).item(),
causal=is_causal,
)
packed_attn_output = packed_attn_output.reshape(-1, self.hidden_size)
packed_attn_output = self.o_proj(packed_attn_output)
if update_past_key_values:
past_key_values.key_cache[self.layer_idx] = merged_key_states
past_key_values.value_cache[self.layer_idx] = merged_value_states
return packed_attn_output, past_key_values
class PackedAttentionMoT(Qwen2Attention):
def __init__(self, config, layer_idx: Optional[int] = None):
super().__init__(config, layer_idx)
if self.config.qk_norm:
self.q_norm = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.k_norm = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.q_norm_moe_gen = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.k_norm_moe_gen = Qwen2RMSNorm(self.head_dim, eps=config.rms_norm_eps)
else:
self.q_norm = nn.Identity()
self.k_norm = nn.Identity()
self.q_norm_moe_gen = nn.Identity()
self.k_norm_moe_gen = nn.Identity()
self.q_proj_moe_gen = nn.Linear(
self.hidden_size, self.num_heads * self.head_dim, bias=True
)
self.k_proj_moe_gen = nn.Linear(
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True
)
self.v_proj_moe_gen = nn.Linear(
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True
)
self.o_proj_moe_gen = nn.Linear(
self.num_heads * self.head_dim, self.hidden_size, bias=False
)
def forward(self, *args, **kwargs):
if self.training:
return self.forward_train(*args, **kwargs)
else:
return self.forward_inference(*args, **kwargs)
def forward_train(
self,
packed_sequence: torch.Tensor,
sample_lens: List[int],
attention_mask,
packed_position_embeddings: Tuple[torch.Tensor, torch.Tensor],
packed_und_token_indexes: torch.LongTensor,
packed_gen_token_indexes: torch.LongTensor,
):
packed_query_states = packed_sequence.new_zeros(
(packed_sequence.shape[0], self.num_heads * self.head_dim)
)
packed_key_states = packed_sequence.new_zeros(
(packed_sequence.shape[0], self.num_key_value_heads * self.head_dim)
)
packed_value_states = packed_sequence.new_zeros(
(packed_sequence.shape[0], self.num_key_value_heads * self.head_dim)
)
packed_sequence_und = packed_sequence[packed_und_token_indexes]
packed_sequence_gen = packed_sequence[packed_gen_token_indexes]
packed_query_states[packed_und_token_indexes] = self.q_proj(packed_sequence_und)
packed_query_states[packed_gen_token_indexes] = self.q_proj_moe_gen(
packed_sequence_gen
)
packed_key_states[packed_und_token_indexes] = self.k_proj(packed_sequence_und)
packed_key_states[packed_gen_token_indexes] = self.k_proj_moe_gen(
packed_sequence_gen
)
packed_value_states[packed_und_token_indexes] = self.v_proj(packed_sequence_und)
packed_value_states[packed_gen_token_indexes] = self.v_proj_moe_gen(
packed_sequence_gen
)
packed_query_states = packed_query_states.view(
-1, self.num_heads, self.head_dim
)
packed_key_states = packed_key_states.view(
-1, self.num_key_value_heads, self.head_dim
)
packed_value_states = packed_value_states.view(
-1, self.num_key_value_heads, self.head_dim
)
if self.config.freeze_und:
packed_value_states[packed_und_token_indexes] = packed_value_states[
packed_und_token_indexes
].detach()
packed_query_states_ = packed_query_states.new_zeros(packed_query_states.shape)
packed_key_states_ = packed_key_states.new_zeros(packed_key_states.shape)
packed_query_states_[packed_und_token_indexes] = self.q_norm(
packed_query_states[packed_und_token_indexes]
)
if self.config.freeze_und:
packed_query_states_[packed_und_token_indexes] = packed_query_states_[
packed_und_token_indexes
].detach()
packed_query_states_[packed_gen_token_indexes] = self.q_norm_moe_gen(
packed_query_states[packed_gen_token_indexes]
)
packed_key_states_[packed_und_token_indexes] = self.k_norm(
packed_key_states[packed_und_token_indexes]
)
if self.config.freeze_und:
packed_key_states_[packed_und_token_indexes] = packed_key_states_[
packed_und_token_indexes
].detach()
packed_key_states_[packed_gen_token_indexes] = self.k_norm_moe_gen(
packed_key_states[packed_gen_token_indexes]
)
packed_cos, packed_sin = packed_position_embeddings
packed_query_states_, packed_key_states_ = apply_rotary_pos_emb(
packed_query_states_,
packed_key_states_,
packed_cos,
packed_sin,
unsqueeze_dim=1,
)
if isinstance(attention_mask, List):
packed_key_states_ = packed_key_states_[:, :, None, :].repeat(
1, 1, self.num_key_value_groups, 1
)
packed_key_states_ = packed_key_states_.reshape(
-1, self.num_heads, self.head_dim
)
packed_value_states = packed_value_states[:, :, None, :].repeat(
1, 1, self.num_key_value_groups, 1
)
packed_value_states = packed_value_states.reshape(
-1, self.num_heads, self.head_dim
)
unpacked_query_states = packed_query_states_.transpose(0, 1).split(
sample_lens, dim=1
)
unpacked_key_states = packed_key_states_.transpose(0, 1).split(
sample_lens, dim=1
)
unpacked_value_states = packed_value_states.transpose(0, 1).split(
sample_lens, dim=1
)
upacked_attn_output = []
for (
query_states,
key_states,
value_states,
attention_mask_per_sample,
) in zip(
unpacked_query_states,
unpacked_key_states,
unpacked_value_states,
attention_mask,
):
with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
attn_output = scaled_dot_product_attention(
query_states.to(torch.bfloat16).unsqueeze(0),
key_states.to(torch.bfloat16).unsqueeze(0),
value_states.to(torch.bfloat16).unsqueeze(0),
attention_mask_per_sample.to(torch.bfloat16).unsqueeze(0),
)
upacked_attn_output.append(attn_output.squeeze(0))
packed_attn_output = torch.cat(upacked_attn_output, dim=1)
else:
pad_size = sum(sample_lens) - packed_query_states.shape[0]
packed_query_states_ = pad_sequence(
packed_query_states_.permute(1, 0, 2), pad_size
)
packed_key_states_ = pad_sequence(
packed_key_states_.permute(1, 0, 2), pad_size
)
packed_value_states = pad_sequence(
packed_value_states.permute(1, 0, 2), pad_size
)
packed_attn_output = flex_attention(
packed_query_states_.unsqueeze(0), # 1, num_head, L, head_dim
packed_key_states_.unsqueeze(0),
packed_value_states.unsqueeze(0),
enable_gqa=True,
block_mask=attention_mask,
)
end_index = packed_attn_output.shape[2] - pad_size
packed_attn_output = packed_attn_output[0, :, :end_index, :]
packed_attn_output = packed_attn_output.transpose(0, 1).reshape(
-1, self.num_heads * self.head_dim
)
packed_attn_output_ = packed_attn_output.new_zeros(packed_attn_output.shape)
packed_attn_output_[packed_und_token_indexes] = self.o_proj(
packed_attn_output[packed_und_token_indexes]
)
packed_attn_output_[packed_gen_token_indexes] = self.o_proj_moe_gen(
packed_attn_output[packed_gen_token_indexes]
)
return packed_attn_output_
def forward_inference(
self,
packed_query_sequence: torch.Tensor,
query_lens: torch.Tensor,
packed_query_position_embeddings: torch.Tensor,
packed_query_indexes: torch.Tensor,
past_key_values: Optional[NaiveCache] = None,
key_values_lens: Optional[torch.Tensor] = None,
packed_key_value_indexes: Optional[torch.Tensor] = None,
update_past_key_values=True,
is_causal=True,
mode="und",
packed_vae_token_indexes=None,
packed_text_indexes=None,
):
if mode == "und":
packed_query_states = self.q_proj(packed_query_sequence).view(
-1, self.num_heads, self.head_dim
)
packed_key_states = self.k_proj(packed_query_sequence).view(
-1, self.num_key_value_heads, self.head_dim
)
packed_value_states = self.v_proj(packed_query_sequence).view(
-1, self.num_key_value_heads, self.head_dim
)
packed_query_states = self.q_norm(packed_query_states)
packed_key_states = self.k_norm(packed_key_states)
elif mode == "gen":
packed_query_sequence = packed_query_sequence.to(torch.bfloat16)
packed_query_states = packed_query_sequence.new_zeros(
(packed_query_sequence.shape[0], self.num_heads * self.head_dim)
)
packed_key_states = packed_query_sequence.new_zeros(
(
packed_query_sequence.shape[0],
self.num_key_value_heads * self.head_dim,
)
)
packed_value_states = packed_query_sequence.new_zeros(
(
packed_query_sequence.shape[0],
self.num_key_value_heads * self.head_dim,
)
)
packed_text_query_sequence = packed_query_sequence[packed_text_indexes]
packed_vae_query_sequence = packed_query_sequence[packed_vae_token_indexes]
packed_query_states[packed_text_indexes] = self.q_proj(
packed_text_query_sequence
)
packed_query_states[packed_vae_token_indexes] = self.q_proj_moe_gen(
packed_vae_query_sequence
)
packed_key_states[packed_text_indexes] = self.k_proj(
packed_text_query_sequence
)
packed_key_states[packed_vae_token_indexes] = self.k_proj_moe_gen(
packed_vae_query_sequence
)
packed_value_states[packed_text_indexes] = self.v_proj(
packed_text_query_sequence
)
packed_value_states[packed_vae_token_indexes] = self.v_proj_moe_gen(
packed_vae_query_sequence
)
packed_query_states = packed_query_states.view(
-1, self.num_heads, self.head_dim
)
packed_key_states = packed_key_states.view(
-1, self.num_key_value_heads, self.head_dim
)
packed_value_states = packed_value_states.view(
-1, self.num_key_value_heads, self.head_dim
)
packed_query_states = packed_query_states.to(torch.float32)
packed_query_states[packed_text_indexes] = self.q_norm(
packed_query_states[packed_text_indexes]
)
packed_query_states[packed_vae_token_indexes] = self.q_norm_moe_gen(
packed_query_states[packed_vae_token_indexes]
)
packed_key_states = packed_key_states.to(torch.float32)
packed_key_states[packed_text_indexes] = self.k_norm(
packed_key_states[packed_text_indexes]
)
packed_key_states[packed_vae_token_indexes] = self.k_norm_moe_gen(
packed_key_states[packed_vae_token_indexes]
)
packed_cos, packed_sin = packed_query_position_embeddings
packed_query_states, packed_key_states = apply_rotary_pos_emb(
packed_query_states,
packed_key_states,
packed_cos,
packed_sin,
unsqueeze_dim=1,
)
packed_query_states = packed_query_states.to(torch.bfloat16)
packed_key_states = packed_key_states.to(torch.bfloat16)
packed_value_states = packed_value_states.to(torch.bfloat16)
if (
past_key_values is not None
and past_key_values.key_cache[self.layer_idx] is not None
):
past_key_states = past_key_values.key_cache[self.layer_idx]
past_value_states = past_key_values.value_cache[self.layer_idx]
seqlens = sum(query_lens) + sum(key_values_lens)
merged_key_states = past_key_states.new_zeros(
size=[seqlens, self.num_key_value_heads, self.head_dim]
)
merged_value_states = past_key_states.new_zeros(
size=[seqlens, self.num_key_value_heads, self.head_dim]
)
merged_key_states[packed_query_indexes] = packed_key_states
merged_key_states[packed_key_value_indexes] = past_key_states
merged_value_states[packed_query_indexes] = packed_value_states
merged_value_states[packed_key_value_indexes] = past_value_states
key_values_lens = key_values_lens + query_lens
else:
merged_key_states = packed_key_states
merged_value_states = packed_value_states
key_values_lens = query_lens
cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(query_lens, dim=0), (1, 0))
cu_seqlens_k = torch.nn.functional.pad(
torch.cumsum(key_values_lens, dim=0), (1, 0)
)
packed_attn_output = flash_attn_varlen_func(
q=packed_query_states,
k=merged_key_states,
v=merged_value_states,
cu_seqlens_q=cu_seqlens_q.to(torch.int32),
cu_seqlens_k=cu_seqlens_k.to(torch.int32),
max_seqlen_q=max(query_lens).item(),
max_seqlen_k=max(key_values_lens).item(),
causal=is_causal,
)
packed_attn_output = packed_attn_output.reshape(-1, self.hidden_size)
if mode == "und":
packed_attn_output = self.o_proj(packed_attn_output)
elif mode == "gen":
packed_attn_output[packed_text_indexes] = self.o_proj(
packed_attn_output[packed_text_indexes]
)
packed_attn_output[packed_vae_token_indexes] = self.o_proj_moe_gen(
packed_attn_output[packed_vae_token_indexes]
)
if update_past_key_values:
past_key_values.key_cache[self.layer_idx] = merged_key_states
past_key_values.value_cache[self.layer_idx] = merged_value_states
return packed_attn_output, past_key_values
class Qwen2DecoderLayer(nn.Module):
def __init__(self, config, layer_idx: Optional[int] = None):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = PackedAttention(config, layer_idx)
self.mlp = Qwen2MLP(config)
self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Qwen2RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def forward(self, *args, **kwargs):
if self.training:
return self.forward_train(*args, **kwargs)
else:
return self.forward_inference(*args, **kwargs)
def forward_train(
self,
packed_sequence: torch.Tensor,
sample_lens: List[int],
attention_mask,
packed_position_embeddings: Tuple[torch.Tensor, torch.Tensor],
) -> torch.Tensor:
residual = packed_sequence
packed_sequence = self.input_layernorm(packed_sequence)
# Self Attention
packed_sequence = self.self_attn(
packed_sequence=packed_sequence,
sample_lens=sample_lens,
attention_mask=attention_mask,
packed_position_embeddings=packed_position_embeddings,
)
packed_sequence = residual + packed_sequence
# Fully Connected
residual = packed_sequence
packed_sequence = self.post_attention_layernorm(packed_sequence)
packed_sequence = self.mlp(packed_sequence)
packed_sequence = residual + packed_sequence
return packed_sequence
def forward_inference(
self,
packed_query_sequence: torch.Tensor,
query_lens: torch.Tensor,
packed_query_position_embeddings: torch.Tensor,
packed_query_indexes: torch.Tensor,
past_key_values: Optional[NaiveCache] = None,
key_values_lens: Optional[torch.Tensor] = None,
packed_key_value_indexes: Optional[torch.Tensor] = None,
update_past_key_values=True,
is_causal=True,
) -> BaseNavitOutputWithPast:
residual = packed_query_sequence
packed_query_sequence = self.input_layernorm(packed_query_sequence)
# Self Attention
packed_query_sequence, past_key_values = self.self_attn(
packed_query_sequence=packed_query_sequence,
query_lens=query_lens,
packed_query_position_embeddings=packed_query_position_embeddings,
packed_query_indexes=packed_query_indexes,
past_key_values=past_key_values,
key_values_lens=key_values_lens,
packed_key_value_indexes=packed_key_value_indexes,
update_past_key_values=update_past_key_values,
is_causal=is_causal,
)
packed_query_sequence = residual + packed_query_sequence
# Fully Connected
residual = packed_query_sequence
packed_query_sequence = self.post_attention_layernorm(packed_query_sequence)
packed_query_sequence = self.mlp(packed_query_sequence)
packed_query_sequence = residual + packed_query_sequence
return packed_query_sequence, past_key_values
class Qwen2MoTDecoderLayer(nn.Module):
def __init__(
self,
config,
layer_idx: Optional[int] = None,
attn_module: Optional[Qwen2Attention] = PackedAttentionMoT,
):
super().__init__()
self.hidden_size = config.hidden_size
self.freeze_und = config.freeze_und
self.self_attn = attn_module(config, layer_idx)
self.mlp = Qwen2MLP(config)
self.mlp_moe_gen = Qwen2MLP(config)
self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.input_layernorm_moe_gen = Qwen2RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.post_attention_layernorm = Qwen2RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.post_attention_layernorm_moe_gen = Qwen2RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def forward(self, *args, **kwargs):
if self.training:
return self.forward_train(*args, **kwargs)
else:
return self.forward_inference(*args, **kwargs)
def forward_train(
self,
packed_sequence: torch.Tensor,
sample_lens: List[int],
attention_mask,
packed_position_embeddings: Tuple[torch.Tensor, torch.Tensor],
packed_und_token_indexes: torch.LongTensor,
packed_gen_token_indexes: torch.LongTensor,
) -> torch.Tensor:
residual = packed_sequence
packed_sequence_ = packed_sequence.new_zeros(packed_sequence.shape)
packed_sequence_[packed_und_token_indexes] = self.input_layernorm(
packed_sequence[packed_und_token_indexes]
)
packed_sequence_[packed_gen_token_indexes] = self.input_layernorm_moe_gen(
packed_sequence[packed_gen_token_indexes]
)
# Self Attention
packed_sequence_ = self.self_attn(
packed_sequence=packed_sequence_,
sample_lens=sample_lens,
attention_mask=attention_mask,
packed_position_embeddings=packed_position_embeddings,
packed_und_token_indexes=packed_und_token_indexes,
packed_gen_token_indexes=packed_gen_token_indexes,
)
if self.freeze_und:
packed_sequence_[packed_und_token_indexes] = packed_sequence_[
packed_und_token_indexes
].detach()
packed_sequence = residual + packed_sequence_
# Fully Connected
residual = packed_sequence
packed_sequence_ = packed_sequence.new_zeros(packed_sequence.shape)
packed_sequence_[packed_und_token_indexes] = self.mlp(
self.post_attention_layernorm(packed_sequence[packed_und_token_indexes])
)
if self.freeze_und:
packed_sequence_[packed_und_token_indexes] = packed_sequence_[
packed_und_token_indexes
].detach()
packed_sequence_[packed_gen_token_indexes] = self.mlp_moe_gen(
self.post_attention_layernorm_moe_gen(
packed_sequence[packed_gen_token_indexes]
)
)
packed_sequence = residual + packed_sequence_
return packed_sequence
def forward_inference(
self,
packed_query_sequence: torch.Tensor,
query_lens: torch.Tensor,
packed_query_position_embeddings: torch.Tensor,
packed_query_indexes: torch.Tensor,
past_key_values: Optional[NaiveCache] = None,
key_values_lens: Optional[torch.Tensor] = None,
packed_key_value_indexes: Optional[torch.Tensor] = None,
update_past_key_values=True,
is_causal=True,
mode="und",
packed_vae_token_indexes=None,
packed_text_indexes=None,
) -> BaseNavitOutputWithPast:
residual = packed_query_sequence
if mode == "und":
packed_query_sequence = self.input_layernorm(packed_query_sequence)
elif mode == "gen":
packed_query_sequence_ = torch.zeros_like(packed_query_sequence)
packed_query_sequence_[packed_text_indexes] = self.input_layernorm(
packed_query_sequence[packed_text_indexes]
)
packed_query_sequence_[packed_vae_token_indexes] = (
self.input_layernorm_moe_gen(
packed_query_sequence[packed_vae_token_indexes]
)
)
packed_query_sequence = packed_query_sequence_
# Self Attention
packed_query_sequence, past_key_values = self.self_attn(
packed_query_sequence=packed_query_sequence,
query_lens=query_lens,
packed_query_position_embeddings=packed_query_position_embeddings,
packed_query_indexes=packed_query_indexes,
past_key_values=past_key_values,
key_values_lens=key_values_lens,
packed_key_value_indexes=packed_key_value_indexes,
update_past_key_values=update_past_key_values,
is_causal=is_causal,
mode=mode,
packed_vae_token_indexes=packed_vae_token_indexes,
packed_text_indexes=packed_text_indexes,
)
packed_query_sequence = residual + packed_query_sequence
# Fully Connected
residual = packed_query_sequence
if mode == "und":
packed_query_sequence = self.post_attention_layernorm(packed_query_sequence)
packed_query_sequence = self.mlp(packed_query_sequence)
elif mode == "gen":
packed_text_query_sequence = packed_query_sequence[packed_text_indexes]
packed_vae_query_sequence = packed_query_sequence[packed_vae_token_indexes]
packed_text_query_sequence = self.post_attention_layernorm(
packed_text_query_sequence
).to(torch.bfloat16)
packed_vae_query_sequence = self.post_attention_layernorm_moe_gen(
packed_vae_query_sequence
).to(torch.bfloat16)
packed_query_sequence_ = torch.zeros_like(packed_query_sequence).to(
torch.bfloat16
)
packed_query_sequence_[packed_text_indexes] = self.mlp(
packed_text_query_sequence
)
packed_query_sequence_[packed_vae_token_indexes] = self.mlp_moe_gen(
packed_vae_query_sequence
)
packed_query_sequence = packed_query_sequence_
packed_query_sequence = residual + packed_query_sequence
return packed_query_sequence, past_key_values
class Qwen2MoEDecoderLayer(nn.Module):
def __init__(self, config, layer_idx: Optional[int] = None):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = PackedAttention(config, layer_idx)
self.mlp = Qwen2MLP(config)
self.mlp_moe_gen = Qwen2MLP(config)
self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Qwen2RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def forward(self, *args, **kwargs):
if self.training:
return self.forward_train(*args, **kwargs)
else:
return self.forward_inference(*args, **kwargs)
def forward_train(
self,
packed_sequence: torch.Tensor,
sample_lens: List[int],
attention_mask,
packed_position_embeddings: Tuple[torch.Tensor, torch.Tensor],
packed_und_token_indexes: torch.LongTensor,
packed_gen_token_indexes: torch.LongTensor,
) -> torch.Tensor:
residual = packed_sequence
packed_sequence = self.input_layernorm(packed_sequence)
# Self Attention
packed_sequence = self.self_attn(
packed_sequence=packed_sequence,
sample_lens=sample_lens,
attention_mask=attention_mask,
packed_position_embeddings=packed_position_embeddings,
)
packed_sequence = residual + packed_sequence
# Fully Connected
residual = packed_sequence
packed_sequence = self.post_attention_layernorm(packed_sequence)
packed_sequence_new = packed_sequence.new_zeros(packed_sequence.shape)
packed_sequence_und = self.mlp(packed_sequence[packed_und_token_indexes])
packed_sequence_gen = self.mlp_moe_gen(
packed_sequence[packed_gen_token_indexes]
)
packed_sequence_new[packed_und_token_indexes] = packed_sequence_und
packed_sequence_new[packed_gen_token_indexes] = packed_sequence_gen
packed_sequence = residual + packed_sequence_new
return packed_sequence
def forward_inference(
self,
packed_query_sequence: torch.Tensor,
query_lens: torch.Tensor,
packed_query_position_embeddings: torch.Tensor,
packed_query_indexes: torch.Tensor,
past_key_values: Optional[NaiveCache] = None,
key_values_lens: Optional[torch.Tensor] = None,
packed_key_value_indexes: Optional[torch.Tensor] = None,
update_past_key_values=True,
is_causal=True,
mode="und",
packed_vae_token_indexes=None,
packed_text_indexes=None,
) -> BaseNavitOutputWithPast:
residual = packed_query_sequence
packed_query_sequence = self.input_layernorm(packed_query_sequence)
# Self Attention
packed_query_sequence, past_key_values = self.self_attn(
packed_query_sequence=packed_query_sequence,
query_lens=query_lens,
packed_query_position_embeddings=packed_query_position_embeddings,
packed_query_indexes=packed_query_indexes,
past_key_values=past_key_values,
key_values_lens=key_values_lens,
packed_key_value_indexes=packed_key_value_indexes,
update_past_key_values=update_past_key_values,
is_causal=is_causal,
)
packed_query_sequence = residual + packed_query_sequence
# Fully Connected
residual = packed_query_sequence
packed_query_sequence = self.post_attention_layernorm(packed_query_sequence)
if mode == "und":
packed_query_sequence = self.mlp(packed_query_sequence)
elif mode == "gen":
packed_query_sequence_ = torch.zeros_like(packed_query_sequence).to(
torch.bfloat16
)
packed_query_sequence_[packed_text_indexes] = self.mlp(
packed_query_sequence[packed_text_indexes]
)
packed_query_sequence_[packed_vae_token_indexes] = self.mlp_moe_gen(
packed_query_sequence[packed_vae_token_indexes]
)
packed_query_sequence = packed_query_sequence_
packed_query_sequence = residual + packed_query_sequence
return packed_query_sequence, past_key_values
Decoder_layer_dict = {
"Qwen2DecoderLayer": Qwen2DecoderLayer,
"Qwen2MoEDecoderLayer": Qwen2MoEDecoderLayer,
"Qwen2MoTDecoderLayer": partial(
Qwen2MoTDecoderLayer, attn_module=PackedAttentionMoT
),
}
class Qwen2Model(Qwen2PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.use_moe = "Mo" in config.layer_module
self.embed_tokens = nn.Embedding(
config.vocab_size, config.hidden_size, self.padding_idx
)
layer_module = Decoder_layer_dict[config.layer_module]
self.layers = nn.ModuleList(
[
layer_module(config, layer_idx)
for layer_idx in range(config.num_hidden_layers)
]
)
self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
if self.use_moe:
self.norm_moe_gen = Qwen2RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.rotary_emb = Qwen2RotaryEmbedding(config=config)
# Initialize weights and apply final processing
self.post_init()
def forward(self, *args, **kwargs):
if self.training:
return self.forward_train(*args, **kwargs)
else:
return self.forward_inference(*args, **kwargs)
def forward_train(
self,
packed_sequence: torch.Tensor,
sample_lens: List[int],
attention_mask,
packed_position_ids: torch.Tensor,
packed_und_token_indexes: Optional[torch.LongTensor] = None,
packed_gen_token_indexes: Optional[torch.LongTensor] = None,
) -> torch.Tensor:
if self.config.freeze_und:
packed_sequence[packed_und_token_indexes] = packed_sequence[
packed_und_token_indexes
].detach()
# create position embeddings to be shared across the decoder layers
cos, sin = self.rotary_emb(packed_sequence, packed_position_ids.unsqueeze(0))
cos = cos.squeeze(0)
sin = sin.squeeze(0)
packed_position_embeddings = (cos, sin)
extra_inputs = {}
if self.use_moe:
assert packed_und_token_indexes is not None
if packed_gen_token_indexes is None:
packed_gen_token_indexes = packed_und_token_indexes.new_ones(size=[0])
extra_inputs.update(
packed_und_token_indexes=packed_und_token_indexes,
packed_gen_token_indexes=packed_gen_token_indexes,
)
for decoder_layer in self.layers:
packed_sequence = decoder_layer(
packed_sequence=packed_sequence,
sample_lens=sample_lens,
attention_mask=attention_mask,
packed_position_embeddings=packed_position_embeddings,
**extra_inputs,
)
if self.use_moe:
packed_sequence_ = torch.zeros_like(packed_sequence)
packed_sequence_[packed_und_token_indexes] = self.norm(
packed_sequence[packed_und_token_indexes]
)
if self.config.freeze_und:
packed_sequence_[packed_und_token_indexes] = packed_sequence_[
packed_und_token_indexes
].detach()
packed_sequence_[packed_gen_token_indexes] = self.norm_moe_gen(
packed_sequence[packed_gen_token_indexes]
)
return packed_sequence_
else:
return self.norm(packed_sequence)
def forward_inference(
self,
packed_query_sequence: torch.Tensor,
query_lens: torch.Tensor,
packed_query_position_ids: torch.Tensor,
packed_query_indexes: torch.Tensor,
past_key_values: Optional[NaiveCache] = None,
key_values_lens: Optional[torch.Tensor] = None,
packed_key_value_indexes: Optional[torch.Tensor] = None,
update_past_key_values=True,
is_causal=True,
mode="und",
packed_vae_token_indexes=None,
packed_text_indexes=None,
) -> BaseNavitOutputWithPast:
# create position embeddings to be shared across the decoder layers
cos, sin = self.rotary_emb(
packed_query_sequence, packed_query_position_ids.unsqueeze(0)
)
cos = cos.squeeze(0)
sin = sin.squeeze(0)
packed_query_position_embeddings = (cos, sin)
extra_inputs = {}
if self.use_moe:
extra_inputs.update(mode=mode)
if mode == "gen":
assert packed_vae_token_indexes is not None
assert packed_text_indexes is not None
extra_inputs.update(
packed_vae_token_indexes=packed_vae_token_indexes,
packed_text_indexes=packed_text_indexes,
)
for decoder_layer in self.layers:
packed_query_sequence, past_key_values = decoder_layer(
packed_query_sequence=packed_query_sequence,
query_lens=query_lens,
packed_query_position_embeddings=packed_query_position_embeddings,
packed_query_indexes=packed_query_indexes,
past_key_values=past_key_values,
key_values_lens=key_values_lens,
packed_key_value_indexes=packed_key_value_indexes,
update_past_key_values=update_past_key_values,
is_causal=is_causal,
**extra_inputs,
)
if self.use_moe:
if mode == "und":
packed_query_sequence = self.norm(packed_query_sequence)
elif mode == "gen":
packed_query_sequence_ = torch.zeros_like(packed_query_sequence)
packed_query_sequence_[packed_text_indexes] = self.norm(
packed_query_sequence[packed_text_indexes]
)
packed_query_sequence_[packed_vae_token_indexes] = self.norm_moe_gen(
packed_query_sequence[packed_vae_token_indexes]
)
packed_query_sequence = packed_query_sequence_
else:
packed_query_sequence = self.norm(packed_query_sequence)
return BaseNavitOutputWithPast(
packed_query_sequence=packed_query_sequence,
past_key_values=past_key_values,
)
class Qwen2ForCausalLM(Qwen2PreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.model = Qwen2Model(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def init_moe(self):
for name, param in self.named_parameters():
if "moe_gen" in name:
original_name = name.replace("_moe_gen", "")
param.data.copy_(self.state_dict()[original_name].data)
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
def forward(self, *args, **kwargs):
if self.training:
return self.forward_train(*args, **kwargs)
else:
return self.forward_inference(*args, **kwargs)
def forward_train(
self,
packed_sequence: torch.Tensor,
sample_lens: List[int],
attention_mask,
packed_position_ids: torch.Tensor,
packed_und_token_indexes: Optional[torch.LongTensor] = None,
packed_gen_token_indexes: Optional[torch.LongTensor] = None,
) -> torch.Tensor:
outputs = self.model(
packed_sequence=packed_sequence,
sample_lens=sample_lens,
packed_position_ids=packed_position_ids,
attention_mask=attention_mask,
packed_und_token_indexes=packed_und_token_indexes,
packed_gen_token_indexes=packed_gen_token_indexes,
)
return outputs
def forward_inference(
self,
packed_query_sequence: torch.Tensor,
query_lens: torch.Tensor,
packed_query_position_ids: torch.Tensor,
packed_query_indexes: torch.Tensor,
past_key_values: Optional[NaiveCache] = None,
key_values_lens: Optional[torch.Tensor] = None,
packed_key_value_indexes: Optional[torch.Tensor] = None,
update_past_key_values=True,
is_causal=True,
mode="und",
packed_vae_token_indexes=None,
packed_text_indexes=None,
) -> BaseNavitOutputWithPast:
outputs = self.model(
packed_query_sequence=packed_query_sequence,
query_lens=query_lens,
packed_query_position_ids=packed_query_position_ids,
packed_query_indexes=packed_query_indexes,
past_key_values=past_key_values,
key_values_lens=key_values_lens,
packed_key_value_indexes=packed_key_value_indexes,
update_past_key_values=update_past_key_values,
is_causal=is_causal,
mode=mode,
packed_vae_token_indexes=packed_vae_token_indexes,
packed_text_indexes=packed_text_indexes,
)
return outputs
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment