# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging from pathlib import Path import numpy as np from PIL import Image LOGGER = logging.getLogger(__name__) def get_dataloader_fn( *, data_dir: str, batch_size: int = 1, width: int = 224, height: int = 224, images_num: int = None, precision: str = "fp32", classes: int = 1000 ): def _dataloader(): image_extensions = [".gif", ".png", ".jpeg", ".jpg"] image_paths = sorted([p for p in Path(data_dir).rglob("*") if p.suffix.lower() in image_extensions]) if images_num is not None: image_paths = image_paths[:images_num] LOGGER.info( f"Creating PIL dataloader on data_dir={data_dir} #images={len(image_paths)} " f"image_size=({width}, {height}) batch_size={batch_size}" ) onehot = np.eye(classes) batch = [] for image_path in image_paths: img = Image.open(image_path.as_posix()).convert("RGB") img = img.resize((width, height)) img = (np.array(img).astype(np.float32) / 255) - np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(1, 1, 3) img = img / np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(1, 1, 3) true_class = np.array([int(image_path.parent.name)]) assert tuple(img.shape) == (height, width, 3) img = img[np.newaxis, ...] batch.append((img, image_path.as_posix(), true_class)) if len(batch) >= batch_size: ids = [image_path for _, image_path, *_ in batch] x = {"INPUT__0": np.ascontiguousarray( np.transpose(np.concatenate([img for img, *_ in batch]), (0, 3, 1, 2)).astype(np.float32 if precision == "fp32" else np.float16))} y_real = {"OUTPUT__0": onehot[np.concatenate([class_ for *_, class_ in batch])].astype( np.float32 if precision == "fp32" else np.float16 )} batch = [] yield ids, x, y_real return _dataloader