Commit 3d735feb authored by luopl's avatar luopl
Browse files

"Initial commit"

parents
Pipeline #3074 canceled with stages
[
{
"image": "demo/images/10095.png",
"conversations": [
{
"from": "human",
"value": "Is the value of Favorable 38 in 2015?\n<image>"
},
{
"from": "gpt",
"value": "Yes"
}
]
},
{
"image": "demo/images/10095.png",
"conversations": [
{
"from": "human",
"value": "How many values are below 40 in Unfavorable graph?\n<image>"
},
{
"from": "gpt",
"value": "6"
}
]
},
{
"image": "demo/images/10149.png",
"conversations": [
{
"from": "human",
"value": "In which year the value was 51?\n<image>"
},
{
"from": "gpt",
"value": "2014"
}
]
},
]
\ No newline at end of file
[
{
"video": "v_7bUu05RIksU.mp4",
"conversations": [
{
"from": "human",
"value": "<video>\nCan you give me an overview of the video content?"
},
{
"from": "gpt",
"value": "The video showcases a group of men washing cars in a parking lot using various tools and techniques, such as pressure washing hoses, hand towels, and brushes. There are also text displays of special offers and business details throughout the video. The men work on different cars, including a white car and a black SUV."
}
]
},
{
"video": "v_TpB_zMG3XBA.mp4",
"conversations": [
{
"from": "human",
"value": "<video>\nCan you describe what the woman is doing on the lake?"
},
{
"from": "gpt",
"value": "The woman is pulled forward in the lake and then begins to glide on the water using a surfboard. She sways left and right while sliding, and then slowly squats down, leaving the surfboard and falling into the water and floating on the surface."
}
]
},
{
"video": "v_rBMQFpHspmo.mp4",
"conversations": [
{
"from": "human",
"value": "<video>\nWhat is the main activity happening in the video?"
},
{
"from": "gpt",
"value": "The main activity in the video is a boy attempting to pole vault over a bar in an indoor field while being watched by several people."
}
]
}
]
\ No newline at end of file
import re
# Define placeholders for dataset paths
CAMBRIAN_737K = {
"annotation_path": "PATH_TO_CAMBRIAN_737K_ANNOTATION",
"data_path": "",
}
CAMBRIAN_737K_PACK = {
"annotation_path": f"PATH_TO_CAMBRIAN_737K_ANNOTATION_PACKED",
"data_path": f"",
}
MP_DOC = {
"annotation_path": "PATH_TO_MP_DOC_ANNOTATION",
"data_path": "PATH_TO_MP_DOC_DATA",
}
CLEVR_MC = {
"annotation_path": "PATH_TO_CLEVR_MC_ANNOTATION",
"data_path": "PATH_TO_CLEVR_MC_DATA",
}
VIDEOCHATGPT = {
"annotation_path": "PATH_TO_VIDEOCHATGPT_ANNOTATION",
"data_path": "PATH_TO_VIDEOCHATGPT_DATA",
}
data_dict = {
"cambrian_737k": CAMBRIAN_737K,
"cambrian_737k_pack": CAMBRIAN_737K_PACK,
"mp_doc": MP_DOC,
"clevr_mc": CLEVR_MC,
"videochatgpt": VIDEOCHATGPT,
}
def parse_sampling_rate(dataset_name):
match = re.search(r"%(\d+)$", dataset_name)
if match:
return int(match.group(1)) / 100.0
return 1.0
def data_list(dataset_names):
config_list = []
for dataset_name in dataset_names:
sampling_rate = parse_sampling_rate(dataset_name)
dataset_name = re.sub(r"%(\d+)$", "", dataset_name)
if dataset_name in data_dict.keys():
config = data_dict[dataset_name].copy()
config["sampling_rate"] = sampling_rate
config_list.append(config)
else:
raise ValueError(f"do not find {dataset_name}")
return config_list
if __name__ == "__main__":
dataset_names = ["cambrian_737k"]
configs = data_list(dataset_names)
for config in configs:
print(config)
import json
import random
import logging
import re
import time
import itertools
from dataclasses import dataclass, field
from typing import Dict, Optional, Sequence, List, Tuple, Any
from collections.abc import Sequence
from pathlib import Path
import numpy as np
import torch
from torch.utils.data import Dataset
import transformers
from . import data_list
from .rope2d import get_rope_index_25, get_rope_index_2, get_rope_index_3
IGNORE_INDEX = -100
IMAGE_TOKEN_INDEX = 151655
VIDEO_TOKEN_INDEX = 151656
DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_VIDEO_TOKEN = "<video>"
local_rank = None
def rank0_print(*args):
if local_rank == 0:
print(*args)
def read_jsonl(path):
with open(path, "r") as f:
return [json.loads(line) for line in f]
def _make_abs_paths(base: Path, files: str) -> str:
return f"{(base / files).resolve()}"
def update_processor_pixels(processor, data_args):
logger = logging.getLogger(__name__)
# --- Image Processor ---
ip = processor.image_processor
rank0_print("=== BEFORE IMAGE PROCESSOR PARAMETERS ===")
rank0_print(f"Image min_pixels: {getattr(ip, 'min_pixels', 'N/A')}")
rank0_print(f"Image max_pixels: {getattr(ip, 'max_pixels', 'N/A')}")
rank0_print(f"ip.size: {ip.size}")
rank0_print(f"Image size (shortest_edge): {ip.size.get('shortest_edge', 'N/A')}")
rank0_print(f"Image size (longest_edge): {ip.size.get('longest_edge', 'N/A')}")
if hasattr(ip, "min_pixels") and hasattr(ip, "max_pixels"):
ip.min_pixels = data_args.min_pixels
ip.max_pixels = data_args.max_pixels
rank0_print(f"✅ Updated image_processor min_pixels to {data_args.min_pixels}")
rank0_print(f"✅ Updated image_processor max_pixels to {data_args.max_pixels}")
if hasattr(ip, "size") and isinstance(ip.size, dict):
ip.size["shortest_edge"] = data_args.min_pixels
ip.size["longest_edge"] = data_args.max_pixels
rank0_print(
f"✅ Updated image_processor size['shortest_edge'] to {data_args.min_pixels}"
)
rank0_print(
f"✅ Updated image_processor size['longest_edge'] to {data_args.max_pixels}"
)
rank0_print("=== AFTER IMAGE PROCESSOR PARAMETERS ===")
rank0_print(f"Image min_pixels: {getattr(ip, 'min_pixels', 'N/A')}")
rank0_print(f"Image max_pixels: {getattr(ip, 'max_pixels', 'N/A')}")
rank0_print(f"Image size (shortest_edge): {ip.size.get('shortest_edge', 'N/A')}")
rank0_print(f"Image size (longest_edge): {ip.size.get('longest_edge', 'N/A')}")
# --- Video Processor ---
if hasattr(processor, "video_processor") and processor.video_processor is not None:
vp = processor.video_processor
rank0_print("\n=== BEFORE VIDEO PROCESSOR PARAMETERS ===")
rank0_print(f"Video min_pixels: {getattr(vp, 'min_pixels', 'N/A')}")
rank0_print(f"Video max_pixels: {getattr(vp, 'max_pixels', 'N/A')}")
rank0_print(f"Video min_frames: {getattr(vp, 'min_frames', 'N/A')}")
rank0_print(f"Video max_frames: {getattr(vp, 'max_frames', 'N/A')}")
rank0_print(f"Video fps: {getattr(vp, 'fps', 'N/A')}")
rank0_print(
f"Video size (shortest_edge): {vp.size.get('shortest_edge', 'N/A')}"
)
rank0_print(f"Video size (longest_edge): {vp.size.get('longest_edge', 'N/A')}")
if hasattr(vp, "min_pixels") and hasattr(vp, "max_pixels"):
vp.min_pixels = data_args.video_min_pixels
vp.max_pixels = data_args.video_max_pixels
rank0_print(
f"✅ Updated Qwen2-VL video_processor min_pixels to {data_args.video_min_pixels}"
)
rank0_print(
f"✅ Updated Qwen2-VL video_processor max_pixels to {data_args.video_max_pixels}"
)
if hasattr(vp, "min_frames") and hasattr(vp, "max_frames"):
vp.min_frames = data_args.video_min_frames
vp.max_frames = data_args.video_max_frames
rank0_print(
f"✅ Updated video_processor min_frames to {data_args.video_min_frames}"
)
rank0_print(
f"✅ Updated video_processor max_frames to {data_args.video_max_frames}"
)
if hasattr(vp, "fps"):
vp.fps = data_args.video_fps
rank0_print(f"✅ Updated video_processor fps to {data_args.video_fps}")
if hasattr(vp, "size") and isinstance(vp.size, dict):
vp.size["shortest_edge"] = data_args.video_min_pixels
vp.size["longest_edge"] = data_args.video_max_pixels
rank0_print(
f"✅ Updated Video size (shortest_edge): {vp.size.get('shortest_edge', 'N/A')}"
)
rank0_print(
f"✅ Updated Video size (longest_edge): {vp.size.get('longest_edge', 'N/A')}"
)
rank0_print("=== AFTER VIDEO PROCESSOR PARAMETERS ===")
rank0_print(f"Video min_pixels: {getattr(vp, 'min_pixels', 'N/A')}")
rank0_print(f"Video max_pixels: {getattr(vp, 'max_pixels', 'N/A')}")
rank0_print(f"Video min_frames: {getattr(vp, 'min_frames', 'N/A')}")
rank0_print(f"Video max_frames: {getattr(vp, 'max_frames', 'N/A')}")
rank0_print(f"Video fps: {getattr(vp, 'fps', 'N/A')}")
rank0_print(
f"Video size (shortest_edge): {vp.size.get('shortest_edge', 'N/A')}"
)
rank0_print(f"Video size (longest_edge): {vp.size.get('longest_edge', 'N/A')}")
return processor
def _build_messages(item: Dict[str, Any], base_path: Path) -> List[Dict[str, Any]]:
# Extract and normalize images and videos
images = item.get("image") or []
if isinstance(images, str):
images = [images]
videos = item.get("video") or []
if isinstance(videos, str):
videos = [videos]
# Build media pools with absolute paths
image_pool = [
{"type": "image", "image": _make_abs_paths(base_path, img)} for img in images
]
video_pool = [
{"type": "video", "video": _make_abs_paths(base_path, vid)} for vid in videos
]
messages = []
for turn in item["conversations"]:
role = "user" if turn["from"] == "human" else "assistant"
text: str = turn["value"]
if role == "user":
content = []
# Split text by <image> or <video> placeholders while keeping delimiters
text_parts = re.split(r"(<image>|<video>)", text)
for seg in text_parts:
if seg == "<image>":
if not image_pool:
raise ValueError(
"Number of <image> placeholders exceeds the number of provided images"
)
content.append(image_pool.pop(0))
elif seg == "<video>":
if not video_pool:
raise ValueError(
"Number of <video> placeholders exceeds the number of provided videos"
)
content.append(video_pool.pop(0))
elif seg.strip():
content.append({"type": "text", "text": seg.strip()})
messages.append({"role": role, "content": content})
else:
# Assistant messages contain only text
messages.append({"role": role, "content": [{"type": "text", "text": text}]})
# Check for unused media files
if image_pool:
raise ValueError(
f"{len(image_pool)} image(s) remain unused (not consumed by placeholders)"
)
if video_pool:
raise ValueError(
f"{len(video_pool)} video(s) remain unused (not consumed by placeholders)"
)
return messages
def preprocess_qwen_visual(
sources,
processor,
) -> Dict:
if len(sources) != 1:
raise ValueError(f"Expected 1 source, got {len(sources)}")
source = sources[0]
base_path = Path(source.get("data_path", ""))
messages = _build_messages(source, base_path)
full_result = processor.apply_chat_template(
messages, tokenize=True, return_dict=True, return_tensors="pt"
)
input_ids = full_result["input_ids"]
if isinstance(input_ids, list):
input_ids = torch.tensor(input_ids).unsqueeze(0)
labels = torch.full_like(input_ids, IGNORE_INDEX)
input_ids_flat = input_ids[0].tolist()
L = len(input_ids_flat)
pos = 0
while pos < L:
if input_ids_flat[pos] == 77091:
ans_start = pos + 2
ans_end = ans_start
while ans_end < L and input_ids_flat[ans_end] != 151645:
ans_end += 1
if ans_end < L:
labels[0, ans_start : ans_end + 2] = input_ids[
0, ans_start : ans_end + 2
]
pos = ans_end
pos += 1
full_result["labels"] = labels
full_result["input_ids"] = input_ids
return full_result
class LazySupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
def __init__(self, processor, data_args):
super(LazySupervisedDataset, self).__init__()
dataset = data_args.dataset_use.split(",")
dataset_list = data_list(dataset)
rank0_print(f"Loading datasets: {dataset_list}")
self.video_max_total_pixels = getattr(
data_args, "video_max_total_pixels", 1664 * 28 * 28
)
self.video_min_total_pixels = getattr(
data_args, "video_min_total_pixels", 256 * 28 * 28
)
self.model_type = data_args.model_type
if data_args.model_type == "qwen3vl":
self.get_rope_index = get_rope_index_3
elif data_args.model_type == "qwen2.5vl":
self.get_rope_index = get_rope_index_25
elif data_args.model_type == "qwen2vl":
self.get_rope_index = get_rope_index_2
else:
raise ValueError(f"model_type: {data_args.model_type} not supported")
list_data_dict = []
for data in dataset_list:
file_format = data["annotation_path"].split(".")[-1]
if file_format == "jsonl":
annotations = read_jsonl(data["annotation_path"])
else:
annotations = json.load(open(data["annotation_path"], "r"))
sampling_rate = data.get("sampling_rate", 1.0)
if sampling_rate < 1.0:
annotations = random.sample(
annotations, int(len(annotations) * sampling_rate)
)
rank0_print(f"sampling {len(annotations)} examples from dataset {data}")
else:
rank0_print(f"dataset name: {data}")
for ann in annotations:
if isinstance(ann, list):
for sub_ann in ann:
sub_ann["data_path"] = data["data_path"]
else:
ann["data_path"] = data["data_path"]
list_data_dict += annotations
rank0_print(f"Total training samples: {len(list_data_dict)}")
random.shuffle(list_data_dict) # Randomly shuffle the data for training
rank0_print("Formatting inputs...Skip in lazy mode")
processor = update_processor_pixels(processor, data_args)
self.processor = processor
self.tokenizer = processor.tokenizer
self.data_args = data_args
self.merge_size = getattr(processor.image_processor, "merge_size", 2)
self.list_data_dict = list_data_dict
if data_args.data_packing:
self.item_fn = self._get_packed_item
else:
self.item_fn = self._get_item
def __len__(self):
return len(self.list_data_dict)
@property
def lengths(self):
length_list = []
for sample in self.list_data_dict:
img_tokens = 128 if "image" in sample else 0
length_list.append(
sum(len(conv["value"].split()) for conv in sample["conversations"])
+ img_tokens
)
return length_list
@property
def modality_lengths(self):
length_list = []
for sample in self.list_data_dict:
cur_len = sum(
len(conv["value"].split()) for conv in sample["conversations"]
)
cur_len = (
cur_len if ("image" in sample) or ("video" in sample) else -cur_len
)
length_list.append(cur_len)
return length_list
@property
def pre_calculated_length(self):
if "num_tokens" in self.list_data_dict[0]:
length_list = [sample["num_tokens"] for sample in self.list_data_dict]
return np.array(length_list)
else:
print("No pre-calculated length available.")
return np.array([1] * len(self.list_data_dict))
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
num_base_retries = 3
num_final_retries = 30
# try the current sample first
for attempt_idx in range(num_base_retries):
try:
sources = self.list_data_dict[i]
if isinstance(sources, dict):
sources = [sources]
sample = self.item_fn(sources)
return sample
except Exception as e:
# sleep 1s in case it is a cloud disk issue
print(f"[Try #{attempt_idx}] Failed to fetch sample {i}. Exception:", e)
time.sleep(1)
# try other samples, in case it is file corruption issue
for attempt_idx in range(num_base_retries):
try:
next_index = min(i + 1, len(self.list_data_dict) - 1)
sources = self.list_data_dict[next_index]
if isinstance(sources, dict):
sources = [sources]
sample = self.item_fn(sources)
return sample
except Exception as e:
# no need to sleep
print(
f"[Try other #{attempt_idx}] Failed to fetch sample {next_index}. Exception:",
e,
)
pass
try:
sources = self.list_data_dict[i]
if isinstance(sources, dict):
sources = [sources]
sample = self.item_fn(sources)
return sample
except Exception as e:
raise e
def _get_item(self, sources) -> Dict[str, torch.Tensor]:
data_dict = preprocess_qwen_visual(
sources,
self.processor,
)
seq_len = data_dict["input_ids"][0].size(0)
if "image_grid_thw" in data_dict:
grid_thw = data_dict.get("image_grid_thw")
if not isinstance(grid_thw, Sequence):
grid_thw = [grid_thw]
else:
grid_thw = None
if "video_grid_thw" in data_dict:
video_grid_thw = data_dict.get("video_grid_thw")
if not isinstance(video_grid_thw, Sequence):
video_grid_thw = [video_grid_thw]
second_per_grid_ts = [
self.processor.video_processor.temporal_patch_size
/ self.processor.video_processor.fps
] * len(video_grid_thw)
else:
video_grid_thw = None
second_per_grid_ts = None
position_ids, _ = self.get_rope_index(
self.merge_size,
data_dict["input_ids"],
image_grid_thw=torch.cat(grid_thw, dim=0) if grid_thw else None,
video_grid_thw=(
torch.cat(video_grid_thw, dim=0) if video_grid_thw else None
),
second_per_grid_ts=second_per_grid_ts if second_per_grid_ts else None,
)
data_dict["position_ids"] = position_ids
data_dict["attention_mask"] = [seq_len]
text = self.processor.tokenizer.decode(
data_dict["input_ids"][0], skip_special_tokens=False
)
labels = data_dict["labels"][0]
labels = [
tid if tid != -100 else self.processor.tokenizer.pad_token_id
for tid in labels
]
label = self.processor.tokenizer.decode(labels, skip_special_tokens=False)
return data_dict
def _get_packed_item(self, sources) -> Dict[str, torch.Tensor]:
if isinstance(sources, dict):
if isinstance(source, dict):
sources = [sources]
assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
return self._get_item(sources)
if isinstance(sources, list):
data_list = []
new_data_dict = {}
for source in sources:
if isinstance(source, dict):
source = [source]
assert (
len(source) == 1
), f"Don't know why it is wrapped to a list.\n {source}" # FIXME
data_list.append(self._get_item(source))
input_ids = torch.cat([d["input_ids"] for d in data_list], dim=1)
labels = torch.cat([d["labels"] for d in data_list], dim=1)
position_ids = torch.cat([d["position_ids"] for d in data_list], dim=2)
attention_mask = [
d["attention_mask"][0] for d in data_list if "attention_mask" in d
]
new_data_dict = {
"input_ids": input_ids,
"labels": labels,
"position_ids": position_ids,
"attention_mask": attention_mask if attention_mask else None,
}
if any("pixel_values" in d for d in data_list):
new_data_dict.update(
{
"pixel_values": torch.cat(
[
d["pixel_values"]
for d in data_list
if "pixel_values" in d
],
dim=0,
),
"image_grid_thw": torch.cat(
[
d["image_grid_thw"]
for d in data_list
if "image_grid_thw" in d
],
dim=0,
),
}
)
if any("pixel_values_videos" in d for d in data_list):
new_data_dict.update(
{
"pixel_values_videos": torch.cat(
[
d["pixel_values_videos"]
for d in data_list
if "pixel_values_videos" in d
],
dim=0,
),
"video_grid_thw": torch.cat(
[
d["video_grid_thw"]
for d in data_list
if "video_grid_thw" in d
],
dim=0,
),
}
)
return new_data_dict
def pad_and_cat(tensor_list):
max_length = max(tensor.shape[2] for tensor in tensor_list)
padded_tensors = []
for tensor in tensor_list:
pad_length = max_length - tensor.shape[2]
padded_tensor = torch.nn.functional.pad(tensor, (0, pad_length), "constant", 1)
padded_tensors.append(padded_tensor)
stacked_tensor = torch.cat(padded_tensors, dim=1)
return stacked_tensor
@dataclass
class DataCollatorForSupervisedDataset(object):
"""Collate examples for supervised fine-tuning."""
tokenizer: transformers.PreTrainedTokenizer
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels, position_ids = tuple(
[instance[key] for instance in instances]
for key in ("input_ids", "labels", "position_ids")
)
input_ids = [ids.squeeze(0) for ids in input_ids]
labels = [ids.squeeze(0) for ids in labels]
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
)
labels = torch.nn.utils.rnn.pad_sequence(
labels, batch_first=True, padding_value=IGNORE_INDEX
)
position_ids = pad_and_cat(position_ids)
input_ids = input_ids[:, : self.tokenizer.model_max_length]
labels = labels[:, : self.tokenizer.model_max_length]
position_ids = position_ids[:, :, : self.tokenizer.model_max_length]
batch = dict(
input_ids=input_ids,
labels=labels,
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
)
images = list(
instance["pixel_values"]
for instance in instances
if "pixel_values" in instance
)
videos = list(
instance["pixel_values_videos"]
for instance in instances
if "pixel_values_videos" in instance
)
if len(images) != 0:
concat_images = torch.cat([image for image in images], dim=0)
grid_thw = [
instance["image_grid_thw"]
for instance in instances
if "image_grid_thw" in instance
]
grid_thw = torch.cat(grid_thw, dim=0)
else:
concat_images = None
grid_thw = None
if len(videos) != 0:
concat_videos = torch.cat([video for video in videos], dim=0)
video_grid_thw = [
instance["video_grid_thw"]
for instance in instances
if "video_grid_thw" in instance
]
video_grid_thw = torch.cat(video_grid_thw, dim=0)
else:
concat_videos = None
video_grid_thw = None
batch["pixel_values"] = concat_images
batch["image_grid_thw"] = grid_thw
batch["pixel_values_videos"] = concat_videos
batch["video_grid_thw"] = video_grid_thw
batch["position_ids"] = position_ids
return batch
@dataclass
class FlattenedDataCollatorForSupervisedDataset(DataCollatorForSupervisedDataset):
"""Collate examples into packed sequence with multi-modal support."""
tokenizer: transformers.PreTrainedTokenizer
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels, position_ids, attention_mask = tuple(
[instance[key] for instance in instances]
for key in ("input_ids", "labels", "position_ids", "attention_mask")
)
attention_mask = list(
itertools.chain(
*(
instance["attention_mask"]
for instance in instances
if "attention_mask" in instance
)
)
)
seq_lens = torch.tensor([0] + attention_mask, dtype=torch.int32)
cumsum_seq_lens = torch.cumsum(seq_lens, dim=0, dtype=torch.int32)
input_ids = torch.cat(input_ids, dim=1)
labels = torch.cat(labels, dim=1)
position_ids = torch.cat(position_ids, dim=2)
batch = dict(
input_ids=input_ids,
labels=labels,
attention_mask=cumsum_seq_lens,
position_ids=position_ids,
)
images = list(
instance["pixel_values"]
for instance in instances
if "pixel_values" in instance
)
videos = list(
instance["pixel_values_videos"]
for instance in instances
if "pixel_values_videos" in instance
)
if len(images) != 0:
concat_images = torch.cat([image for image in images], dim=0)
grid_thw = [
instance["image_grid_thw"]
for instance in instances
if "image_grid_thw" in instance
]
grid_thw = torch.cat(grid_thw, dim=0)
else:
concat_images = None
grid_thw = None
if len(videos) != 0:
concat_videos = torch.cat([video for video in videos], dim=0)
video_grid_thw = [
instance["video_grid_thw"]
for instance in instances
if "video_grid_thw" in instance
]
video_grid_thw = torch.cat(video_grid_thw, dim=0)
else:
concat_videos = None
video_grid_thw = None
batch["pixel_values"] = concat_images
batch["image_grid_thw"] = grid_thw
batch["pixel_values_videos"] = concat_videos
batch["video_grid_thw"] = video_grid_thw
return batch
def make_supervised_data_module(processor, data_args) -> Dict:
"""Make dataset and collator for supervised fine-tuning."""
train_dataset = LazySupervisedDataset(processor, data_args=data_args)
if data_args.data_flatten or data_args.data_packing:
data_collator = FlattenedDataCollatorForSupervisedDataset(processor.tokenizer)
return dict(
train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator
)
data_collator = DataCollatorForSupervisedDataset(processor.tokenizer)
return dict(
train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator
)
if __name__ == "__main__":
pass
import torch
from typing import Dict, Optional, Sequence, List, Tuple
def get_rope_index_3(
spatial_merge_size: Optional[int] = 2,
input_ids: Optional[torch.LongTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
second_per_grid_ts: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Different from the original implementation, Qwen3VL use timestamps rather than absolute time position ids."""
# Since we use timestamps to seperate videos, like <t1> <vision_start> <frame1> <vision_end> <t2> <vision_start> <frame2> <vision_end>, the video_grid_thw should also be split
if video_grid_thw is not None:
video_grid_thw = torch.repeat_interleave(video_grid_thw, video_grid_thw[:, 0], dim=0)
video_grid_thw[:, 0] = 1
image_token_id = 151655
video_token_id = 151656
vision_start_token_id = 151652
mrope_position_deltas = []
if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
total_input_ids = input_ids
if attention_mask is None:
attention_mask = torch.ones_like(total_input_ids)
position_ids = torch.ones(
3,
input_ids.shape[0],
input_ids.shape[1],
dtype=input_ids.dtype,
device=input_ids.device,
)
image_index, video_index = 0, 0
attention_mask = attention_mask.to(total_input_ids.device)
for i, input_ids in enumerate(total_input_ids):
input_ids = input_ids[attention_mask[i] == 1]
image_nums, video_nums = 0, 0
vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
vision_tokens = input_ids[vision_start_indices + 1]
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (vision_tokens == video_token_id).sum()
input_tokens = input_ids.tolist()
llm_pos_ids_list: list = []
st = 0
remain_images, remain_videos = image_nums, video_nums
for _ in range(image_nums + video_nums):
if image_token_id in input_tokens and remain_images > 0:
ed_image = input_tokens.index(image_token_id, st)
else:
ed_image = len(input_tokens) + 1
if video_token_id in input_tokens and remain_videos > 0:
ed_video = input_tokens.index(video_token_id, st)
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = (
image_grid_thw[image_index][0],
image_grid_thw[image_index][1],
image_grid_thw[image_index][2],
)
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = (
video_grid_thw[video_index][0],
video_grid_thw[video_index][1],
video_grid_thw[video_index][2],
)
video_index += 1
remain_videos -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = (
t.item(),
h.item() // spatial_merge_size,
w.item() // spatial_merge_size,
)
text_len = ed - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
# t_index is always 0 because llm_grid_t is always 1 (we use timestamps to encode the temporal information for videos)
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
return position_ids, mrope_position_deltas
else:
if attention_mask is not None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
else:
position_ids = (
torch.arange(input_ids.shape[1], device=input_ids.device)
.view(1, 1, -1)
.expand(3, input_ids.shape[0], -1)
)
mrope_position_deltas = torch.zeros(
[input_ids.shape[0], 1],
device=input_ids.device,
dtype=input_ids.dtype,
)
return position_ids, mrope_position_deltas
def get_rope_index_25(
spatial_merge_size: Optional[int] = 2,
input_ids: Optional[torch.LongTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
second_per_grid_ts: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Calculate the 3D rope index based on image and video's temporal, height and width in LLM.
Explanation:
Each embedding sequence contains vision embedding and text embedding or just contains text embedding.
For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs.
Examples:
input_ids: [T T T T T], here T is for text.
temporal position_ids: [0, 1, 2, 3, 4]
height position_ids: [0, 1, 2, 3, 4]
width position_ids: [0, 1, 2, 3, 4]
For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part
and 1D rotary position embedding for text part.
Examples:
Temporal (Time): 3 patches, representing different segments of the video in time.
Height: 2 patches, dividing each frame vertically.
Width: 2 patches, dividing each frame horizontally.
We also have some important parameters:
fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each second.
tokens_per_second: This is a crucial parameter. It dictates how many "time-steps" or "temporal tokens" are conceptually packed into a one-second interval of the video. In this case, we have 25 tokens per second. So each second of the video will be represented with 25 separate time points. It essentially defines the temporal granularity.
temporal_patch_size: The number of frames that compose one temporal patch. Here, it's 2 frames.
interval: The step size for the temporal position IDs, calculated as tokens_per_second * temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be have a difference of 50 in the temporal position IDs.
input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision.
vision temporal position_ids: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100]
vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]
vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
text temporal position_ids: [101, 102, 103, 104, 105]
text height position_ids: [101, 102, 103, 104, 105]
text width position_ids: [101, 102, 103, 104, 105]
Here we calculate the text start position_ids as the max vision position_ids plus 1.
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
The temporal, height and width of feature shape of each image in LLM.
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
The temporal, height and width of feature shape of each video in LLM.
second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*):
The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
Returns:
position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`)
mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`)
"""
image_token_id = 151655
video_token_id = 151656
vision_start_token_id = 151652
mrope_position_deltas = []
if input_ids is not None and (
image_grid_thw is not None or video_grid_thw is not None
):
total_input_ids = input_ids
if attention_mask is None:
attention_mask = torch.ones_like(total_input_ids)
position_ids = torch.ones(
3,
input_ids.shape[0],
input_ids.shape[1],
dtype=input_ids.dtype,
device=input_ids.device,
)
image_index, video_index = 0, 0
attention_mask = attention_mask.to(total_input_ids.device)
for i, input_ids in enumerate(total_input_ids):
input_ids = input_ids[attention_mask[i] == 1]
image_nums, video_nums = 0, 0
vision_start_indices = torch.argwhere(
input_ids == vision_start_token_id
).squeeze(1)
vision_tokens = input_ids[vision_start_indices + 1]
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (vision_tokens == video_token_id).sum()
input_tokens = input_ids.tolist()
llm_pos_ids_list: list = []
st = 0
remain_images, remain_videos = image_nums, video_nums
for _ in range(image_nums + video_nums):
if image_token_id in input_tokens and remain_images > 0:
ed_image = input_tokens.index(image_token_id, st)
else:
ed_image = len(input_tokens) + 1
if video_token_id in input_tokens and remain_videos > 0:
ed_video = input_tokens.index(video_token_id, st)
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = (
image_grid_thw[image_index][0],
image_grid_thw[image_index][1],
image_grid_thw[image_index][2],
)
second_per_grid_t = 0
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = (
video_grid_thw[video_index][0],
video_grid_thw[video_index][1],
video_grid_thw[video_index][2],
)
if second_per_grid_ts is not None:
second_per_grid_t = second_per_grid_ts[video_index]
else:
second_per_grid_t = 1.0
video_index += 1
remain_videos -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = (
t.item(),
h.item() // spatial_merge_size,
w.item() // spatial_merge_size,
)
text_len = ed - st
st_idx = (
llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
)
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
)
range_tensor = torch.arange(llm_grid_t).view(-1, 1)
expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w)
time_tensor = expanded_range * second_per_grid_t * 2
time_tensor_long = time_tensor.long()
t_index = time_tensor_long.flatten()
h_index = (
torch.arange(llm_grid_h)
.view(1, -1, 1)
.expand(llm_grid_t, -1, llm_grid_w)
.flatten()
)
w_index = (
torch.arange(llm_grid_w)
.view(1, 1, -1)
.expand(llm_grid_t, llm_grid_h, -1)
.flatten()
)
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + text_len + st_idx
)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
if st < len(input_tokens):
st_idx = (
llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
)
text_len = len(input_tokens) - st
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(
position_ids.device
)
mrope_position_deltas.append(
llm_positions.max() + 1 - len(total_input_ids[i])
)
mrope_position_deltas = torch.tensor(
mrope_position_deltas, device=input_ids.device
).unsqueeze(1)
return position_ids, mrope_position_deltas
else:
if attention_mask is not None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
position_ids = (
position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
)
max_position_ids = position_ids.max(0, keepdim=False)[0].max(
-1, keepdim=True
)[0]
mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
else:
position_ids = (
torch.arange(input_ids.shape[1], device=input_ids.device)
.view(1, 1, -1)
.expand(3, input_ids.shape[0], -1)
)
mrope_position_deltas = torch.zeros(
[input_ids.shape[0], 1],
device=input_ids.device,
dtype=input_ids.dtype,
)
return position_ids, mrope_position_deltas
def get_rope_index_2(
spatial_merge_size: Optional[int] = 2,
input_ids: Optional[torch.LongTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
second_per_grid_ts: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Calculate the 3D rope index based on image and video's temporal, height and width in LLM.
Explanation:
Each embedding sequence contains vision embedding and text embedding or just contains text embedding.
For pure text embedding sequence, the rotary position embedding has no difference with mordern LLMs.
Examples:
input_ids: [T T T T T], here T is for text.
temporal position_ids: [0, 1, 2, 3, 4]
height position_ids: [0, 1, 2, 3, 4]
width position_ids: [0, 1, 2, 3, 4]
For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part
and 1D rotary position embeddin for text part.
Examples:
Assume we have a video input with 3 temporal patches, 2 height patches and 2 width patches.
input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision.
vision temporal position_ids: [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]
vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]
vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
text temporal position_ids: [3, 4, 5, 6, 7]
text height position_ids: [3, 4, 5, 6, 7]
text width position_ids: [3, 4, 5, 6, 7]
Here we calculate the text start position_ids as the max vision position_ids plus 1.
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
The temporal, height and width of feature shape of each image in LLM.
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
The temporal, height and width of feature shape of each video in LLM.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
Returns:
position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`)
mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`)
"""
image_token_id = 151655
video_token_id = 151656
vision_start_token_id = 151652
mrope_position_deltas = []
if input_ids is not None and (
image_grid_thw is not None or video_grid_thw is not None
):
total_input_ids = input_ids
if attention_mask is None:
attention_mask = torch.ones_like(total_input_ids)
position_ids = torch.ones(
3,
input_ids.shape[0],
input_ids.shape[1],
dtype=input_ids.dtype,
device=input_ids.device,
)
image_index, video_index = 0, 0
for i, input_ids in enumerate(total_input_ids):
input_ids = input_ids[attention_mask[i] == 1]
image_nums, video_nums = 0, 0
vision_start_indices = torch.argwhere(
input_ids == vision_start_token_id
).squeeze(1)
vision_tokens = input_ids[vision_start_indices + 1]
image_nums = (vision_tokens == image_token_id).sum()
video_nums = (vision_tokens == video_token_id).sum()
input_tokens = input_ids.tolist()
llm_pos_ids_list: list = []
st = 0
remain_images, remain_videos = image_nums, video_nums
for _ in range(image_nums + video_nums):
if image_token_id in input_tokens and remain_images > 0:
ed_image = input_tokens.index(image_token_id, st)
else:
ed_image = len(input_tokens) + 1
if video_token_id in input_tokens and remain_videos > 0:
ed_video = input_tokens.index(video_token_id, st)
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = (
image_grid_thw[image_index][0],
image_grid_thw[image_index][1],
image_grid_thw[image_index][2],
)
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = (
video_grid_thw[video_index][0],
video_grid_thw[video_index][1],
video_grid_thw[video_index][2],
)
video_index += 1
remain_videos -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = (
t.item(),
h.item() // spatial_merge_size,
w.item() // spatial_merge_size,
)
text_len = ed - st
st_idx = (
llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
)
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
)
t_index = (
torch.arange(llm_grid_t)
.view(-1, 1)
.expand(-1, llm_grid_h * llm_grid_w)
.flatten()
)
h_index = (
torch.arange(llm_grid_h)
.view(1, -1, 1)
.expand(llm_grid_t, -1, llm_grid_w)
.flatten()
)
w_index = (
torch.arange(llm_grid_w)
.view(1, 1, -1)
.expand(llm_grid_t, llm_grid_h, -1)
.flatten()
)
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + text_len + st_idx
)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
if st < len(input_tokens):
st_idx = (
llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
)
text_len = len(input_tokens) - st
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(
position_ids.device
)
mrope_position_deltas.append(
llm_positions.max() + 1 - len(total_input_ids[i])
)
mrope_position_deltas = torch.tensor(
mrope_position_deltas, device=input_ids.device
).unsqueeze(1)
return position_ids, mrope_position_deltas
else:
if attention_mask is not None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
position_ids = (
position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
)
max_position_ids = position_ids.max(0, keepdim=False)[0].max(
-1, keepdim=True
)[0]
mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
else:
position_ids = (
torch.arange(input_ids.shape[1], device=input_ids.device)
.view(1, 1, -1)
.expand(3, input_ids.shape[0], -1)
)
mrope_position_deltas = torch.zeros(
[input_ids.shape[0], 1],
device=input_ids.device,
dtype=input_ids.dtype,
)
return position_ids, mrope_position_deltas
import transformers
from dataclasses import dataclass, field
from typing import Dict, Optional, Sequence, List
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default="Qwen/Qwen2.5-VL-3B-Instruct")
tune_mm_llm: bool = field(default=False)
tune_mm_mlp: bool = field(default=False)
tune_mm_vision: bool = field(default=False)
@dataclass
class DataArguments:
dataset_use: str = field(default="")
data_flatten: bool = field(default=False)
data_packing: bool = field(default=False)
base_interval: int = field(default=2)
max_pixels: int = field(default=28 * 28 * 576)
min_pixels: int = field(default=28 * 28 * 16)
video_max_frames: Optional[int] = field(default=8)
video_min_frames: Optional[int] = field(default=4)
video_max_pixels: int = field(default=1024 * 28 * 28)
video_min_pixels: int = field(default=256 * 28 * 28)
video_fps: float = 2
@dataclass
class TrainingArguments(transformers.TrainingArguments):
cache_dir: Optional[str] = field(default=None)
optim: str = field(default="adamw_torch")
model_max_length: int = field(
default=512,
metadata={
"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
},
)
mm_projector_lr: Optional[float] = None
vision_tower_lr: Optional[float] = None
## Lora config
lora_enable: bool = field(default=False)
lora_r: int = field(default=64)
lora_alpha: int = field(default=128)
lora_dropout: float = field(default=0.0)
# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import logging
import pathlib
import torch
import transformers
import sys
from pathlib import Path
project_root = Path(__file__).parent.parent.parent
sys.path.append(str(project_root))
from trainer import replace_qwen2_vl_attention_class
from transformers import (
Qwen2VLForConditionalGeneration,
Qwen2_5_VLForConditionalGeneration,
Qwen3VLForConditionalGeneration,
Qwen3VLMoeForConditionalGeneration
)
from qwenvl.data.data_processor import make_supervised_data_module
from qwenvl.train.argument import (
ModelArguments,
DataArguments,
TrainingArguments,
)
from transformers import AutoProcessor, Trainer
local_rank = None
def rank0_print(*args):
if local_rank == 0:
print(*args)
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
"""Collects the state dict and dump to disk."""
if trainer.deepspeed:
torch.cuda.synchronize()
trainer.save_model(output_dir)
return
state_dict = trainer.model.state_dict()
if trainer.args.should_save:
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
del state_dict
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
def set_model(model_args, model):
if model_args.tune_mm_vision:
for n, p in model.visual.named_parameters():
p.requires_grad = True
else:
for n, p in model.visual.named_parameters():
p.requires_grad = False
if model_args.tune_mm_mlp:
for n, p in model.visual.merger.named_parameters():
p.requires_grad = True
else:
for n, p in model.visual.merger.named_parameters():
p.requires_grad = False
if model_args.tune_mm_llm:
for n, p in model.language_model.named_parameters():
p.requires_grad = True
model.lm_head.requires_grad = True
else:
for n, p in model.language_model.named_parameters():
p.requires_grad = False
model.lm_head.requires_grad = False
def train(attn_implementation="flash_attention_2"):
global local_rank
parser = transformers.HfArgumentParser(
(ModelArguments, DataArguments, TrainingArguments)
)
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
local_rank = training_args.local_rank
os.makedirs(training_args.output_dir, exist_ok=True)
if "qwen3" in model_args.model_name_or_path.lower() and "a" in Path(model_args.model_name_or_path.rstrip("/")).name.lower():
model = Qwen3VLMoeForConditionalGeneration.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
attn_implementation=attn_implementation,
dtype=(torch.bfloat16 if training_args.bf16 else None),
)
data_args.model_type = "qwen3vl"
elif "qwen3" in model_args.model_name_or_path.lower():
model = Qwen3VLForConditionalGeneration.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
attn_implementation=attn_implementation,
dtype=(torch.bfloat16 if training_args.bf16 else None),
)
data_args.model_type = "qwen3vl"
elif "qwen2.5" in model_args.model_name_or_path.lower():
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
attn_implementation=attn_implementation,
dtype=(torch.bfloat16 if training_args.bf16 else None),
)
data_args.model_type = "qwen2.5vl"
else:
model = Qwen2VLForConditionalGeneration.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
attn_implementation=attn_implementation,
dtype=(torch.bfloat16 if training_args.bf16 else None),
)
data_args.model_type = "qwen2vl"
print(f'the initlized model is {model_args.model_name_or_path} the class is {model.__class__.__name__}')
processor = AutoProcessor.from_pretrained(
model_args.model_name_or_path,
)
if data_args.data_flatten or data_args.data_packing:
replace_qwen2_vl_attention_class()
model.config.use_cache = False
if training_args.gradient_checkpointing:
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
model_max_length=training_args.model_max_length,
padding_side="right",
use_fast=False,
)
if training_args.lora_enable:
from peft import LoraConfig, get_peft_model, TaskType
print("LoRA enabled")
for p in model.parameters():
p.requires_grad = False
lora_config = LoraConfig(
r=training_args.lora_r or 64,
lora_alpha=training_args.lora_alpha or 128,
lora_dropout=training_args.lora_dropout or 0.05,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], # Qwen 的 attention 线性层
bias="none",
task_type=TaskType.CAUSAL_LM,
)
model = get_peft_model(model, lora_config)
else:
set_model(model_args, model)
if torch.distributed.get_rank() == 0:
model.visual.print_trainable_parameters()
model.model.print_trainable_parameters()
data_module = make_supervised_data_module(processor, data_args=data_args)
trainer = Trainer(
model=model, processing_class=tokenizer, args=training_args, **data_module
)
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
logging.info("checkpoint found, resume training")
trainer.train(resume_from_checkpoint=True)
else:
trainer.train()
trainer.save_state()
model.config.use_cache = True
safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)
processor.save_pretrained(training_args.output_dir)
if __name__ == "__main__":
train(attn_implementation="flash_attention_2")
from typing import Dict, List, Optional, Sequence, Tuple, Callable
import torch
from flash_attn.flash_attn_interface import flash_attn_varlen_func
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers import Trainer
from transformers.cache_utils import Cache
from transformers.utils.deprecation import deprecate_kwarg
from transformers.processing_utils import Unpack
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
Qwen2VisionTransformerPretrainedModel,
Qwen2VLModel,
apply_multimodal_rotary_pos_emb,
)
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VisionTransformerPretrainedModel,
Qwen2_5_VLModel,
)
from transformers.models.qwen3_vl.modeling_qwen3_vl import (
Qwen3VLVisionModel,
Qwen3VLModel,
apply_rotary_pos_emb,
)
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import (
Qwen3VLMoeVisionModel,
Qwen3VLMoeModel,
)
from transformers.utils import logging
logger = logging.get_logger(__name__)
def flash_attention_forward(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
dropout: float = 0.0,
scaling: Optional[float] = None,
sliding_window: Optional[int] = None,
softcap: Optional[float] = None,
**kwargs,
) -> tuple[torch.Tensor, None]:
if kwargs.get("output_attentions", False) or kwargs.get("head_mask") is not None:
logger.warning_once(
"`flash_attention_2` does not support `output_attentions=True` or `head_mask`."
" Please set your attention to `eager` if you want any of these features."
)
# This is before the transpose
seq_len = query.shape[2]
if any(dim == 0 for dim in query.shape):
raise ValueError(
"Tensor query has shape with a zero dimension.\n"
"FlashAttention does not support inputs with dim=0.\n"
"Please check your input shapes or use SDPA instead."
)
# FA2 uses non-transposed inputs
# batch, head, seq_len, dim
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
# batch, seqlen, head, dim
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (usually our RMSNorm modules handle it correctly)
target_dtype = None
if query.dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(module.config, "_pre_quantization_dtype"):
target_dtype = module.config._pre_quantization_dtype
else:
target_dtype = next(layer for layer in module.modules() if isinstance(layer, torch.nn.Linear)).weight.dtype
query = query.squeeze(0)
key = key.squeeze(0)
value = value.squeeze(0)
cu_seqlens = attention_mask
with torch.no_grad():
max_seqlen = max(
[
cu_seqlens[idx + 1] - cu_seqlens[idx]
for idx in range(cu_seqlens.size(0) - 1)
]
).item()
attn_output = flash_attn_varlen_func(
query,
key,
value,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
causal=True,
)
attn_output = attn_output.unsqueeze(0)
return attn_output, None
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
def qwen2vl_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_multimodal_rotary_pos_emb(
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
)
if past_key_values is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
attn_output, attn_weights = flash_attention_forward(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
sliding_window=self.sliding_window,
position_ids=position_ids, # pass positions for FA2
**kwargs,
)
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
def qwen3vl_forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_values: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_values is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
attn_output, attn_weights = flash_attention_forward(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
def return_mask(
config,
input_embeds,
attention_mask,
cache_position,
past_key_values,
position_ids,
**kwargs
):
return attention_mask
def replace_qwen2_vl_attention_class():
import transformers
import transformers.modeling_flash_attention_utils
transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VLAttention.forward = (
qwen2vl_forward
)
transformers.models.qwen2_vl.modeling_qwen2_vl.create_causal_mask = (
return_mask
)
transformers.models.qwen2_vl.modeling_qwen2_vl.create_sliding_window_causal_mask = (
return_mask
)
## qwen2_5_vl
transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLAttention.forward = (
qwen2vl_forward
)
transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.create_causal_mask = (
return_mask
)
transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.create_sliding_window_causal_mask = (
return_mask
)
## qwen3vl
transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLTextAttention.forward = (
qwen3vl_forward
)
transformers.models.qwen3_vl.modeling_qwen3_vl.create_causal_mask = (
return_mask
)
## qwen3vl moe
transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe.Qwen3VLMoeTextAttention.forward = (
qwen3vl_forward
)
transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe.create_causal_mask = (
return_mask
)
def print_trainable_parameters_visual(self) -> None:
"""
Prints the trainable status of all vision components including attention blocks and merger module.
Outputs the indices of trainable/non-trainable blocks and the merger module status.
"""
trainable_blocks = []
non_trainable_blocks = []
# Check trainable status of vision attention blocks
for block_idx, block in enumerate(self.blocks):
is_trainable = all(param.requires_grad for param in block.parameters())
if is_trainable:
trainable_blocks.append(block_idx)
else:
non_trainable_blocks.append(block_idx)
# Check trainable status of merger module
is_merger_trainable = any(param.requires_grad for param in self.merger.parameters())
# Print results
print("Vision Module - Attention Blocks:")
print(
f"Trainable Block Indices: {trainable_blocks if trainable_blocks else 'None'}"
)
print(
f"Non-Trainable Block Indices: {non_trainable_blocks if non_trainable_blocks else 'None'}"
)
print(f"Merger Module Trainable: {is_merger_trainable}")
def print_trainable_parameters(self) -> None:
"""
Prints the trainable status of all LLM components including embeddings, layers, and normalization.
Outputs the indices of trainable/non-trainable layers and other module statuses.
"""
# Check embed_tokens
is_embed_trainable = any(
param.requires_grad for param in self.language_model.embed_tokens.parameters()
)
print(f"LLM Module - Embed Tokens Trainable: {is_embed_trainable}")
# Check each decoder layer
trainable_layers = []
non_trainable_layers = []
for layer_idx, layer in enumerate(self.language_model.layers):
is_trainable = any(param.requires_grad for param in layer.parameters())
if is_trainable:
trainable_layers.append(layer_idx)
else:
non_trainable_layers.append(layer_idx)
# Print layer status
print(
f"LLM Module - Trainable Layer Indices: {trainable_layers if trainable_layers else 'None'}"
)
print(
f"LLM Module - Non-Trainable Layer Indices: {non_trainable_layers if non_trainable_layers else 'None'}"
)
def create_optimizer(self):
opt_model = self.model
if self.optimizer is None:
decay_parameters = self.get_decay_parameter_names(opt_model)
decay_parameters = [name for name in decay_parameters if "bias" not in name]
if self.args.mm_projector_lr is not None and self.args.mm_projector_lr != 0:
projector_parameters = [
name for name, _ in opt_model.named_parameters() if "merger" in name
]
if self.args.vision_tower_lr is not None and self.args.vision_tower_lr != 0:
vision_tower_parameters = [
name for name, _ in opt_model.named_parameters() if "visual" in name
]
optimizer_grouped_parameters = [
{
"params": [
p
for n, p in opt_model.named_parameters()
if (
n in decay_parameters
and n not in projector_parameters
and n not in vision_tower_parameters
and p.requires_grad
)
],
"weight_decay": self.args.weight_decay,
},
{
"params": [
p
for n, p in opt_model.named_parameters()
if (
n in decay_parameters
and n not in projector_parameters
and n in vision_tower_parameters
and p.requires_grad
)
],
"weight_decay": self.args.weight_decay,
"lr": self.args.vision_tower_lr,
},
{
"params": [
p
for n, p in opt_model.named_parameters()
if (
n not in decay_parameters
and n not in projector_parameters
and n not in vision_tower_parameters
and p.requires_grad
)
],
"weight_decay": 0.0,
},
{
"params": [
p
for n, p in opt_model.named_parameters()
if (
n not in decay_parameters
and n not in projector_parameters
and n in vision_tower_parameters
and p.requires_grad
)
],
"weight_decay": 0.0,
"lr": self.args.vision_tower_lr,
},
{
"params": [
p
for n, p in opt_model.named_parameters()
if (
n in decay_parameters
and n in projector_parameters
and p.requires_grad
)
],
"weight_decay": self.args.weight_decay,
"lr": self.args.mm_projector_lr,
},
{
"params": [
p
for n, p in opt_model.named_parameters()
if (
n not in decay_parameters
and n in projector_parameters
and p.requires_grad
)
],
"weight_decay": 0.0,
"lr": self.args.mm_projector_lr,
},
]
else:
optimizer_grouped_parameters = [
{
"params": [
p
for n, p in opt_model.named_parameters()
if (
n in decay_parameters
and n not in projector_parameters
and p.requires_grad
)
],
"weight_decay": self.args.weight_decay,
},
{
"params": [
p
for n, p in opt_model.named_parameters()
if (
n not in decay_parameters
and n not in projector_parameters
and p.requires_grad
)
],
"weight_decay": 0.0,
},
{
"params": [
p
for n, p in opt_model.named_parameters()
if (
n in decay_parameters
and n in projector_parameters
and p.requires_grad
)
],
"weight_decay": self.args.weight_decay,
"lr": self.args.mm_projector_lr,
},
{
"params": [
p
for n, p in opt_model.named_parameters()
if (
n not in decay_parameters
and n in projector_parameters
and p.requires_grad
)
],
"weight_decay": 0.0,
"lr": self.args.mm_projector_lr,
},
]
else:
optimizer_grouped_parameters = [
{
"params": [
p
for n, p in opt_model.named_parameters()
if (n in decay_parameters and p.requires_grad)
],
"weight_decay": self.args.weight_decay,
},
{
"params": [
p
for n, p in opt_model.named_parameters()
if (n not in decay_parameters and p.requires_grad)
],
"weight_decay": 0.0,
},
]
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
self.args
)
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
return self.optimizer
# Apply monkey patches
Trainer.create_optimizer = create_optimizer
Qwen2VisionTransformerPretrainedModel.print_trainable_parameters = (
print_trainable_parameters_visual
)
Qwen2VLModel.print_trainable_parameters = print_trainable_parameters
Qwen2_5_VisionTransformerPretrainedModel.print_trainable_parameters = (
print_trainable_parameters_visual
)
Qwen2_5_VLModel.print_trainable_parameters = print_trainable_parameters
Qwen3VLVisionModel.print_trainable_parameters = (
print_trainable_parameters_visual
)
Qwen3VLModel.print_trainable_parameters = print_trainable_parameters
Qwen3VLMoeVisionModel.print_trainable_parameters = print_trainable_parameters_visual
Qwen3VLMoeModel.print_trainable_parameters = print_trainable_parameters
\ No newline at end of file
#!/bin/bash
# Distributed training configuration
MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
MASTER_PORT=${MASTER_PORT:-$(shuf -i 20001-29999 -n 1)}
NNODES=${WORLD_SIZE:-1}
# DeepSpeed configuration
deepspeed=./scripts/zero3.json
# Model configuration
llm=Qwen/Qwen2.5-VL-3B-Instruct # Using HuggingFace model ID
# Training hyperparameters
lr=2e-7
batch_size=4
grad_accum_steps=4
# Training entry point
entry_file=qwenvl/train/train_qwen.py
# Dataset configuration (replace with public dataset names)
datasets=public_dataset1,public_dataset2
# Output configuration
run_name="qwen2vl-baseline"
output_dir=./output
# Training arguments
args="
--deepspeed ${deepspeed} \
--model_name_or_path "${llm}" \
--dataset_use ${datasets} \
--data_flatten True \
--tune_mm_vision False \
--tune_mm_mlp True \
--tune_mm_llm True \
--bf16 \
--output_dir ${output_dir} \
--num_train_epochs 0.5 \
--per_device_train_batch_size ${batch_size} \
--per_device_eval_batch_size $((batch_size*2)) \
--gradient_accumulation_steps ${grad_accum_steps} \
--max_pixels 50176 \
--min_pixels 784 \
--eval_strategy "no" \
--save_strategy "steps" \
--save_steps 1000 \
--save_total_limit 1 \
--learning_rate ${lr} \
--weight_decay 0 \
--warmup_ratio 0.03 \
--max_grad_norm 1 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--model_max_length 8192 \
--gradient_checkpointing True \
--dataloader_num_workers 4 \
--run_name ${run_name} \
--report_to wandb"
# Launch training
torchrun --nproc_per_node=${NPROC_PER_NODE} \
--master_addr=${MASTER_ADDR} \
--master_port=${MASTER_PORT} \
${entry_file} ${args}
\ No newline at end of file
#!/bin/bash
# Distributed training configuration
MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
MASTER_PORT=${MASTER_PORT:-$(shuf -i 20001-29999 -n 1)}
NNODES=${WORLD_SIZE:-1}
# DeepSpeed configuration
# MoE model only supports zero2
# this script could run on 32 80G GPU
deepspeed=./scripts/zero2.json
# Model configuration
llm=Qwen/Qwen3-VL-30B-A3B-Instruct # Using HuggingFace model ID
# Training hyperparameters
lr=1e-5
batch_size=1
grad_accum_steps=4
# Training entry point
entry_file=qwenvl/train/train_qwen.py
# Dataset configuration (replace with public dataset names)
datasets=public_dataset1,public_dataset2
# Output configuration
run_name="qwen3vl-moe"
output_dir=./output
# Training arguments
args="
--deepspeed ${deepspeed} \
--model_name_or_path "${llm}" \
--dataset_use ${datasets} \
--data_flatten True \
--tune_mm_vision False \
--tune_mm_mlp True \
--tune_mm_llm True \
--bf16 \
--output_dir ${output_dir} \
--num_train_epochs 0.5 \
--per_device_train_batch_size ${batch_size} \
--per_device_eval_batch_size $((batch_size*2)) \
--gradient_accumulation_steps ${grad_accum_steps} \
--max_pixels 50176 \
--min_pixels 784 \
--eval_strategy "no" \
--save_strategy "steps" \
--save_steps 1000 \
--save_total_limit 1 \
--learning_rate ${lr} \
--weight_decay 0 \
--warmup_ratio 0.03 \
--max_grad_norm 1 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--model_max_length 8192 \
--gradient_checkpointing True \
--dataloader_num_workers 4 \
--run_name ${run_name} \
--report_to wandb"
# Launch training
torchrun --nproc_per_node=${NPROC_PER_NODE} \
--master_addr=${MASTER_ADDR} \
--master_port=${MASTER_PORT} \
${entry_file} ${args}
\ No newline at end of file
#!/bin/bash
# Distributed training configuration
MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
MASTER_PORT=${MASTER_PORT:-$(shuf -i 20001-29999 -n 1)}
NNODES=${WORLD_SIZE:-1}
# DeepSpeed configuration
# MoE model only supports zero2
# this script could run on 32 80G GPU
deepspeed=./scripts/zero2.json
# Model configuration
llm=Qwen/Qwen3-VL-30B-A3B-Instruct # Using HuggingFace model ID
# Training hyperparameters
lr=1e-5
batch_size=1
grad_accum_steps=4
# Training entry point
entry_file=qwenvl/train/train_qwen.py
# Dataset configuration (replace with public dataset names)
datasets=public_dataset1,public_dataset2
# Output configuration
run_name="qwen3vl-moe-lora"
output_dir=./output
# Training arguments
args="
--deepspeed ${deepspeed} \
--model_name_or_path "${llm}" \
--dataset_use ${datasets} \
--data_flatten True \
--tune_mm_vision False \
--tune_mm_mlp True \
--tune_mm_llm True \
--bf16 \
--lora_enable True \
--output_dir ${output_dir} \
--num_train_epochs 0.5 \
--per_device_train_batch_size ${batch_size} \
--per_device_eval_batch_size ${batch_size*2} \
--gradient_accumulation_steps ${grad_accum_steps} \
--max_pixels 50176 \
--min_pixels 784 \
--eval_strategy "no" \
--save_strategy "steps" \
--save_steps 1000 \
--save_total_limit 1 \
--learning_rate ${lr} \
--weight_decay 0 \
--warmup_ratio 0.03 \
--max_grad_norm 1 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--model_max_length 8192 \
--gradient_checkpointing True \
--dataloader_num_workers 4 \
--run_name ${run_name} \
--report_to wandb"
# Launch training
torchrun --nproc_per_node=${NPROC_PER_NODE} \
--master_addr=${MASTER_ADDR} \
--master_port=${MASTER_PORT} \
${entry_file} ${args}
\ No newline at end of file
#!/bin/bash
# Distributed training configuration
MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
MASTER_PORT=${MASTER_PORT:-$(shuf -i 20001-29999 -n 1)}
NNODES=${WORLD_SIZE:-1}
# DeepSpeed configuration
deepspeed=./scripts/zero3.json
# Model configuration
llm=Qwen/Qwen2.5-VL-32B-Instruct # Using HuggingFace model ID
# Training hyperparameters
lr=2e-7
batch_size=2
grad_accum_steps=8
# Training entry point
entry_file=qwenvl/train/train_qwen.py
# Dataset configuration (replace with public dataset names)
datasets=public_dataset1,public_dataset2
# Output configuration
run_name="qwen2vl-baseline"
output_dir=./output
# Training arguments
args="
--deepspeed ${deepspeed} \
--model_name_or_path "${llm}" \
--dataset_use ${datasets} \
--data_flatten True \
--tune_mm_vision False \
--tune_mm_mlp True \
--tune_mm_llm True \
--bf16 \
--output_dir ${output_dir} \
--num_train_epochs 0.5 \
--per_device_train_batch_size ${batch_size} \
--per_device_eval_batch_size $((batch_size*2)) \
--gradient_accumulation_steps ${grad_accum_steps} \
--max_pixels 50176 \
--min_pixels 784 \
--eval_strategy "no" \
--save_strategy "steps" \
--save_steps 1000 \
--save_total_limit 1 \
--learning_rate ${lr} \
--weight_decay 0 \
--warmup_ratio 0.03 \
--max_grad_norm 1 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--model_max_length 8192 \
--gradient_checkpointing True \
--dataloader_num_workers 4 \
--run_name ${run_name} \
--report_to wandb"
# Launch training
torchrun --nproc_per_node=${NPROC_PER_NODE} \
--master_addr=${MASTER_ADDR} \
--master_port=${MASTER_PORT} \
${entry_file} ${args}
\ No newline at end of file
#!/bin/bash
# Distributed training configuration
MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
MASTER_PORT=${MASTER_PORT:-$(shuf -i 20001-29999 -n 1)}
NNODES=${WORLD_SIZE:-1}
# DeepSpeed configuration
deepspeed=./scripts/zero3.json
# Model configuration
llm=Qwen/Qwen2.5-VL-7B-Instruct # Using HuggingFace model ID
# Training hyperparameters
lr=2e-7
batch_size=4
grad_accum_steps=4
# Training entry point
entry_file=qwenvl/train/train_qwen.py
# Dataset configuration (replace with public dataset names)
datasets=public_dataset1,public_dataset2
# Output configuration
run_name="qwen2vl-baseline"
output_dir=./output
# Training arguments
args="
--deepspeed ${deepspeed} \
--model_name_or_path "${llm}" \
--dataset_use ${datasets} \
--data_flatten True \
--tune_mm_vision False \
--tune_mm_mlp True \
--tune_mm_llm True \
--bf16 \
--output_dir ${output_dir} \
--num_train_epochs 0.5 \
--per_device_train_batch_size ${batch_size} \
--per_device_eval_batch_size $((batch_size*2)) \
--gradient_accumulation_steps ${grad_accum_steps} \
--max_pixels 50176 \
--min_pixels 784 \
--eval_strategy "no" \
--save_strategy "steps" \
--save_steps 1000 \
--save_total_limit 1 \
--learning_rate ${lr} \
--weight_decay 0 \
--warmup_ratio 0.03 \
--max_grad_norm 1 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--model_max_length 8192 \
--gradient_checkpointing True \
--dataloader_num_workers 4 \
--run_name ${run_name} \
--report_to wandb"
# Launch training
torchrun --nproc_per_node=${NPROC_PER_NODE} \
--master_addr=${MASTER_ADDR} \
--master_port=${MASTER_PORT} \
${entry_file} ${args}
\ No newline at end of file
#!/bin/bash
# Distributed training configuration
MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
MASTER_PORT=${MASTER_PORT:-$(shuf -i 20001-29999 -n 1)}
NNODES=${WORLD_SIZE:-1}
# DeepSpeed configuration
deepspeed=./scripts/zero3.json
# Model configuration
llm=Qwen/Qwen3-VL-4B-Instruct # Using HuggingFace model ID
# Training hyperparameters
lr=1e-5
batch_size=4
grad_accum_steps=4
# Training entry point
entry_file=qwenvl/train/train_qwen.py
# Dataset configuration (replace with public dataset names)
datasets=public_dataset1,public_dataset2
# Output configuration
run_name="qwen3vl"
output_dir=./output
# Training arguments
args="
--deepspeed ${deepspeed} \
--model_name_or_path "${llm}" \
--dataset_use ${datasets} \
--data_flatten True \
--tune_mm_vision False \
--tune_mm_mlp True \
--tune_mm_llm True \
--bf16 \
--output_dir ${output_dir} \
--num_train_epochs 0.5 \
--per_device_train_batch_size ${batch_size} \
--per_device_eval_batch_size $((batch_size*2)) \
--gradient_accumulation_steps ${grad_accum_steps} \
--max_pixels 50176 \
--min_pixels 784 \
--eval_strategy "no" \
--save_strategy "steps" \
--save_steps 1000 \
--save_total_limit 1 \
--learning_rate ${lr} \
--weight_decay 0 \
--warmup_ratio 0.03 \
--max_grad_norm 1 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--model_max_length 8192 \
--gradient_checkpointing True \
--dataloader_num_workers 4 \
--run_name ${run_name} \
--report_to wandb"
# Launch training
torchrun --nproc_per_node=${NPROC_PER_NODE} \
--master_addr=${MASTER_ADDR} \
--master_port=${MASTER_PORT} \
${entry_file} ${args}
\ No newline at end of file
{
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"bf16": {
"enabled": "auto"
},
"train_micro_batch_size_per_gpu": "auto",
"train_batch_size": "auto",
"gradient_accumulation_steps": "auto",
"zero_optimization": {
"stage": 2,
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto"
}
}
\ No newline at end of file
{
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"bf16": {
"enabled": "auto"
},
"train_micro_batch_size_per_gpu": "auto",
"train_batch_size": "auto",
"gradient_accumulation_steps": "auto",
"zero_optimization": {
"stage": 3,
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": true
}
}
\ No newline at end of file
{
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"bf16": {
"enabled": "auto"
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": "auto",
"weight_decay": "auto"
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto"
}
},
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"offload_param": {
"device": "cpu",
"pin_memory": true
},
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"gather_16bit_weights_on_model_save": true
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"steps_per_print": 1e5,
"wall_clock_breakdown": false
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment