"vscode:/vscode.git/clone" did not exist on "600ef8a4dcaa830c1dc2cb8fb7ff23f20fce5bd7"
test_dataset.py 1.53 KB
Newer Older
1
2
3
4
import os
import tempfile

import pydantic
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import pytest
from dgl import graphbolt as gb


def test_Dataset():
    dataset = gb.Dataset()
    with pytest.raises(NotImplementedError):
        _ = dataset.train_set()
    with pytest.raises(NotImplementedError):
        _ = dataset.validation_set()
    with pytest.raises(NotImplementedError):
        _ = dataset.test_set()
    with pytest.raises(NotImplementedError):
        _ = dataset.graph()
    with pytest.raises(NotImplementedError):
        _ = dataset.feature()
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53


def test_OnDiskDataset_TVTSet():
    """Test OnDiskDataset with TVTSet."""
    with tempfile.TemporaryDirectory() as test_dir:
        yaml_content = """
        train_set:
          - - type_name: paper
              format: torch
              path: set/paper-train.pt
            - type_name: 'paper:cites:paper'
              format: numpy
              path: set/cites-train.pt
        """
        yaml_file = os.path.join(test_dir, "test.yaml")
        with open(yaml_file, "w") as f:
            f.write(yaml_content)
        _ = gb.OnDiskDataset(yaml_file)

        # Invalid format.
        yaml_content = """
        train_set:
          - - type_name: paper
              format: torch_invalid
              path: set/paper-train.pt
            - type_name: 'paper:cites:paper'
              format: numpy_invalid
              path: set/cites-train.pt
        """
        with open(yaml_file, "w") as f:
            f.write(yaml_content)
        with pytest.raises(pydantic.ValidationError):
            _ = gb.OnDiskDataset(yaml_file)