Commit 1b9205c9 authored by yangzhong's avatar yangzhong
Browse files

v1.0

parents
Pipeline #2931 failed with stages
in 0 seconds
import os
import copy
from dataclasses import dataclass
import json
from glob import glob
import random
from typing import Dict, Optional, Sequence, List, Iterator
from operator import itemgetter
from tqdm import tqdm
import torch
import torch.distributed as dist
from torch.utils.data import Dataset, DataLoader, DistributedSampler, Sampler
import transformers
from PIL import Image
import conversation as conversation_lib
from data_utils import DataInfo
from open_flamingo.train.any_res_data_utils import process_anyres_image
from data_configs.data_paths import IMAGE_FOLDER_DICT_GCP
LOGDIR = "."
# Model Constants
IGNORE_INDEX = -100
DEFAULT_IMAGE_TOKEN = "<image>"
def get_image_fullpath(image_file):
image_file_fp = None
for k, v in IMAGE_FOLDER_DICT_GCP.items():
if k in image_file:
image_file_fp = image_file.replace(k, v)
break
if image_file_fp is None:
print(f"File not found: {image_file}")
exit(0)
return image_file_fp
def preprocess_phi_3(
sources,
conv_template,
tokenizer: transformers.PreTrainedTokenizer,
has_image: bool = False
) -> Dict:
conv = conv_template.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
if roles[source[0]["from"]] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence["from"]]
assert role == conv.roles[j % 2], f"{i}"
conv.append_message(role, sentence["value"])
conversations.append(conv.get_prompt())
# Tokenize conversations.
# Truncate to 2048 to save memory.
if tokenizer.model_max_length > 2048:
max_len = 2048
else:
max_len = tokenizer.model_max_length
input_ids = tokenizer(
conversations,
return_tensors="pt",
padding="longest",
max_length=max_len,
truncation=True,
).input_ids
targets = input_ids.clone()
assert conv.sep_style == conversation_lib.SeparatorStyle.PHI_3
# Mask targets
sep = conv.roles[1] + "\n"
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(tokenizer.pad_token_id).sum())
rounds = conversation.split(conv.sep2+'\n')
rounds_len = len(rounds)
cur_len = 0 # No <bos> token.
for i, rou in enumerate(rounds):
if rou == "":
break
rou += conv.sep2+'\n'
if sep in rou:
# assistant round
round_ids = tokenizer(rou,
max_length=max_len,
truncation=True).input_ids
role_prefix_ids = tokenizer(sep).input_ids
len_prefix = len(role_prefix_ids)
round_ids = round_ids[len_prefix:]
round_len = len(round_ids)
elif conv.roles[0] in rou:
# user round
rou += sep
if has_image:
round_ids = tokenizer(rou,
max_length=max_len,
truncation=True).input_ids
if i > 0:
round_ids = round_ids[2:] # Skip the bos tokens
round_len = len(round_ids)
instruction_len = round_len # All are instructions.
else:
round_ids = tokenizer(rou).input_ids
if i > 0:
round_ids = round_ids[2:] # Skip the bos tokens
round_len = len(round_ids)
instruction_len = round_len
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
else:
# system round
round_ids = tokenizer(rou,
max_length=max_len,
truncation=True).input_ids
round_len = len(round_ids)
instruction_len = round_len # All are instructions.
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
cur_len += round_len
target[cur_len:] = IGNORE_INDEX
if cur_len < max_len: # The input_ids are truncated to this max length.
if cur_len != total_len:
target[:] = IGNORE_INDEX
print(
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
f" (ignored)"
)
return dict(
input_ids=input_ids,
labels=targets,
)
def preprocess_phi_3_new(
sources,
tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
role_mapping = {"human": "user", "gpt": "assistant"}
roles = ("<|user|>", "<|assistant|>")
sep="<s>"
sep2="<|end|>"
# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
# TODO: add system prompt is there's not any in source.
# Update key names
for i, rnd in enumerate(source):
if "from" in rnd:
if rnd["from"] in ["human", "gpt"]:
rnd["role"] = role_mapping[rnd.pop("from")]
else:
rnd["role"] = rnd.pop("from")
if "value" in rnd:
rnd["content"] = rnd.pop("value")
# Apply chat template
tokenizer.chat_template = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'system' %}{{ '<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}"
chat_conv = tokenizer.apply_chat_template(source, tokenize=False)
chat_conv = chat_conv.replace(tokenizer.bos_token,'')
conversations.append(chat_conv)
# Tokenize conversations
if tokenizer.model_max_length > 2048:
max_len = 2048
else:
max_len = tokenizer.model_max_length
input_ids = tokenizer(
conversations,
return_tensors="pt",
padding="longest",
max_length=max_len,
truncation=True,
).input_ids
targets = input_ids.clone()
# assert conv.sep_style == conversation_lib.SeparatorStyle.PHI_3
# Mask targets
sep = roles[1] + "\n"
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(tokenizer.pad_token_id).sum())
rounds = conversation.split(sep2+'\n')
cur_len = 0 # No <bos> token.
for i, rou in enumerate(rounds):
if rou == "":
break
rou += sep2+'\n'
if sep in rou:
# assistant round
round_ids = tokenizer(rou,
max_length=max_len,
truncation=True).input_ids
role_prefix_ids = tokenizer(sep).input_ids
len_prefix = len(role_prefix_ids)
round_ids = round_ids[len_prefix:]
round_len = len(round_ids)
elif roles[0] in rou:
# user round
rou += sep
round_ids = tokenizer(rou,
max_length=max_len,
truncation=True).input_ids
if i > 0:
round_ids = round_ids[1:] # Skip the bos tokens
round_len = len(round_ids)
instruction_len = round_len # All are instructions.
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
else:
# system round
round_ids = tokenizer(rou,
max_length=max_len,
truncation=True).input_ids
round_len = len(round_ids)
instruction_len = round_len # All are instructions.
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
cur_len += round_len
target[cur_len:] = IGNORE_INDEX
if cur_len < max_len: # The input_ids are truncated to this max length.
if cur_len != total_len:
target[:] = IGNORE_INDEX
print(
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
f" (ignored)"
)
return dict(
input_ids=input_ids,
labels=targets,
)
def preprocess(
sources: Sequence[str],
tokenizer: transformers.PreTrainedTokenizer,
conv_template_name: Optional[str] = None,
) -> Dict:
"""
Given a list of sources, each is a conversation list. This transform:
1. Add signal '### ' at the beginning each sentence, with end signal '\n';
2. Concatenate conversations together;
3. Tokenize the concatenated conversation;
4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
"""
if conv_template_name is not None and conv_template_name in conversation_lib.conv_templates.keys():
# Use the specified preproseccing func.
conv_template = conversation_lib.conv_templates[conv_template_name]
else:
conv_template = conversation_lib.default_conversation
if conv_template.version.startswith("phi_3"):
return preprocess_phi_3_new(sources, tokenizer)
else:
raise NotImplementedError
class LazySupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
def __init__(self, data_path: str,
tokenizer: transformers.PreTrainedTokenizer,
image_processor,
data_args,
# data_args: DataArguments
):
super(LazySupervisedDataset, self).__init__()
if isinstance(data_path, str) and os.path.isfile(data_path):
# Load the default 650k data mix.
list_data_dict = json.load(open(data_path, "r"))
elif isinstance(data_path, str) and os.path.isdir(data_path):
# Load a custom mixture of data with a list of json files.
json_lists = glob(os.path.join(data_path, '*.json'))
list_data_dict = []
for json_file in json_lists:
list_data_dict.extend(json.load(open(json_file, "r")))
elif isinstance(data_path, Dict):
# data_path: yamlļ
list_data_dict = []
for json_file, n_sample in data_path.items():
d_json = json.load(open(json_file, "r"))
# print(f"Loaded {json_file} with {len(d_json)} items, requesting {n_sample} samples.")
if n_sample > len(d_json):
# print(f"Warning: Requested {n_sample} samples, but only {len(d_json)} available. Using random.choices.")
list_data_dict.extend(random.Random(42).choices(d_json, k=n_sample))
else:
list_data_dict.extend(random.Random(42).sample(d_json, k=n_sample))
else:
raise ValueError(f"Unknown data_path type: {data_path}")
# rank0_print("Formatting inputs...Skip in lazy mode")
self.tokenizer = tokenizer
self.image_processor = image_processor
self.conv_template_name = data_args.conv_template_name
self.list_data_dict = list_data_dict
self.data_args = data_args
self.anyres_grids = []
base_img_size = self.image_processor.transforms[0].size[0]
for (m,n) in data_args.anyres_grids:
self.anyres_grids.append([base_img_size*m, base_img_size*n])
def __len__(self):
return len(self.list_data_dict)
@property
def lengths(self):
length_list = []
for sample in self.list_data_dict:
img_tokens = 128 if 'image' in sample else 0
length_list.append(sum(len(conv['value'].split()) for conv in sample['conversations']) + img_tokens)
return length_list
@property
def modality_lengths(self):
length_list = []
for sample in self.list_data_dict:
cur_len = sum(len(conv['value'].split()) for conv in sample['conversations'])
cur_len = cur_len if 'image' in sample else -cur_len
length_list.append(cur_len)
return length_list
def _process_single_image(self, image_file) -> Dict[str, torch.Tensor]:
image_file_fullpath = get_image_fullpath(image_file)
success = True
try:
image = Image.open(image_file_fullpath).convert('RGB')
except:
print(f"error opening the file: {image_file_fullpath}")
success = False
return success, None, None
processor = self.image_processor
img_size = image.size
if self.data_args.image_aspect_ratio == 'pad':
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
# FIXME: Hardcoded workaround to work with torchvision.Compose()
image = expand2square(image, tuple(int(x*255) for x in processor.transforms[-1].mean))
image = processor(image) # FIXME: whether to take the 0-th item.
elif self.data_args.image_aspect_ratio == "anyres":
# Return image shape: [N_patch, C, H, W]
image = process_anyres_image(image, processor, self.anyres_grids)
else:
image = processor(image)
return success, image, img_size
def _check_img_token_nums(self, source):
keep_sample = True
if 'image' not in source:
# Make sure no <image> token in text-only samples.
for conv in source["conversations"]:
n_img_token = conv["value"].count(DEFAULT_IMAGE_TOKEN)
if n_img_token > 0:
keep_sample = False
break
return keep_sample, source
n_image = len(source['image']) if isinstance(source['image'], list) else 1
if n_image > 1:
# FIXME: the checker below doesn't work for mantis. Currently only check for single image data.
return keep_sample, source
for conv in source["conversations"]:
if conv["from"] == "human":
n_img_token = conv["value"].count(DEFAULT_IMAGE_TOKEN)
if not n_img_token == n_image:
# print(source)
conv["value"] = conv["value"].replace(DEFAULT_IMAGE_TOKEN, '').strip()
conv["value"] = f"{DEFAULT_IMAGE_TOKEN}\n" * n_image + conv["value"]
break
return keep_sample, source
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
sources = self.list_data_dict[i]
keep_sample, sources = self._check_img_token_nums(sources)
if not keep_sample:
return self.__getitem__(i+1)
if isinstance(i, int):
sources = [sources]
assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
# Add the system prompt.
system_round = {
"from": "system",
"value": "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions."
}
if sources[0]["conversations"][0]["from"] != "system":
sources[0]["conversations"] = [system_round] + sources[0]["conversations"]
if 'image' in sources[0]:
has_image = True
image_file = sources[0]['image']
if isinstance(image_file, list):
# FIXME: Skipping samples with more than 4 images to avoid OOM issue.
if len(image_file) > 4:
return self.__getitem__(i+1)
image = []
img_size = []
for single_image in image_file:
success, image_i, img_size_i = self._process_single_image(single_image)
if not success:
# Skip the entire sample if one of the images can't be opened.
return self.__getitem__(i+1)
image.append(image_i)
img_size.append(img_size_i)
elif isinstance(image_file, str):
success, image, img_size = self._process_single_image(image_file)
if not success:
# Skip the entire sample if one of the images can't be opened.
return self.__getitem__(i+1)
else:
raise NotImplementedError(f"Unknown image_file type: {image_file}")
sources = copy.deepcopy([e["conversations"] for e in sources])
else:
has_image = False
sources = copy.deepcopy([e["conversations"] for e in sources])
data_dict = preprocess(
sources,
self.tokenizer,
conv_template_name=self.conv_template_name)
if isinstance(i, int):
data_dict = dict(input_ids=data_dict["input_ids"][0],
labels=data_dict["labels"][0])
# image exist in the data
if has_image:
if isinstance(image, list):
# Multi-image, each image can be of 4-dim (anyres) or 3-dim (base res)
data_dict['image'] = image
if image[0].ndim == 3:
# Stack base res image groups along the T-dim.
image = torch.stack(image, dim=0)
data_dict['image'] = image.unsqueeze(1) # [T, 1, C, H, W]
elif image.ndim == 4: # Any-res image patches of a single image - use the F dim for N-patches.
data_dict['image'] = image[None, :]
else: # single image, single frame
data_dict['image'] = image[None, None, :] # Expand dims with [T_img, F] to be compatible with flamingo-like vision encoding.
data_dict['image_size'] = img_size
else:
# image does not exist in the data, but the model is multimodal
crop_size = self.image_processor.transforms[0].size # FIXME: Hardcoded workaround to work with torchvision.Compose()
data_dict['image'] = torch.zeros(1, 1, 3, crop_size[0], crop_size[1]) # Expand dims with [T_img, F] to be compatible with flamingo-like vision encoding.
data_dict['image_size'] = crop_size
return data_dict
def stack_with_padding(list_of_tensors, padding_value=0, padding_side="right"):
"""
Stack a list of tensors with padding on one side
Args:
list_of_tensors (list[torch.Tensor]): List of tensors to stack
padding_value (int, optional): Value to pad with. Defaults to 0.
padding_side (str, optional): Side to pad on. Defaults to "right".
Returns:
torch.Tensor: Stacked tensors
"""
max_tokens = max(tensor.size(0) for tensor in list_of_tensors)
padded_tensors = []
for tensor in list_of_tensors:
num_tokens = tensor.size(0)
padding = torch.full(
(max_tokens - num_tokens,) + tuple(tensor.shape[1:]),
padding_value,
dtype=tensor.dtype,
device=tensor.device,
)
padded_tensor = (
torch.cat((tensor, padding), dim=0)
if padding_side == "right"
else torch.cat((padding, tensor), dim=0)
)
padded_tensors.append(padded_tensor)
return torch.stack(padded_tensors)
@dataclass
class DataCollatorForSupervisedDataset(object):
"""Collate examples for supervised fine-tuning."""
tokenizer: transformers.PreTrainedTokenizer
image_aspect_ratio: str
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels = tuple([instance[key] for instance in instances]
for key in ("input_ids", "labels"))
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids,
batch_first=True,
padding_value=self.tokenizer.pad_token_id)
labels = torch.nn.utils.rnn.pad_sequence(labels,
batch_first=True,
padding_value=IGNORE_INDEX)
input_ids = input_ids[:, :self.tokenizer.model_max_length]
labels = labels[:, :self.tokenizer.model_max_length]
batch = dict(
input_ids=input_ids,
labels=labels,
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
)
if 'image' in instances[0]:
images = [instance['image'] for instance in instances]
image_size = [instance['image_size'] for instance in instances]
batch['image_size'] = image_size
if any(isinstance(x, list) for x in images):
images_list = []
for x in images:
if isinstance(x, torch.Tensor):
images_list.append([x])
elif isinstance(x, list):
images_list.append(x)
else:
raise NotImplementedError(f"Unknown data type: {x}")
image_size_list = []
for x in image_size:
if not isinstance(x, list):
image_size_list.append([x])
else:
image_size_list.append(x)
batch['images'] = images_list
batch['image_size'] = image_size_list
elif images[0].shape[0] == 1 and all(x is not None and x.shape == images[0].shape for x in images):
# stacking images when not using anyres.
batch['images'] = torch.stack(images)
elif images[0].ndim == 5 and self.image_aspect_ratio != 'anyres':
# Stacking batch of multi-image base-res image groups with padding.
batch['images'] = stack_with_padding(images)
else:
batch['images'] = images
return batch
def split_to_even_chunks(indices, lengths, num_chunks):
"""
Split a list of indices into `chunks` chunks of roughly equal lengths.
"""
if len(indices) % num_chunks != 0:
return [indices[i::num_chunks] for i in range(num_chunks)]
num_indices_per_chunk = len(indices) // num_chunks
chunks = [[] for _ in range(num_chunks)]
chunks_lengths = [0 for _ in range(num_chunks)]
for index in indices:
shortest_chunk = chunks_lengths.index(min(chunks_lengths))
chunks[shortest_chunk].append(index)
chunks_lengths[shortest_chunk] += lengths[index]
if len(chunks[shortest_chunk]) == num_indices_per_chunk:
chunks_lengths[shortest_chunk] = float("inf")
return chunks
def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None):
# We need to use torch for the random part as a distributed sampler will set the random seed for torch.
assert all(l != 0 for l in lengths), "Should not have zero length."
if all(l > 0 for l in lengths) or all(l < 0 for l in lengths):
# all samples are in the same modality
return get_length_grouped_indices(lengths, batch_size, world_size, generator=generator)
mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0])
lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0])
mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)]
lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)]
megabatch_size = world_size * batch_size
mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)]
lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)]
last_mm = mm_megabatches[-1]
last_lang = lang_megabatches[-1]
additional_batch = last_mm + last_lang
megabatches = mm_megabatches[:-1] + lang_megabatches[:-1]
if generator is not None:
torch.manual_seed(42)
megabatch_indices = torch.randperm(len(megabatches), generator=generator.manual_seed(42))
megabatches = [megabatches[i] for i in megabatch_indices]
if len(additional_batch) > 0:
megabatches.append(sorted(additional_batch))
return [i for megabatch in megabatches for i in megabatch]
def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True):
# We need to use torch for the random part as a distributed sampler will set the random seed for torch.
indices = torch.randperm(len(lengths), generator=generator)
megabatch_size = world_size * batch_size
megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]
megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches]
return [i for megabatch in megabatches for batch in megabatch for i in batch]
class LengthGroupedSampler(Sampler):
r"""
Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
keeping a bit of randomness.
"""
def __init__(
self,
batch_size: int,
world_size: int,
lengths: Optional[List[int]] = None,
generator=None,
group_by_modality: bool = False,
):
if lengths is None:
raise ValueError("Lengths must be provided.")
self.batch_size = batch_size
self.world_size = world_size
self.lengths = lengths
self.generator = generator
self.group_by_modality = group_by_modality
def __len__(self):
return len(self.lengths)
def __iter__(self):
if self.group_by_modality:
indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
else:
indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
return iter(indices)
class DatasetFromSampler(Dataset):
"""Dataset to create indexes from `Sampler`.
Args:
sampler: PyTorch sampler
"""
def __init__(self, sampler: Sampler):
"""Initialisation for DatasetFromSampler."""
self.sampler = sampler
self.sampler_list = None
def __getitem__(self, index: int):
"""Gets element of the dataset.
Args:
index: index of the element in the dataset
Returns:
Single element by index
"""
if self.sampler_list is None:
self.sampler_list = list(self.sampler)
return self.sampler_list[index]
def __len__(self) -> int:
"""
Returns:
int: length of the dataset
"""
return len(self.sampler)
class DistributedSamplerWrapper(DistributedSampler):
"""
https://github.com/catalyst-team/catalyst/blob/master/catalyst/data/sampler.py
Wrapper over `Sampler` for distributed training.
Allows you to use any sampler in distributed mode.
It is especially useful in conjunction with
`torch.nn.parallel.DistributedDataParallel`. In such case, each
process can pass a DistributedSamplerWrapper instance as a DataLoader
sampler, and load a subset of subsampled data of the original dataset
that is exclusive to it.
.. note::
Sampler is assumed to be of constant size.
"""
def __init__(
self,
sampler,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
shuffle: bool = True,
):
"""
Args:
sampler: Sampler used for subsampling
num_replicas (int, optional): Number of processes participating in
distributed training
rank (int, optional): Rank of the current process
within ``num_replicas``
shuffle (bool, optional): If true (default),
sampler will shuffle the indices
"""
super(DistributedSamplerWrapper, self).__init__(
DatasetFromSampler(sampler),
num_replicas=num_replicas,
rank=rank,
shuffle=shuffle,
)
self.sampler = sampler
def __iter__(self) -> Iterator[int]:
"""Iterate over sampler.
Returns:
python iterator
"""
self.dataset = DatasetFromSampler(self.sampler)
indexes_of_indexes = super().__iter__()
subsampler_indexes = self.dataset
return iter(itemgetter(*indexes_of_indexes)(subsampler_indexes))
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
image_processor,
data_args) -> Dict:
"""Make dataset and collator for supervised fine-tuning."""
train_dataset = LazySupervisedDataset(tokenizer=tokenizer,
data_path=data_args.data_path,
image_processor=image_processor,
data_args=data_args)
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer,
image_aspect_ratio=data_args.image_aspect_ratio)
if data_args.data_sampler_group_by_length:
# Use length grouped sampler for more balanced GPU usages.
lengths = train_dataset.modality_lengths
sampler_inner = LengthGroupedSampler(
data_args.batch_size,
world_size=data_args.world_size * data_args.gradient_accumulation_steps,
lengths=lengths,
group_by_modality=True,
generator=torch.Generator().manual_seed(42),
)
sampler = DistributedSamplerWrapper(
sampler=sampler_inner,
num_replicas=data_args.world_size,
rank=data_args.rank,
shuffle=False
)
else:
sampler = DistributedSampler(
train_dataset,
shuffle=True,
num_replicas=dist.get_world_size(),
rank=dist.get_rank(),
)
# sampler = None
data_loader = DataLoader(
train_dataset,
batch_size=data_args.batch_size,
num_workers=data_args.workers,
pin_memory=True,
sampler=sampler,
shuffle=sampler is None,
collate_fn=data_collator,
)
return DataInfo(
name='instruction-finetune-mix',
dataloader=data_loader,
batch_size=data_args.batch_size,
loss_multiplier=1.0,
shared_epoch=None,
sampler=sampler,
), len(train_dataset)
from argparse import Namespace
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import open_clip
from sft_data_utils import make_supervised_data_module
IGNORE_INDEX = -100
if __name__=='__main__':
# Constant for unit test.
tokenizer_path = 'lmsys/vicuna-7b-v1.5'
clip_vision_encoder_path = 'ViT-H-14-378-quickgelu'
clip_vision_encoder_pretrained = 'dfn5b'
cache_dir='/export/share/manlis/models'
# load tokenizer and ensure there is a pad token
text_tokenizer = AutoTokenizer.from_pretrained(
tokenizer_path,
local_files_only=False,
trust_remote_code=True,
cache_dir=cache_dir,
use_fast=False,
)
if text_tokenizer.pad_token is None or text_tokenizer.pad_token == text_tokenizer.eos_token:
# add a pad token if it doesn't exist
text_tokenizer.add_special_tokens({"pad_token": "<pad>"})
# add special tokens to the tokenizer and language models
special_tokens = {
"media_token": "<image>",
}
text_tokenizer.add_special_tokens(
{"additional_special_tokens": list(special_tokens.values())}
)
# load vision encoder
_, _, image_processor = open_clip.create_model_and_transforms(
clip_vision_encoder_path,
pretrained=clip_vision_encoder_pretrained,
cache_dir=cache_dir,
force_image_size=378,
)
# Create dataset.
args = Namespace(
data_sampler_group_by_length=False,
data_path='/export/share/manlis/data/lavis/llava_instruct_665k_sharegpt4v/annotations/sharegpt4v_mix665k_cap23k_coco-ap9k_lcs3k_sam9k_div2k.json',
batch_size=8,
world_size=8,
gradient_accumulation_steps=1,
rank=0,
workers=4,
image_aspect_ratio='pad',
is_multimodal=True,
mm_use_im_start_end=False,
)
train_dataset, total_num_samples = make_supervised_data_module(tokenizer=text_tokenizer,
image_processor=image_processor,
data_args=args)
# Iter through all data samples.
print(len(train_dataset.dataloader))
for i, sample in enumerate(train_dataset.dataloader):
if (sample['labels'] == IGNORE_INDEX).all():
print(f"sample {i} token mismatch")
pass
""" Main training script """
import argparse
import os
import torch
import wandb
import functools
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
from open_flamingo import create_model_and_transforms, SUPPORTED_MODEL_FAMILIES
from open_flamingo.train.data import get_data, SUPPORTED_DATASETS
from open_flamingo.train.distributed import (
init_distributed_device,
world_info_from_env,
get_fsdp_config,
get_fsdp_checkpoint_config,
)
from open_flamingo.train.train_utils import (
train_one_epoch,
random_seed,
find_most_recent_checkpoint,
load_checkpoint,
save_checkpoint,
)
from open_flamingo.train.losses import (
SUPPORTED_LOSSES,
get_loss_fn,
)
from transformers import (
get_constant_schedule_with_warmup,
get_cosine_schedule_with_warmup,
get_linear_schedule_with_warmup,
)
def main():
parser = argparse.ArgumentParser()
# model configuration args
parser.add_argument(
"--model_family", default="flamingo", type=str, choices=SUPPORTED_MODEL_FAMILIES
)
parser.add_argument("--vision_encoder_path", default="ViT-SO400M-14-SigLIP-384", type=str)
parser.add_argument("--vision_encoder_pretrained", default="webli", type=str)
parser.add_argument("--lm_path", default="facebook/opt-1.3b", type=str)
parser.add_argument(
"--tokenizer_path",
default="facebook/opt-30b",
type=str,
help="path to tokenizer",
)
parser.add_argument(
"--cross_attn_every_n_layers",
type=int,
default=1,
help="how often to add a cross-attention layer after each transformer layer",
)
# training args
parser.add_argument(
"--loss", type=str, choices=SUPPORTED_LOSSES, default="next_token_prediction"
)
parser.add_argument(
"--run_name",
type=str,
default="openflamingo3B",
help="used to name saving directory and wandb run",
)
parser.add_argument(
"--resume_from_checkpoint",
type=str,
help="path to checkpoint to resume from, this should contain model, optimizer, and lr_scheduler states. if there exists a checkpoint in the dir named run_name, we will resume from that checkpoint by default.",
default=None,
)
parser.add_argument(
"--delete_previous_checkpoint",
action="store_true",
help="delete previous checkpoint when saving new checkpoint",
)
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--learning_rate", default=1e-4, type=float)
parser.add_argument(
"--lr_scheduler",
default="constant",
type=str,
help="constant, linear, or cosine",
)
parser.add_argument("--warmup_steps", default=5000, type=int)
parser.add_argument("--weight_decay", default=0.1, type=float)
parser.add_argument(
"--precision",
choices=["amp_bf16", "amp_bfloat16", "bf16", "fp16", "fp32"],
default="fp32",
help="Floating point precision.",
)
parser.add_argument(
"--gradient_checkpointing",
action="store_true",
help="whether to train with gradient/activation checkpointing",
)
parser.add_argument(
"--num_epochs",
type=int,
default=1,
help="we define an 'epoch' as a fixed number of examples specified by train_num_samples, not a pass through the entire dataset",
)
parser.add_argument("--offline", action="store_true")
parser.add_argument(
"--logging_steps", type=int, default=100, help="log loss every n steps"
)
# data args
for dataset_name in SUPPORTED_DATASETS:
parser.add_argument(f"--batch_size_{dataset_name}", type=int, default=128)
parser.add_argument(
f"--loss_multiplier_{dataset_name}", type=float, default=1.0
)
parser.add_argument(
f"--train_num_samples_{dataset_name}",
type=int,
default=10000,
help="Number of samples in an 'epoch' for this dataset. Note that train_num_samples/batch_size must be the same for all datasets.",
)
parser.add_argument(
f"--{dataset_name}_shards",
type=str,
default=None,
help="Should be a glob pattern such as /path/to/shards/shard-{0000..0999}.tar. If None, we will not train on this dataset.",
)
parser.add_argument("--workers", type=int, default=1)
parser.add_argument("--dataset_resampled", action="store_true")
parser.add_argument(
"--mmc4_textsim_threshold",
default=0.24,
type=float,
help="threshold for filtering images in mmc4 based on image-text similarity",
)
parser.add_argument(
"--mmc4_max_num_images",
default=6,
type=int,
help="max number of images per sequence in mmc4 / chatgpt",
)
parser.add_argument(
"--mmc4_min_num_images",
default=1,
type=int,
help="min number of images per sequence in mmc4 / chatgpt",
)
# distributed training args
parser.add_argument(
"--dist-url",
default="env://",
type=str,
help="url used to set up distributed training",
)
parser.add_argument(
"--dist-backend", default="nccl", type=str, help="distributed backend"
)
parser.add_argument(
"--horovod",
default=False,
action="store_true",
help="Use horovod for distributed training.",
)
parser.add_argument(
"--no-set-device-rank",
default=False,
action="store_true",
help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).",
)
# fsdp args
parser.add_argument(
"--fsdp",
default=False,
action="store_true",
help="Use FullyShardedDataParallel for distributed training. Not supported for some models, e.g. OPT.",
)
parser.add_argument(
"--fsdp_sharding_strategy", default="full", type=str, choices=["full", "hybrid", "shard_grad_op", "hybrid_shard_grad_op", "no_shard"]
)
# wandb args
parser.add_argument("--report_to_wandb", default=False, action="store_true")
parser.add_argument(
"--wandb_project",
type=str,
)
parser.add_argument(
"--wandb_entity",
type=str,
)
parser.add_argument(
"--save_checkpoints_to_wandb",
default=False,
action="store_true",
help="save checkpoints to wandb",
)
parser.add_argument(
'--local-rank',
default=0,
type=int,
help='Local rank for distributed training'
)
parser.add_argument(
'--use_flash_attention_2',
default=False, action='store_true',
help='Use Flash Attention 2.0 for language model.'
)
parser.add_argument(
'--unfreeze_vision_encoder',
default=False, action='store_true',
help='Unfreeze vision encoder during training.'
)
parser.add_argument(
'--vision_encoder_precision',
default='fp32',
choices=["bf16", "fp32"],
help='Precision of the vision encoder during training.'
)
parser.add_argument(
'--cpu_offload_gradients',
default=False, action='store_true',
help='This specifies whether to offload parameters to CPU when not involved in computation. If True, then this offloads gradients to CPU as well, meaning that the optimizer step runs on CPU.'
)
args = parser.parse_args()
# Parse which datasets to train on and which to exclude
datasets_to_train_on = []
for dataset_name in SUPPORTED_DATASETS:
if getattr(args, f"{dataset_name}_shards") is None:
print(f"Excluding {dataset_name} from training")
setattr(args, f"train_num_samples_{dataset_name}", 0)
setattr(args, f"batch_size_{dataset_name}", 0)
else:
datasets_to_train_on.append(dataset_name)
shards_path = getattr(args, f"{dataset_name}_shards")
if shards_path.startswith("s3"):
setattr(
args,
f"{dataset_name}_shards",
f"pipe:aws s3 cp {shards_path} -",
)
assert len(datasets_to_train_on) > 0, "Must train on at least one dataset"
# Validate args
for i in range(len(datasets_to_train_on) - 1):
assert getattr(args, f"train_num_samples_{datasets_to_train_on[i]}") // getattr(
args, f"batch_size_{datasets_to_train_on[i]}"
) == getattr(
args, f"train_num_samples_{datasets_to_train_on[i + 1]}"
) // getattr(
args, f"batch_size_{datasets_to_train_on[i + 1]}"
), "Number of batches in each dataloader must be the same"
if args.save_checkpoints_to_wandb and not args.report_to_wandb:
raise ValueError("save_checkpoints_to_wandb requires report_to_wandb")
if args.fsdp:
assert (
torch.__version__ > "2.0.1"
), "FSDP requires torch > 2.0.1"
# Set up distributed training
args.local_rank, args.rank, args.world_size = world_info_from_env()
if args.rank == 0:
print(f"Initializing distributed training with {args.world_size} GPUs.")
if args.offline:
os.environ["WANDB_MODE"] = "offline"
os.environ["TRANSFORMERS_OFFLINE"] = "1"
device_id = init_distributed_device(args)
random_seed(args.seed)
# Initialize model
additional_kwargs = (
{"cross_attn_every_n_layers": args.cross_attn_every_n_layers}
if args.model_family == "flamingo"
else {}
)
model, image_processor, tokenizer = create_model_and_transforms(
args.vision_encoder_path,
args.vision_encoder_pretrained,
args.lm_path,
args.tokenizer_path if args.tokenizer_path else args.lm_path,
model_family=args.model_family,
use_local_files=args.offline,
gradient_checkpointing=args.gradient_checkpointing,
verbose=(args.rank == 0),
**additional_kwargs,
)
random_seed(args.seed, args.rank)
# Initialize wandb logging
if args.rank == 0 and args.report_to_wandb:
wandb.init(
project=args.wandb_project,
entity=args.wandb_entity,
name=args.run_name,
config=vars(args),
)
# Load model checkpoint (on CPU)
if args.fsdp:
args.fsdp_checkpoint_config = get_fsdp_checkpoint_config(args)
# if args do not specify a checkpoint to resume from, resume from most recent checkpoint
if os.path.exists(f"{args.run_name}") and args.resume_from_checkpoint is None:
args.resume_from_checkpoint = find_most_recent_checkpoint(args)
if (
args.resume_from_checkpoint is not None
):
resume_from_epoch, checkpoint = load_checkpoint(args, model)
else:
resume_from_epoch = 0
# Initialize gradient checkpointing
if args.gradient_checkpointing:
model.init_gradient_checkpointing()
# Initialize FSDP / DDP, and ensure the model is on GPU
if args.fsdp:
auto_wrap_policy = functools.partial(
lambda_auto_wrap_policy, lambda_fn=model.get_fsdp_lambda_fn()
)
wrapper_kwargs = get_fsdp_config(args, device_id)
distributed_model = FSDP(
model, auto_wrap_policy=auto_wrap_policy, **wrapper_kwargs
)
else:
model = model.to(device_id)
distributed_model = DDP(model, device_ids=[device_id])
# Initialize optimizer
params_with_wd, params_without_wd = model.group_params_by_weight_decay()
optimizer = torch.optim.AdamW(
[
{"params": params_with_wd, "weight_decay": args.weight_decay},
{"params": params_without_wd, "weight_decay": 0.0},
],
lr=args.learning_rate,
)
# load optimizer checkpoint
if args.resume_from_checkpoint is not None:
optim_state_dict = checkpoint["optimizer_state_dict"]
if args.fsdp:
# FSDP.set_state_dict_type(
# distributed_model,
# **args.fsdp_checkpoint_config,
# )
optim_state_dict = FSDP.optim_state_dict_to_load(
model=distributed_model, optim=optimizer, optim_state_dict=optim_state_dict
)
optimizer.load_state_dict(optim_state_dict)
# Initialize datasets
datasets = [
get_data(args, image_processor, tokenizer, dataset_name)
for dataset_name in datasets_to_train_on
]
total_training_steps = (
getattr(args, f"train_num_samples_{datasets_to_train_on[0]}")
// (getattr(args, f"batch_size_{datasets_to_train_on[0]}") * args.gradient_accumulation_steps * args.world_size)
) * args.num_epochs
if args.rank == 0:
print(f"Total training steps: {total_training_steps}")
# Initialize lr scheduler
if args.lr_scheduler == "linear":
lr_scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=args.warmup_steps,
num_training_steps=total_training_steps,
)
elif args.lr_scheduler == "cosine":
lr_scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=args.warmup_steps,
num_training_steps=total_training_steps,
)
else:
lr_scheduler = get_constant_schedule_with_warmup(
optimizer, num_warmup_steps=args.warmup_steps
)
# load lr scheduler checkpoint
if args.resume_from_checkpoint is not None:
lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"])
# Initialize the loss fn
loss_fn = get_loss_fn(args.loss)
# check wrapping
if args.rank == 0:
print(distributed_model)
# Start training!
print(f"Start running training on rank {args.rank}.")
for epoch in range(resume_from_epoch, args.num_epochs):
for dataset in datasets:
dataset.set_epoch(epoch)
train_one_epoch(
args=args,
model=distributed_model,
epoch=epoch,
datasets=datasets,
compute_loss_fn=loss_fn,
tokenizer=tokenizer,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
device_id=device_id,
wandb=wandb,
)
save_checkpoint(distributed_model, optimizer, lr_scheduler, epoch, args)
if __name__ == "__main__":
main()
\ No newline at end of file
import time
from contextlib import suppress
import torch
from tqdm import tqdm
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
import os
import shutil
import wandb
import glob
from data_utils import DataInfo
import random
import numpy as np
import torch.nn as nn
def train_one_epoch(
args,
model,
epoch,
datasets: [DataInfo],
compute_loss_fn: callable,
tokenizer,
optimizer,
lr_scheduler,
device_id,
wandb,
):
"""
Helper function for running one epoch of training.
Handles logging, calling forward, backward, gradient clipping, and optimizer step.
Args:
args (argparse.Namespace): arguments from command line
model: DDP / FSDP wrapped model
epoch (int): epoch number
datasets (list): list of DataInfos, one for each dataset, to train on
compute_loss_fn (callable): function that given the model and inputs, calls forward
and returns a loss
tokenizer: tokenizer for the language model
optimizer: optimizer to step
lr_scheduler: learning rate scheduler
device_id (int): GPU device ID for this rank
wandb: wandb object for logging
"""
# calculate the number of steps in an epoch
num_batches_per_epoch = datasets[0].dataloader.num_batches
total_training_steps = num_batches_per_epoch * args.num_epochs
# set up model, autocast, and dtypes
model.train()
autocast = get_autocast(args.precision)
# set up logging
step_time_m = AverageMeter()
data_time_m = AverageMeter()
end = time.time()
# loop through the batches in this epoch
for step_num, batches in tqdm(
enumerate(zip(*[dataset.dataloader for dataset in datasets])),
disable=args.rank != 0,
total=total_training_steps,
initial=(epoch * num_batches_per_epoch),
):
data_time_m.update(time.time() - end)
global_step = step_num + epoch * num_batches_per_epoch
# call compute_loss_fn on each dataset; call backward before continuing
losses_to_log = {}
batch_metadata_to_log = {}
for dataset_ix, (images, (input_ids, attention_mask)) in enumerate(batches):
# unpack the batch and move to device
images = images.to(device_id, non_blocking=True)
input_ids = input_ids.to(device_id, non_blocking=True)
attention_mask = attention_mask.to(device_id, non_blocking=True)
# save some metadata for logging
batch_metadata_to_log[
f"{datasets[dataset_ix].name}_num_tokens"
] = attention_mask.sum().item()
batch_metadata_to_log[f"{datasets[dataset_ix].name}_num_images"] = (
(input_ids == unwrap_model(model).media_token_id).sum().item()
)
# forward pass
dataset_loss = compute_loss_fn(
model=model,
tokenizer=tokenizer,
images=images,
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
)[0]
divided_loss_laion = loss_laion / args.gradient_accumulation_steps
(divided_loss_laion * args.loss_multiplier_laion).backward()
#### MMC4 FORWARD PASS ####
images = batch_mmc4[0].to(device_id, dtype=cast_dtype, non_blocking=True)
images = rearrange(images, "b (t f) c h w -> b t f c h w", f=1)
input_ids = torch.stack([x[0] for x in batch_mmc4[1]]).squeeze(1)
attention_mask = torch.stack([x[1] for x in batch_mmc4[1]]).squeeze(1)
# set up labels; language model is expected to handle shifting
labels = input_ids.clone()
labels[labels == tokenizer.pad_token_id] = -100
labels[labels == tokenizer.eos_token] = -100
for i in range(labels.shape[0]):
# remove loss for any token before the first <image> token
label_idx = 0
while (
label_idx < labels.shape[1] and labels[i][label_idx] != media_token_id
):
labels[i][label_idx] = -100
label_idx += 1
# get index of all endofchunk tokens in the sequence
endofchunk_idxs = torch.where(labels[i] == endofchunk_token_id)[0]
for endofchunk_idx in endofchunk_idxs:
token_idx = endofchunk_idx + 1
while (
token_idx < labels.shape[1]
and labels[i][token_idx] != media_token_id
):
labels[i][token_idx] = -100
token_idx += 1
labels[labels == media_token_id] = -100
labels = labels.to(device_id)
# gradient accumulation w/ fsdp cpu offloading requires a no_sync context manager
with autocast():
loss_mmc4 = model(
vision_x=images,
lang_x=input_ids.to(device_id),
attention_mask=attention_mask.to(device_id),
labels=labels,
)[0]
# if loss is nan, skip this batch
# this hack of skipping the batch is not FSDP-compatible
if torch.isnan(loss_mmc4):
print("loss is nan, skipping this batch")
print("input_ids: ", tokenizer.batch_decode(input_ids))
print("labels: ", labels)
print("images: ", images)
optimizer.zero_grad(set_to_none=True)
continue
divided_loss_mmc4 = loss_mmc4 / args.gradient_accumulation_steps
(divided_loss_mmc4 * args.loss_multiplier_mmc4).backward()
if (not args.freeze_lm_embeddings) and (
not args.fsdp or args.fsdp_use_orig_params
):
# Mask gradients for input embeddings s.t. we only update the added tokens <image> and <|endofchunk|>
if args.fsdp:
embed_grad = model.lang_encoder.get_input_embeddings().weight.grad
else:
embed_grad = (
model.module.lang_encoder.get_input_embeddings().weight.grad
)
zero_mask = torch.zeros_like(embed_grad)
zero_mask[media_token_id] = torch.ones_like(zero_mask[media_token_id])
zero_mask[endofchunk_token_id] = torch.ones_like(
zero_mask[endofchunk_token_id]
)
if args.fsdp:
model.lang_encoder.get_input_embeddings().weight.grad = (
embed_grad * zero_mask
)
else:
model.module.lang_encoder.get_input_embeddings().weight.grad = (
embed_grad * zero_mask
)
# clip gradient norm
if args.fsdp:
model.clip_grad_norm_(1.0, norm_type=2.0)
else:
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
# step optimizer and log
if (((step_num + 1) % args.gradient_accumulation_steps) == 0) or (
step_num == num_batches_per_epoch - 1
):
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
# step time and reset end outside of rank 0
step_time_m.update(time.time() - end)
end = time.time()
# rank 0 logging
if args.rank == 0 and args.report_to_wandb:
# calculate samples per second
throughput_metrics = compute_throughput(
args,
datasets,
batch_metadata_to_log,
step_time_m,
)
wandb.log(
{
"global_step": global_step,
"lr": optimizer.param_groups[0]["lr"],
"data_time": data_time_m.avg,
"step_time": step_time_m.avg,
**throughput_metrics,
**losses_to_log,
},
commit=True,
)
step_time_m.reset()
data_time_m.reset()
# Log loss to console
if ((step_num + 1) % args.logging_steps == 0) and args.rank == 0:
print(
f"Step {step_num+1}/{num_batches_per_epoch} of epoch {epoch+1}/{args.num_epochs} complete. Losses: "
+ "// ".join([f"{k}: {v:.3f}" for k, v in losses_to_log.items()])
)
def finetune_one_epoch(
args,
resume_from_step,
model,
epoch,
dataset: DataInfo,
compute_loss_fn: callable,
tokenizer,
optimizer,
lr_scheduler,
device_id,
wandb,
):
"""
Helper function for running one epoch of training.
Handles logging, calling forward, backward, gradient clipping, and optimizer step.
Args:
args (argparse.Namespace): arguments from command line
model: DDP / FSDP wrapped model
epoch (int): epoch number
datasets (list): list of DataInfos, one for each dataset, to train on
compute_loss_fn (callable): function that given the model and inputs, calls forward
and returns a loss
tokenizer: tokenizer for the language model
optimizer: optimizer to step
lr_scheduler: learning rate scheduler
device_id (int): GPU device ID for this rank
wandb: wandb object for logging
"""
# calculate the number of steps in an epoch
num_batches_per_epoch = len(dataset.dataloader)
total_training_steps = num_batches_per_epoch * args.num_epochs
# set up model, autocast, and dtypes
model.train()
autocast = get_autocast(args.precision)
# set up logging
step_time_m = AverageMeter()
data_time_m = AverageMeter()
end = time.time()
# loop through the batches in this epoch
for step_num, samples in tqdm(enumerate(dataset.dataloader),
disable=args.rank != 0,
total=total_training_steps,
initial=epoch * num_batches_per_epoch,
):
# for step_num, samples in enumerate(dataset.dataloader):
if step_num < resume_from_step:
# Jump to the resume step.
continue
data_time_m.update(time.time() - end)
global_step = step_num + epoch * num_batches_per_epoch
# call compute_loss_fn on each dataset; call backward before continuing
losses_to_log = {}
batch_metadata_to_log = {}
# images, (input_ids, attention_mask) = samples
# unpack the batch and move to device
images = samples["images"]
if not isinstance(images, list):
images = images.to(device_id, non_blocking=True)
input_ids = samples["input_ids"].to(device_id, non_blocking=True)
attention_mask = samples["attention_mask"].to(device_id, non_blocking=True)
labels = samples["labels"].to(device_id, non_blocking=True)
# save some metadata for logging
batch_metadata_to_log[
f"{dataset.name}_num_tokens"
] = attention_mask.sum().item()
batch_metadata_to_log[f"{dataset.name}_num_images"] = (
(input_ids == unwrap_model(model).media_token_id).sum().item()
)
# forward pass
loss = compute_loss_fn(
model=model,
tokenizer=tokenizer,
images=images,
image_size=samples['image_size'],
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
autocast=autocast,
)
losses_to_log["train_loss"] = loss.item()
divided_loss = loss / args.gradient_accumulation_steps
divided_loss.backward()
if args.dryrun:
del loss
del divided_loss
optimizer.zero_grad(set_to_none=True)
continue
# FIXME: Where are the special tokens added/defined?
# if (not args.freeze_lm_embeddings) and (
# not args.fsdp or args.fsdp_use_orig_params
# ):
# # Mask gradients for input embeddings s.t. we only update the added tokens <image> and <|endofchunk|>
# if args.fsdp:
# embed_grad = model.lang_encoder.get_input_embeddings().weight.grad
# else:
# embed_grad = (
# model.module.lang_encoder.get_input_embeddings().weight.grad
# )
# zero_mask = torch.zeros_like(embed_grad)
# zero_mask[media_token_id] = torch.ones_like(zero_mask[media_token_id])
# zero_mask[endofchunk_token_id] = torch.ones_like(
# zero_mask[endofchunk_token_id]
# )
# if args.fsdp:
# model.lang_encoder.get_input_embeddings().weight.grad = (
# embed_grad * zero_mask
# )
# else:
# model.module.lang_encoder.get_input_embeddings().weight.grad = (
# embed_grad * zero_mask
# )
# clip gradient norm
if args.fsdp:
model.clip_grad_norm_(1.0, norm_type=2.0)
else:
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
# step optimizer and log
if (((step_num + 1) % args.gradient_accumulation_steps) == 0) or (
step_num == num_batches_per_epoch - 1
):
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
# step time and reset end outside of rank 0
step_time_m.update(time.time() - end)
end = time.time()
# rank 0 logging
if args.rank == 0 and args.report_to_wandb:
# calculate samples per second
throughput_metrics = compute_throughput(
args,
[dataset],
batch_metadata_to_log,
step_time_m,
)
wandb.log(
{
"global_step": global_step,
"lr": optimizer.param_groups[0]["lr"],
**losses_to_log,
"data_time": data_time_m.avg,
"step_time": step_time_m.avg,
**throughput_metrics,
},
commit=True,
)
step_time_m.reset()
data_time_m.reset()
# dist.barrier()
# Log loss to console
if ((step_num + 1) % args.logging_steps == 0) and args.rank == 0:
print(
f"Step {step_num+1}/{num_batches_per_epoch} of epoch {epoch+1}/{args.num_epochs} complete. Losses: "
+ "// ".join([f"{k}: {v:.3f}" for k, v in losses_to_log.items()])
)
if ((step_num + 1) % args.checkpoint_steps == 0):
save_checkpoint(model, optimizer, lr_scheduler, epoch, args, step=step_num)
def get_autocast(precision, cache_enabled=True):
"""
Parses the precision argument and returns an autocast context manager.
"""
if precision == "amp":
return torch.cuda.amp.autocast(cache_enabled=cache_enabled)
elif precision == "amp_bfloat16" or precision == "amp_bf16":
return lambda: torch.cuda.amp.autocast(
dtype=torch.bfloat16, cache_enabled=cache_enabled
)
else:
return suppress
def random_seed(seed=42, rank=0):
"""Seed everything"""
torch.manual_seed(seed + rank)
torch.cuda.manual_seed(seed + rank)
np.random.seed(seed + rank)
random.seed(seed + rank)
def unwrap_model(model):
"""
Unwrap a model from a DataParallel or DistributedDataParallel wrapper.
"""
if isinstance(model, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
return model.module
else:
return model
################################
# Helper functions for logging #
################################
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def compute_throughput(
args,
datasets,
batch_metadata,
step_time_m,
):
"""
Computes throughput metrics for logging, including samples per second and tokens per second.
"""
log = {}
for dataset in datasets:
log[f"{dataset.name}_samples_per_second_per_gpu"] = (
args.gradient_accumulation_steps * dataset.batch_size / step_time_m.val
)
log[f"{dataset.name}_samples_per_second"] = (
log[f"{dataset.name}_samples_per_second_per_gpu"] * args.world_size
)
log[f"{dataset.name}_tokens_per_second_per_gpu"] = (
args.gradient_accumulation_steps
* batch_metadata[f"{dataset.name}_num_tokens"]
/ step_time_m.val
)
log[f"{dataset.name}_tokens_per_second"] = (
log[f"{dataset.name}_tokens_per_second_per_gpu"] * args.world_size
) # this is an estimate based on rank 0
log[f"{dataset.name}_images_per_second_per_gpu"] = (
args.gradient_accumulation_steps
* batch_metadata[f"{dataset.name}_num_images"]
/ step_time_m.val
)
log[f"{dataset.name}_images_per_second"] = (
log[f"{dataset.name}_images_per_second_per_gpu"] * args.world_size
) # this is an estimate based on rank 0
return log
####################################################
# Helper functions for checkpoint loading / saving #
####################################################
def find_most_recent_checkpoint(args):
"""
Returns the path of the most recent checkpoint for a given run name.
"""
checkpoint_list = glob.glob(f"{args.run_name}/checkpoint_*.pt")
if len(checkpoint_list) == 0:
print(f"Found no checkpoints for run {args.run_name}.")
resume_from_checkpoint = None
else:
resume_from_checkpoint = sorted(
checkpoint_list, key=lambda x: int(x.split("_")[-1].split(".")[0])
)[-1]
print(f"Found checkpoint {resume_from_checkpoint} for run {args.run_name}.")
return resume_from_checkpoint
def load_checkpoint(args, model, pretrained=False):
"""
Loads a checkpoint into the model and returns the checkpoint + epoch to resume from.
Does not load the optimizer or learning rate checkpoints, but these are included in the returned checkpoint dict.
"""
if pretrained:
ckpt_path = args.pretrained
else:
ckpt_path = args.resume_from_checkpoint
if args.rank == 0:
print(f"Loading checkpoint from {ckpt_path}")
checkpoint = torch.load(ckpt_path, map_location="cpu")
# msd = checkpoint.pop("model_state_dict")
if "model_state_dict" in checkpoint:
msd = checkpoint.pop("model_state_dict")
else:
print("No 'model_state_dict' found. Using entire checkpoint as model state dict.")
msd = checkpoint
msd = {k.replace("module.", ""): v for k, v in msd.items()}
if 'vision_tokenizer.latents' in msd.keys():
msd_current = model.state_dict()
if msd_current['vision_tokenizer.latents'].shape != msd['vision_tokenizer.latents'].shape:
msd["vision_tokenizer.latents"] = msd_current['vision_tokenizer.latents'] # Random re-init.
# remove any module with vision_encoder in the name
# msd = {k: v for k, v in msd.items() if "vision_encoder" not in k}
if not pretrained:
resume_from_epoch = checkpoint["epoch"] + 1
else:
resume_from_epoch = None
if 'step' in checkpoint and checkpoint["step"] is not None:
resume_from_step = checkpoint["step"] + 1
resume_from_epoch = checkpoint["epoch"] # Resume from prev epoch at the given step.
else:
resume_from_step = 0
if args.fsdp:
FSDP.set_state_dict_type(
model,
**args.fsdp_checkpoint_config,
)
result = model.load_state_dict(msd, strict=False)
# Print missing and unexpected keys
print("Missing keys:", result.missing_keys)
print("Unexpected keys:", result.unexpected_keys)
return resume_from_epoch, resume_from_step, checkpoint
def filter_state_dict_to_trainable(model, state_dict):
"""
Remove non-trainable parameters from model state dict.
Exception: Embeddings will not be removed, even if frozen.
This is because we need the new <image> <|endofchunk|> tokens to
be consistent across initializations.
"""
# first, remove frozen params
for name, p in model.named_parameters():
if "fsdp" in name:
continue
if not p.requires_grad:
name = name.replace("._checkpoint_wrapped_module", "")
if name in state_dict:
del state_dict[name]
else:
print(f"WARNING: filtering but {name} not in state_dict")
# second, remove additional duplicate params
duplicate = lambda k: (
"lang_model.old_decoder_blocks" in k
or "lang_model.gated_cross_attn_layers" in k
)
filtered_dict = {
key: value for key, value in state_dict.items() if not duplicate(key)
}
return filtered_dict
def save_checkpoint(model, optimizer, lr_scheduler, epoch, args, step=None):
"""
Save training checkpoint with model, optimizer, and lr_scheduler state.
"""
torch.cuda.empty_cache() # (Sometimes this is necessary to avoid OOM errors when saving checkpoints)
if args.fsdp:
FSDP.set_state_dict_type(
model,
**args.fsdp_checkpoint_config,
)
model_state = model.state_dict()
optim_state = FSDP.optim_state_dict(model, optimizer)
else:
model_state = model.state_dict()
optim_state = optimizer.state_dict()
if args.rank == 0:
model_state = filter_state_dict_to_trainable(model, model_state)
if not os.path.exists(args.run_name):
os.makedirs(args.run_name)
checkpoint_dict = {
"epoch": epoch,
"step": step,
"model_state_dict": model_state,
"optimizer_state_dict": optim_state,
"lr_scheduler_state_dict": lr_scheduler.state_dict(),
}
if args.no_save_optim_state and step is None:
del checkpoint_dict['optimizer_state_dict']
del checkpoint_dict['lr_scheduler_state_dict']
if step is not None:
save_name = f"{args.run_name}/checkpoint_{step}.pt"
else:
save_name = f"{args.run_name}/checkpoint_{epoch}.pt"
print(f"Saving checkpoint to {save_name}")
torch.save(checkpoint_dict, save_name)
if args.report_to_wandb and args.save_checkpoints_to_wandb:
wandb.save(f"{save_name}")
if args.delete_previous_checkpoint:
if epoch > 0:
os.remove(f"{args.run_name}/checkpoint_{epoch-1}.pt")
else:
checkpoint_list = glob.glob(f"{args.run_name}/checkpoint_*.pt")
if len(checkpoint_list) > 1:
last_checkpoint = sorted(
checkpoint_list, key=lambda x: int(x.split("_")[-1].split(".")[0])
)[0]
os.remove(f"{last_checkpoint}")
import torch
import os
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com' # hf镜像源
from pathlib import Path
import argparse
from omegaconf import OmegaConf
import torch
from transformers import AutoModelForVision2Seq, AutoTokenizer, AutoImageProcessor
from open_flamingo import create_model_and_transforms
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--dest_fn",
type=str,
default="/blip-3_pytorch/pretrain_model/xgen-mm-phi3-mini-base-r-v1.5.pt",
)
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
# Load model from HF hub.
#model_name_or_path = "/blip-3/pretrain_model/xgen-mm-phi3-mini-base-r-v1.5/"
model_name_or_path = "Salesforce/xgen-mm-phi3-mini-base-r-v1.5"
model = AutoModelForVision2Seq.from_pretrained(
model_name_or_path, trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path, trust_remote_code=True, use_fast=True, legacy=False
)
image_processor = AutoImageProcessor.from_pretrained(
model_name_or_path, trust_remote_code=True
)
tokenizer = model.update_special_tokens(tokenizer)
# Test weight loading.
# Set local model configs.
cfg = dict(
model_family="xgenmm_v1",
lm_path="microsoft/Phi-3-mini-4k-instruct",
vision_encoder_path="google/siglip-so400m-patch14-384",
vision_encoder_pretrained="google",
num_vision_tokens=128,
image_aspect_ratio="anyres",
anyres_patch_sampling=True,
anyres_grids=[(1, 2), (2, 1), (2, 2), (3, 1), (1, 3)],
)
cfg = OmegaConf.create(cfg)
additional_kwargs = {
"num_vision_tokens": cfg.num_vision_tokens,
"image_aspect_ratio": cfg.image_aspect_ratio,
"anyres_patch_sampling": cfg.anyres_patch_sampling,
}
# Initialize the model.
local_model, _, _ = create_model_and_transforms(
clip_vision_encoder_path=cfg.vision_encoder_path,
clip_vision_encoder_pretrained=cfg.vision_encoder_pretrained,
lang_model_path=cfg.lm_path,
tokenizer_path=cfg.lm_path,
model_family=cfg.model_family,
**additional_kwargs,
)
try:
local_model.load_state_dict(model.vlm.state_dict(), strict=True)
print("Testing weight loading OK.")
except Exception as e:
print(e)
# Export model weight.
print(f"Saving converted model weight to {args.dest_fn}")
Path(args.dest_fn).parent.mkdir(parents=True, exist_ok=True)
torch.save(model.vlm.state_dict(), args.dest_fn)
#IMAGE_FOLDER_DICT = {
IMAGE_FOLDER_DICT_GCP = {
#"LLaVA-Pretrain": "/public/opendas/DL_DATA/LLaVA-Pretrain",
# "ai2d": "/export/home/blip3_data/ocr_datasets/ai2d",
# "dvqa": "/export/home/blip3_data/ocr_datasets/DVQA",
# "docvqa": "/export/home/blip3_data/ocr_datasets/DocVQA", # Put this before vg, bc docvqa files contain characters.
# "ChartQA_Dataset": "/export/home/blip3_data/ocr_datasets/chartQA/ChartQA_Dataset",
"coco/som_train2017": "/blip-3_pytorch/dataset/SoM-LLaVA/som_train2017",
# "coco/train2017": "/export/home/blip3_data/coco/images/train2017",
# "ocr_vqa": "/export/home/blip3_data/ocr_vqa",
# "vg": "/export/home/blip3_data/visual-genome",
# "gqa": "/export/home/blip3_data/GQA",
# "share_textvqa": "/export/home/blip3_data/share_textvqa", # Put this before the substring below.
# "textvqa": "/export/home/blip3_data/TextVQA",
# 'wikiart': "/export/home/blip3_data/wikiart",
# 'sam/images': '/export/home/blip3_data/sam/images',
# "web-celebrity": "/export/home/blip3_data/web-celebrity",
# "web-landmark": "/export/home/blip3_data/web-landmark",
# "llava/llava_pretrain": "/export/home/blip3_data/llava/llava_pretrain",
# "train2017": "/export/home/blip3_data/coco/images/train2017",
}
# Data args.
# Note: this is an example data config, not for reproducing xgen-mm-instruct.
data_path: {
#'/blip-3/dataset/blip_laion_cc_sbu_558k_fixed.json': 558128
#'/blip-3/dataset/LLaVA-Pretrain/blip_laion_cc_sbu_558k_fixed.json': 558128
# '/mnt/xgen-mm/LLaVA-Pretrain/llava_all_path.json': 558128
# # Llava-665K
# '/export/home/blip3_data/llava_instruct_665k_sharegpt4v/annotations/sharegpt4v_mix665k_cap23k_coco-ap9k_lcs3k_sam9k_div2k.json': 665058, # Total: 665058,
# SoM-llava.
'/blip-3_pytorch/dataset/SoM-LLaVA/som_qa_coco20k.json': 20160,
'/blip-3_pytorch/dataset/SoM-LLaVA/som_listing_coco10k.json': 10000,
# # Text-only. (37k)
# # '/export/share/manlis/data/allava-4v/Evol-Instruct-GPT4-Turbo-143K-filterd.json': 20000, # Total: 143000
# '/export/home/blip3_data/text-only-sft-data/Python-Code-23k-ShareGPT.json': 10000, # Total 22608
# '/export/home/blip3_data/text-only-sft-data/gsm8k-main-train.json': 7473,
# '/export/home/blip3_data/text-only-sft-data/slimorca-dedup.json': 10000, # Total: 363491
# '/export/home/blip3_data/text-only-sft-data/orca-math-word-problems-200k.json': 10000, # Total: 200035
# '/export/home/blip3_data/text-only-sft-data/lima-train.json': 5000, #Total: 1030
# # OCR (72k)
# '/export/home/blip3_data/ocr_datasets/ai2d/ai2d_multichoice_llava_format_single_img_token_train.json': 10000, # Total: 2482
# '/export/home/blip3_data/ocr_datasets/DVQA/dvqa_llava_format.json': 20000, # Total: 2325316
# '/export/home/blip3_data/ocr_datasets/DocVQA/docvqa_llava_format.json': 20649,
# '/export/home/blip3_data/ocr_datasets/chartQA/chartqa_train_augmented_llava_format.json': 20901,
# '/export/home/blip3_data/ocr_datasets/chartQA/chartqa_train_human_llava_format.json': 7398,
}
运行python /blip-3_pytorch/down_dataset_hf.py 从hf下载SoM-LLaVA数据集
import os
os.environ['CURL_CA_BUNDLE'] = ''
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
from huggingface_hub import hf_hub_download, snapshot_download
snapshot_download(repo_id="zzxslp/SoM-LLaVA", repo_type="dataset", local_dir='/blip-pytorch/dataset/SoM-LLaVA')
import os
os.environ['CURL_CA_BUNDLE'] = ''
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
from huggingface_hub import hf_hub_download, snapshot_download
snapshot_download(repo_id="Salesforce/xgen-mm-phi3-mini-base-r-v1.5", local_dir='/blip-3_pytorch/pretrain_model/xgen-mm-phi3-mini-base-r-v1.5')
import json
from pathlib import Path
# 1. 配置路径(仅需修改这1个参数:JSON文件的路径)
json_file = Path("/public/opendas/DL_DATA/LLaVA-Pretrain/blip_laion_cc_sbu_558k.json")
# 自动获取JSON所在目录(即图片根目录,无需手动改!)
image_root_dir = json_file.parent # 结果:/blip-3/dataset/LLaVA-Pretrain/
# 修复后的JSON保存路径(在原路径后加"_fixed",避免覆盖原始文件)
fixed_json_file = "/blip-3/dataset/blip_laion_cc_sbu_558k_fixed.json"
# 2. 读取原始JSON数据
print(f"正在读取原始JSON:{json_file}")
with open(json_file, "r", encoding="utf-8") as f:
data = json.load(f)
if not isinstance(data, list):
raise ValueError("JSON文件内容必须是列表格式(每个元素为一个样本)")
print(f"成功读取 {len(data)} 个样本")
# 3. 修复每个样本的图片路径(核心逻辑)
fixed_count = 0 # 成功修复的样本数
missing_count = 0 # 原始路径缺失的样本数
for idx, sample in enumerate(data):
# 从样本中获取图片相对路径(常见字段名:image、file_path、img_path,根据你的JSON调整!)
# 先尝试"image"字段(如果你的JSON用其他字段,比如"file_path",就改成sample.get("file_path"))
relative_img_path = sample.get("image")
if not relative_img_path:
missing_count += 1
print(f"警告:第{idx}个样本缺失图片路径,将跳过")
continue
# 拼接绝对路径:JSON所在目录 + 相对路径
absolute_img_path = image_root_dir / relative_img_path
# 转换为字符串格式(避免Path对象在JSON中被序列化为字典)
sample["image"] = str(absolute_img_path)
# (可选)验证路径是否存在,提前排查无效图片
if not absolute_img_path.exists():
print(f"警告:第{idx}个样本的图片不存在 → {absolute_img_path}")
else:
fixed_count += 1
# 4. 保存修复后的JSON
with open(fixed_json_file, "w", encoding="utf-8") as f:
json.dump(data, f, indent=2, ensure_ascii=False)
# 打印修复结果
print("\n" + "="*50)
print(f"路径修复完成!")
print(f"原始样本数:{len(data)}")
print(f"成功修复(路径有效或已拼接):{fixed_count}")
print(f"缺失图片路径的样本:{missing_count}")
print(f"修复后的JSON:{fixed_json_file}")
# 打印第一个样本的路径示例,确认是否正确
if len(data) > 0 and "image" in data[0]:
print(f"示例路径(第一个样本):{data[0]['image']}")
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/export/share/anasawadalla/miniconda3/envs/xgenmm-release-clone/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"import os\n",
"\n",
"from omegaconf import OmegaConf\n",
"from functools import partial\n",
"from PIL import Image\n",
"import torch\n",
"\n",
"from open_flamingo import create_model_and_transforms \n",
"from open_flamingo.train.any_res_data_utils import process_images"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Inference code"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
"Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00, 1.60it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"xgenmm_v1 model initialized with 3,931,031,619 trainable parameters\n",
"==========Trainable Parameters\n",
"Vision encoder: 0 trainable parameters\n",
"Vision tokenizer: 109,901,568 trainable parameters\n",
"Language model: 3,821,130,051 trainable parameters\n",
"==========Total Parameters\n",
"Vision encoder: 428,225,600 parameters\n",
"Vision tokenizer: 109,901,568 parameters\n",
"Language model: 3,821,130,051 parameters\n",
"==========\n"
]
}
],
"source": [
"# Set model configs.\n",
"model_ckpt=\"path/to/your/local/checkpoint.pt\"\n",
"cfg = dict(\n",
" model_family = 'xgenmm_v1',\n",
" lm_path = 'microsoft/Phi-3-mini-4k-instruct',\n",
" vision_encoder_path = 'google/siglip-so400m-patch14-384',\n",
" vision_encoder_pretrained = 'google',\n",
" num_vision_tokens = 128,\n",
" image_aspect_ratio = 'anyres',\n",
" anyres_patch_sampling = True,\n",
" anyres_grids = [(1,2),(2,1),(2,2),(3,1),(1,3)],\n",
" ckpt_pth = model_ckpt,\n",
")\n",
"cfg = OmegaConf.create(cfg)\n",
"\n",
"additional_kwargs = {\n",
" \"num_vision_tokens\": cfg.num_vision_tokens,\n",
" \"image_aspect_ratio\": cfg.image_aspect_ratio,\n",
" \"anyres_patch_sampling\": cfg.anyres_patch_sampling,\n",
"}\n",
"\n",
"# Initialize the model.\n",
"model, image_processor, tokenizer = create_model_and_transforms(\n",
" clip_vision_encoder_path=cfg.vision_encoder_path,\n",
" clip_vision_encoder_pretrained=cfg.vision_encoder_pretrained,\n",
" lang_model_path=cfg.lm_path,\n",
" tokenizer_path=cfg.lm_path,\n",
" model_family=cfg.model_family,\n",
" **additional_kwargs)\n",
"\n",
"ckpt = torch.load(cfg.ckpt_pth)[\"model_state_dict\"]\n",
"model.load_state_dict(ckpt, strict=True)\n",
"torch.cuda.empty_cache()\n",
"model = model.eval().cuda()\n",
"\n",
"base_img_size = model.base_img_size\n",
"anyres_grids = []\n",
"for (m,n) in cfg.anyres_grids:\n",
" anyres_grids.append([base_img_size*m, base_img_size*n])\n",
"model.anyres_grids = anyres_grids"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"# Preprocessing utils.\n",
"\n",
"image_proc = partial(process_images, image_processor=image_processor, model_cfg=cfg)\n",
"\n",
"def apply_prompt_template(prompt, cfg):\n",
" if 'Phi-3' in cfg.lm_path:\n",
" s = (\n",
" '<|system|>\\nA chat between a curious user and an artificial intelligence assistant. '\n",
" \"The assistant gives helpful, detailed, and polite answers to the user's questions.<|end|>\\n\"\n",
" f'<|user|>\\n{prompt}<|end|>\\n<|assistant|>\\n'\n",
" )\n",
" else:\n",
" raise NotImplementedError\n",
" return s"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"# Prep image input.\n",
"image_path_1 = 'example_images/image-1.jpeg'\n",
"image_path_2 = 'example_images/image-2.jpeg'\n",
"\n",
"image_1 = Image.open(image_path_1).convert('RGB')\n",
"image_2 = Image.open(image_path_2).convert('RGB')\n",
"images = [image_1, image_2]\n",
"image_size = [image_1.size, image_2.size]\n",
"image_size = [image_size]\n",
"vision_x = [image_proc([img]) for img in images]\n",
"vision_x = [vision_x]"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"# Prep language input.\n",
"prompt = \"Look at this image <image> and this image <image>. What is in the second image?\"\n",
"prompt = apply_prompt_template(prompt, cfg)\n",
"lang_x = tokenizer([prompt], return_tensors=\"pt\")"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/export/share/anasawadalla/miniconda3/envs/xgenmm-release-clone/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:515: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n",
" warnings.warn(\n",
"You are not running the flash-attention implementation, expect numerical differences.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"A black and white cat. \n"
]
}
],
"source": [
"# Run inference.\n",
"kwargs_default = dict(do_sample=False, temperature=0, max_new_tokens=1024, top_p=None, num_beams=1)\n",
"\n",
"generated_text = model.generate(\n",
" vision_x=vision_x, \n",
" lang_x=lang_x['input_ids'].to(torch.device('cuda:0')), \n",
" image_size=image_size,\n",
" attention_mask=lang_x['attention_mask'].to(torch.device('cuda:0')), \n",
" **kwargs_default)\n",
" \n",
"generated_text = tokenizer.decode(generated_text[0], skip_special_tokens=True)\n",
"if 'Phi-3' in cfg.lm_path:\n",
" text = generated_text.split('<|end|>')[0]\n",
"else:\n",
" text=generated_text\n",
"\n",
"print(text)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Metadata-Version: 2.4
Name: open_flamingo
Version: 2.0.1
Summary: An open-source framework for training large multimodal models
License: MIT
Keywords: machine learning
Classifier: Development Status :: 4 - Beta
Classifier: Intended Audience :: Developers
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Classifier: License :: OSI Approved :: MIT License
Classifier: Programming Language :: Python :: 3.9
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: einops
Requires-Dist: einops-exts
Requires-Dist: transformers
Requires-Dist: torch>=2.0.1
Requires-Dist: pillow
Requires-Dist: open_clip_torch>=2.16.0
Requires-Dist: sentencepiece
Provides-Extra: training
Requires-Dist: torchvision; extra == "training"
Requires-Dist: braceexpand; extra == "training"
Requires-Dist: webdataset; extra == "training"
Requires-Dist: tqdm; extra == "training"
Requires-Dist: wandb; extra == "training"
Provides-Extra: all
Requires-Dist: sentencepiece; extra == "all"
Requires-Dist: braceexpand; extra == "all"
Requires-Dist: transformers; extra == "all"
Requires-Dist: torch>=2.0.1; extra == "all"
Requires-Dist: tqdm; extra == "all"
Requires-Dist: einops; extra == "all"
Requires-Dist: webdataset; extra == "all"
Requires-Dist: torchvision; extra == "all"
Requires-Dist: pillow; extra == "all"
Requires-Dist: einops-exts; extra == "all"
Requires-Dist: open_clip_torch>=2.16.0; extra == "all"
Requires-Dist: wandb; extra == "all"
Dynamic: classifier
Dynamic: description
Dynamic: description-content-type
Dynamic: keywords
Dynamic: license
Dynamic: license-file
Dynamic: provides-extra
Dynamic: requires-dist
Dynamic: summary
# BLIP-3
## 论文
xGen-MM (BLIP-3): A Family of Open Large Multimodal Models
https://arxiv.org/pdf/2408.08872
## 模型结构
BLIP-3,也叫xGen-MM,是一个用于开发Large的框架多模态模型(lmm)。该框架包括精心准备的数据集、训练配方、模型体系结构,以及最终的lmm套件。xGen-MM是xGen-MultiModal的缩写,扩展了Salesforce xGen计划的基础人工智能模型。模型经过一系列严格的评估的任务,包括单图像和多图像基准。预训练基础模型显示出很强的情境学习能力和指令微调模型在具有相似模型大小的开源lmm中展示了优异的竞争表现。此外,模型还引入了一个安全调优模型DPO,旨在减轻幻觉等有害行为,提高安全性。
## 环境配置
### Docker(方法一)
```
# 在光源可拉取docker镜像:
docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.1.0-py3.10-dtk24.04.3-ubuntu20.04
# 创建并启动容器:
docker run -it --network=host -v /opt/hyhal/:/opt/hyhal/:ro --shm-size=80G --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --privileged=true --device=/dev/kfd --device=/dev/dri/ --ipc=host --group-add video --privileged --name <your_proiect_name> <image_id> bash
# 安装依赖包:
python setup.py install
pip install omegaconf einops_exts ftfy transformers==4.47.0 wandb braceexpand webdataset -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com
# 安装protoc
wget https://github.com/protocolbuffers/protobuf/releases/download/v3.19.0/protoc-3.19.0-linux-x86_64.zip
unzip protoc-3.19.0-linux-x86_64.zip -d $HOME/protoc
echo 'export PATH="$HOME/protoc/bin:$PATH"' >> ~/.bashrc
source ~/.bashrc
pip install protobuf==3.19.0
```
### Dockerfile(方法二)
```
docker build --no-cache -t blip3:latest .
docker run -it --network=host --name=blip3 --privileged --device=/dev/kfd --device=/dev/dri --ipc=host --shm-size=16G --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -u root --ulimit stack=-1:-1 --ulimit memlock=-1:-1 -v /opt/hyhal/:/opt/hyhal/:ro -v /usr/local/hyhal:/usr/local/hyhal:ro blip3:latest bash
安装依赖:
python setup.py install
pip install omegaconf einops_exts ftfy transformers==4.47.0 wandb braceexpand webdataset -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com
```
### Anaconda(方法三)
```
1.创建conda虚拟环境:
conda create -n blip3 python=3.10
2.关于本项目DCU显卡所需的工具包、深度学习库等均可从光合开发者社区下载安装:https://developer.hpccube.com/tool/
DTK驱动:dtk24.04.3
python:python3.10
torch:2.1.0
```
Tips:以上DTK、python、torch等DCU相关工具包,版本需要严格一一对应。
```
3.其它非特殊库参照requirements.txt安装
pip install -r requirements-training.txt -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com
python setup.py install
pip install omegaconf einops_exts ftfy transformers==4.47.0 wandb braceexpand webdataset -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com
```
# 训练
## 数据集
模型支持llava格式的json数据集文件,json文件结构如下。参考数据集[llava_pretrain](http://113.200.138.88:18080/aimodels/llava_pretrain)。您可以放置多个不同的数据集。
接着您需要配置[`data/example_data_config.yaml`](./data_configs/example_data_config.yaml)文件,包括所有json文件路径和图片数量。如果您的json文件内是数据的相对路径,则还需要配置路径映射文件[`data/data_paths.py`](./data/data_paths.py)。
```
yaml文件:
data_path: {
'/path/to/llava_pretrain.json': 558128
'/path/to/som_qa_coco20k.json': 20160,
'/path/to/som_listing_coco10k.json': 10000,
}
```
```
json文件:
{
"id": "000000033471",
"image": "coco/train2017/000000033471.jpg",
"conversations": [
{
"from": "human",
"value": "<image>\nWhat are the colors of the bus in the image?"
},
{
"from": "gpt",
"value": "The bus in the image is white and red."
},
...
]
}
```
LLaVA-Pretrain数据集目录结构如下:
```
/path/to/LLaVA-Pretrain/
├── blip_laion_cc_sbu_558k.json
├── 00000
│ ├── 000000010.jpg
│ ├── 000000012.jpg
│ └── ...
├── 00001
├── 00002
└── ...
```
## 微调
#### 预训练权重
可从scnet快速[下载链接](http://113.200.138.88:18080/aimodels/xgen-mm-phi3-mini-base-r-v1.5)获取预训练模型`xgen-mm-phi3-mini-base-r-v1.5`
并运行如下脚本生成pytorch原生格式pt文件:
```
# 修改dest_fn参数为保存路径和pt文件名,以及修改model_name_or_path为预训练模型权重路径
python convert_hf_model.py
```
#### 单机多卡
```
bash scripts/finetune.sh
```
训练脚本参数说明如下
* `exp_name`: 训练日志文件名
* `data_path`: yaml文件路径
* `pretrained_ckpt`: pt文件路径
* `--nproc_per_node=2`: 多卡训练的卡数
* `--nnodes=1`: 节点数
* `--master_port 9650`: 端口
* `--lm_path`: 语言模型(LM)的路径,默认"microsoft/Phi-3-mini-4k-instruct"
* `--tokenizer_path`: 分词器的路径,用于处理文本数据,默认"microsoft/Phi-3-mini-4k-instruct"
* `--vision_encoder_path`: 视觉编码器,默认"google/siglip-so400m-patch14-384"
## result
### 应用场景
### 算法类别
图生文
### 热点应用行业
AIGC,设计
## 源码仓库及问题反馈
- https://developer.sourcefind.cn/codes/modelzoo/blip-3
## 参考资料
- https://github.com/salesforce/LAVIS/tree/xgen-mm
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment