import os
from timm.data import create_dataset, create_loader, resolve_data_config
from torchvision.transforms import ToTensor

class CustomEfficientNetDataset:
    def __init__(self, config, is_train=True, data_root=None, transform=None):
        self.dataset = create_dataset(
            '/workspace/cmcc-infer-final/datasets/InfDataset/ImageNet2012/val', # 这里应替换为你的数据集配置文件路径或直接数据根目录
            root=data_root,
            is_training=is_train,
            transform=transform
        )
        self.transform = transform or ToTensor()
        self.data_root = data_root
        self.file_names = [os.path.basename(item[0]) for item in self.dataset.samples] if hasattr(self.dataset, 'samples') else []

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        img, target = self.dataset[index]
        img = self.transform(img)
        filename = self.file_names[index]
        return img, target, filename

# 配置和创建自定义数据集
config = {}  # 这里填写你的配置字典
data_root = '/workspace/cmcc-infer-final/datasets/InfDataset/ImageNet2012/val'  # 数据集根目录
dataset = CustomEfficientNetDataset(config=config, data_root=data_root)

# 创建DataLoader
loader = create_loader(
    dataset,
    input_size=config['input_size'],
    batch_size=32,
    use_prefetcher=True,
    interpolation='bicubic',
    mean=config['mean'],
    std=config['std'],
    num_workers=4,
)

# 使用DataLoader时获取文件名
for images, targets, filenames in loader:
    print(filenames)
