Unverified Commit 0b00b581 authored by Mingbang Wang's avatar Mingbang Wang Committed by GitHub
Browse files

[Misc] Polish the python code (#6443)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-39-125.ap-northeast-1.compute.internal>
parent 406d621f
......@@ -151,7 +151,7 @@ class CSCSamplingGraph(SamplingGraph):
>>> graph = gb.from_csc(indptr, indices, node_type_offset,
... type_per_edge, None, metadata)
>>> print(graph.num_nodes)
{'N0': tensor(2), 'N1': tensor(3)}
{'N0': 2, 'N1': 3}
"""
offset = self.node_type_offset
......@@ -163,7 +163,7 @@ class CSCSamplingGraph(SamplingGraph):
# Heterogenous
else:
num_nodes_per_type = {
_type: offset[_idx + 1] - offset[_idx]
_type: (offset[_idx + 1] - offset[_idx]).item()
for _type, _idx in self.metadata.node_type_to_id.items()
}
......
......@@ -3,7 +3,7 @@
import os
import shutil
from copy import deepcopy
from typing import Dict, List
from typing import Dict, List, Union
import pandas as pd
import torch
......@@ -235,9 +235,9 @@ class OnDiskTask:
def __init__(
self,
metadata: Dict,
train_set: ItemSet or ItemSetDict,
validation_set: ItemSet or ItemSetDict,
test_set: ItemSet or ItemSetDict,
train_set: Union[ItemSet, ItemSetDict],
validation_set: Union[ItemSet, ItemSetDict],
test_set: Union[ItemSet, ItemSetDict],
):
"""Initialize a task.
......@@ -245,11 +245,11 @@ class OnDiskTask:
----------
metadata : Dict
Metadata.
train_set : ItemSet or ItemSetDict
train_set : Union[ItemSet, ItemSetDict]
Training set.
validation_set : ItemSet or ItemSetDict
validation_set : Union[ItemSet, ItemSetDict]
Validation set.
test_set : ItemSet or ItemSetDict
test_set : Union[ItemSet, ItemSetDict]
Test set.
"""
self._metadata = metadata
......@@ -263,17 +263,17 @@ class OnDiskTask:
return self._metadata
@property
def train_set(self) -> ItemSet or ItemSetDict:
def train_set(self) -> Union[ItemSet, ItemSetDict]:
"""Return the training set."""
return self._train_set
@property
def validation_set(self) -> ItemSet or ItemSetDict:
def validation_set(self) -> Union[ItemSet, ItemSetDict]:
"""Return the validation set."""
return self._validation_set
@property
def test_set(self) -> ItemSet or ItemSetDict:
def test_set(self) -> Union[ItemSet, ItemSetDict]:
"""Return the test set."""
return self._test_set
......@@ -446,7 +446,7 @@ class OnDiskDataset(Dataset):
def _init_tvt_set(
self, tvt_set: List[OnDiskTVTSet]
) -> ItemSet or ItemSetDict:
) -> Union[ItemSet, ItemSetDict]:
"""Initialize the TVT set."""
ret = None
if (tvt_set is None) or (len(tvt_set) == 0):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment