# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. from typing import Optional from torch_geometric.datasets import * from torch_geometric.data import Dataset from .pyg_dataset import GraphormerPYGDataset import torch.distributed as dist import os.path as osp import pickle import torch from torch_geometric.datasets import ZINC from torch_geometric.data import Data class MyQM7b(QM7b): def download(self): if not dist.is_initialized() or dist.get_rank() == 0: super(MyQM7b, self).download() if dist.is_initialized(): dist.barrier() def process(self): if not dist.is_initialized() or dist.get_rank() == 0: super(MyQM7b, self).process() if dist.is_initialized(): dist.barrier() class MyQM9(QM9): def download(self): if not dist.is_initialized() or dist.get_rank() == 0: super(MyQM9, self).download() if dist.is_initialized(): dist.barrier() def process(self): if not dist.is_initialized() or dist.get_rank() == 0: super(MyQM9, self).process() if dist.is_initialized(): dist.barrier() class MyZINC(ZINC): def download(self): if not dist.is_initialized() or dist.get_rank() == 0: pass if dist.is_initialized(): dist.barrier() def process(self): if not dist.is_initialized() or dist.get_rank() == 0: for i, split in enumerate(['train', 'val', 'test']): input_path = osp.join(self.raw_dir, f'{split}.pickle') with open(input_path, 'rb') as f: graphs = pickle.load(f) data_list = [] for g in graphs: x = g['atom_type'].to(torch.long).view(-1, 1) bond_info = g['bond_type'] y = g['logP_SA_cycle_normalized'].clone().detach().view(1, -1).to(torch.float) edge_index = bond_info[:, :2].t().contiguous().to(torch.long) edge_attr = bond_info[:, 2].to(torch.long) data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y) data.num_nodes = len(x) data_list.append(data) if self.pre_filter is not None: data_list = [d for d in data_list if self.pre_filter(d)] if self.pre_transform is not None: data_list = [self.pre_transform(d) for d in data_list] data, slices = self.collate(data_list) torch.save((data, slices), self.processed_paths[i]) if dist.is_initialized(): dist.barrier() class MyMoleculeNet(MoleculeNet): def download(self): if not dist.is_initialized() or dist.get_rank() == 0: super(MyMoleculeNet, self).download() if dist.is_initialized(): dist.barrier() def process(self): if not dist.is_initialized() or dist.get_rank() == 0: super(MyMoleculeNet, self).process() if dist.is_initialized(): dist.barrier() class PYGDatasetLookupTable: @staticmethod def GetPYGDataset(dataset_spec: str, seed: int) -> Optional[Dataset]: split_result = dataset_spec.split(":") if len(split_result) == 2: name, params = split_result[0], split_result[1] params = params.split(",") elif len(split_result) == 1: name = dataset_spec params = [] inner_dataset = None num_class = 1 train_set = None valid_set = None test_set = None root = "dataset" if name == "qm7b": inner_dataset = MyQM7b(root=root) elif name == "qm9": inner_dataset = MyQM9(root=root) elif name == "zinc": inner_dataset = MyZINC(root=root) train_set = MyZINC(root=root, split="train") valid_set = MyZINC(root=root, split="val") test_set = MyZINC(root=root, split="test") elif name == "moleculenet": nm = None for param in params: name, value = param.split("=") if name == "name": nm = value inner_dataset = MyMoleculeNet(root=root, name=nm) else: raise ValueError(f"Unknown dataset name {name} for pyg source.") if train_set is not None: return GraphormerPYGDataset( None, seed, None, None, None, train_set, valid_set, test_set, ) else: return ( None if inner_dataset is None else GraphormerPYGDataset(inner_dataset, seed) )