Unverified Commit f247d29f authored by Jinjing Zhou's avatar Jinjing Zhou Committed by GitHub
Browse files

Minor fix to DGL Enter (#3753)



* [Fix] Convert float64 to float32 when creating tensor

* fix
Co-authored-by: default avatarRhettYing <rhett_ying@qq.com>
Co-authored-by: default avatarRhett Ying <85214957+Rhett-Ying@users.noreply.github.com>
parent 5558ce29
...@@ -161,9 +161,9 @@ DataFactory.register( ...@@ -161,9 +161,9 @@ DataFactory.register(
DataFactory.register( DataFactory.register(
"csv", "csv",
import_code="from dgl.data import DGLCSVDataset", import_code="from dgl.data import CSVDataset",
extra_args={"data_path": "./"}, extra_args={"data_path": "./"},
class_name="DGLCSVDataset({})", class_name="CSVDataset({})",
allowed_pipeline=["nodepred", "nodepred-ns", "linkpred"]) allowed_pipeline=["nodepred", "nodepred-ns", "linkpred"])
DataFactory.register( DataFactory.register(
......
...@@ -76,7 +76,7 @@ class AsNodePredDataset(DGLDataset): ...@@ -76,7 +76,7 @@ class AsNodePredDataset(DGLDataset):
self.split_ratio = split_ratio self.split_ratio = split_ratio
self.target_ntype = target_ntype self.target_ntype = target_ntype
super().__init__(self.dataset.name + '-as-nodepred', super().__init__(self.dataset.name + '-as-nodepred',
hash_key=(split_ratio, target_ntype), **kwargs) hash_key=(split_ratio, target_ntype, dataset.name, 'nodepred'), **kwargs)
def process(self): def process(self):
is_ogb = hasattr(self.dataset, 'get_idx_split') is_ogb = hasattr(self.dataset, 'get_idx_split')
...@@ -211,7 +211,7 @@ class AsLinkPredDataset(DGLDataset): ...@@ -211,7 +211,7 @@ class AsLinkPredDataset(DGLDataset):
Dataset("cora_v2", num_graphs=1, save_path=...) Dataset("cora_v2", num_graphs=1, save_path=...)
>>> new_ds = dgl.data.AsNodePredDataset(ds, [0.8, 0.1, 0.1]) >>> new_ds = dgl.data.AsNodePredDataset(ds, [0.8, 0.1, 0.1])
>>> print(new_ds) >>> print(new_ds)
Dataset("cora_v2-as-edgepred", num_graphs=1, save_path=/home/ubuntu/.dgl/cora_v2-as-edgepred) Dataset("cora_v2-as-linkpred", num_graphs=1, save_path=/home/ubuntu/.dgl/cora_v2-as-linkpred)
>>> print(hasattr(new_ds, "get_test_edges")) >>> print(hasattr(new_ds, "get_test_edges"))
True True
""" """
...@@ -226,8 +226,8 @@ class AsLinkPredDataset(DGLDataset): ...@@ -226,8 +226,8 @@ class AsLinkPredDataset(DGLDataset):
self.dataset = dataset self.dataset = dataset
self.split_ratio = split_ratio self.split_ratio = split_ratio
self.neg_ratio = neg_ratio self.neg_ratio = neg_ratio
super().__init__(dataset.name + '-as-edgepred', super().__init__(dataset.name + '-as-linkpred',
hash_key=(neg_ratio, split_ratio), **kwargs) hash_key=(neg_ratio, split_ratio, dataset.name, 'linkpred'), **kwargs)
def process(self): def process(self):
if self.split_ratio is None: if self.split_ratio is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment