# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. from PIL import Image, UnidentifiedImageError import numpy as np import io import torch try: from torchvision.transforms import InterpolationMode BICUBIC = InterpolationMode.BICUBIC except ImportError: BICUBIC = Image.BICUBIC from torchvision.transforms import Compose, ToTensor, Normalize, ToPILImage, RandomResizedCrop, Resize def _convert_image_to_rgb(image): return image.convert("RGB") def _transform(img_h, img_w): return Compose([ ToPILImage(), RandomResizedCrop((img_h, img_w), scale=(0.5, 1.0), interpolation=BICUBIC), _convert_image_to_rgb, ToTensor(), Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), ]) class MultiModalDataset(torch.utils.data.Dataset): def __init__(self, name, data_prefix, indexed_dataset, num_samples, seq_length, seed, img_h, img_w): self.name = name self.indexed_dataset = indexed_dataset self.doc_idx = indexed_dataset.get_document_indices() self.visual_transform = _transform(img_h, img_w) def __len__(self): return self.indexed_dataset.sequence_lengths.shape[0] def __getitem__(self, idx): text_sample, mode = self.indexed_dataset.get(self.doc_idx[idx]) assert mode == 0 img_sample, mode = self.indexed_dataset.get(self.doc_idx[idx]+1) assert mode == 1 img_pad = img_sample[0].item() xs = img_sample[1:].tobytes(order='C') xs = xs[:len(xs)-img_pad] img_sample = np.array(Image.open(io.BytesIO(xs))) img_sample = self.visual_transform(img_sample).reshape(-1) return {'text': np.array(text_sample, dtype=np.int64), 'img': np.array(img_sample, dtype=np.float32)}