"git@developer.sourcefind.cn:change/sglang.git" did not exist on "183df4728260a1469612f848f980cc71266591b9"
create_chunked_dataset.py 8.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import argparse
import json
import logging
import os
import platform
import sys
import tempfile

import dgl
import numpy as np
import torch
from chunk_graph import chunk_graph
from dgl.data.utils import load_graphs, load_tensors


16

17
def create_chunked_dataset(
18
    root_dir, num_chunks, include_masks=False, data_fmt='numpy'
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
):
    """
    This function creates a sample dataset, based on MAG240 dataset.

    Parameters:
    -----------
    root_dir : string
        directory in which all the files for the chunked dataset will be stored.
    """
    # Step0: prepare chunked graph data format.
    # A synthetic mini MAG240.
    num_institutions = 1200
    num_authors = 1200
    num_papers = 1200

    def rand_edges(num_src, num_dst, num_edges):
        eids = np.random.choice(num_src * num_dst, num_edges, replace=False)
        src = torch.from_numpy(eids // num_dst)
        dst = torch.from_numpy(eids % num_dst)

        return src, dst

    num_cite_edges = 24 * 1000
    num_write_edges = 12 * 1000
    num_affiliate_edges = 2400

    # Structure.
    data_dict = {
        ('paper', 'cites', 'paper'): rand_edges(
            num_papers, num_papers, num_cite_edges
        ),
        ('author', 'writes', 'paper'): rand_edges(
            num_authors, num_papers, num_write_edges
        ),
        ('author', 'affiliated_with', 'institution'): rand_edges(
            num_authors, num_institutions, num_affiliate_edges
        ),
56
57
58
        ('institution', 'writes', 'paper'): rand_edges(
            num_institutions, num_papers, num_write_edges
        ),
59
60
61
62
63
64
65
66
67
68
69
70
    }
    src, dst = data_dict[('author', 'writes', 'paper')]
    data_dict[('paper', 'rev_writes', 'author')] = (dst, src)
    g = dgl.heterograph(data_dict)

    # paper feat, label, year
    num_paper_feats = 3
    paper_feat = np.random.randn(num_papers, num_paper_feats)
    num_classes = 4
    paper_label = np.random.choice(num_classes, num_papers)
    paper_year = np.random.choice(2022, num_papers)
    paper_orig_ids = np.arange(0, num_papers)
71
    writes_orig_ids = np.arange(0, num_write_edges)
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87

    # masks.
    if include_masks:
        paper_train_mask = np.random.randint(0, 2, num_papers)
        paper_test_mask = np.random.randint(0, 2, num_papers)
        paper_val_mask = np.random.randint(0, 2, num_papers)

        author_train_mask = np.random.randint(0, 2, num_authors)
        author_test_mask = np.random.randint(0, 2, num_authors)
        author_val_mask = np.random.randint(0, 2, num_authors)

        inst_train_mask = np.random.randint(0, 2, num_institutions)
        inst_test_mask = np.random.randint(0, 2, num_institutions)
        inst_val_mask = np.random.randint(0, 2, num_institutions)

    # Edge features.
88
89
    cite_count = np.random.choice(10, num_cite_edges)
    write_year = np.random.choice(2022, num_write_edges)
90
    write2_year = np.random.choice(2022, num_write_edges)
91
92
93
94

    # Save features.
    input_dir = os.path.join(root_dir, 'data_test')
    os.makedirs(input_dir)
95
    for sub_d in ['paper', 'cites', 'writes', 'writes2']:
96
97
98
99
100
        os.makedirs(os.path.join(input_dir, sub_d))

    paper_feat_path = os.path.join(input_dir, 'paper/feat.npy')
    with open(paper_feat_path, 'wb') as f:
        np.save(f, paper_feat)
101
    g.nodes['paper'].data['feat'] = torch.from_numpy(paper_feat)
102
103
104
105

    paper_label_path = os.path.join(input_dir, 'paper/label.npy')
    with open(paper_label_path, 'wb') as f:
        np.save(f, paper_label)
106
    g.nodes['paper'].data['label'] = torch.from_numpy(paper_label)
107
108
109
110

    paper_year_path = os.path.join(input_dir, 'paper/year.npy')
    with open(paper_year_path, 'wb') as f:
        np.save(f, paper_year)
111
    g.nodes['paper'].data['year'] = torch.from_numpy(paper_year)
112
113
114
115

    paper_orig_ids_path = os.path.join(input_dir, 'paper/orig_ids.npy')
    with open(paper_orig_ids_path, 'wb') as f:
        np.save(f, paper_orig_ids)
116
    g.nodes['paper'].data['orig_ids'] = torch.from_numpy(paper_orig_ids)
117

118
119
120
    cite_count_path = os.path.join(input_dir, 'cites/count.npy')
    with open(cite_count_path, 'wb') as f:
        np.save(f, cite_count)
121
    g.edges['cites'].data['count'] = torch.from_numpy(cite_count)
122

123
124
125
    write_year_path = os.path.join(input_dir, 'writes/year.npy')
    with open(write_year_path, 'wb') as f:
        np.save(f, write_year)
126
    g.edges[('author', 'writes', 'paper')].data['year'] = torch.from_numpy(write_year)
127
128
129
130
131
    g.edges['rev_writes'].data['year'] = torch.from_numpy(write_year)

    writes_orig_ids_path = os.path.join(input_dir, 'writes/orig_ids.npy')
    with open(writes_orig_ids_path, 'wb') as f:
        np.save(f, writes_orig_ids)
132
133
134
135
136
137
    g.edges[('author', 'writes', 'paper')].data['orig_ids'] = torch.from_numpy(writes_orig_ids)

    write2_year_path = os.path.join(input_dir, 'writes2/year.npy')
    with open(write2_year_path, 'wb') as f:
        np.save(f, write2_year)
    g.edges[('institution', 'writes', 'paper')].data['year'] = torch.from_numpy(write2_year)
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215

    node_data = None
    if include_masks:
        for sub_d in ['author', 'institution']:
            os.makedirs(os.path.join(input_dir, sub_d))
        paper_train_mask_path = os.path.join(input_dir, 'paper/train_mask.npy')
        with open(paper_train_mask_path, 'wb') as f:
            np.save(f, paper_train_mask)

        paper_test_mask_path = os.path.join(input_dir, 'paper/test_mask.npy')
        with open(paper_test_mask_path, 'wb') as f:
            np.save(f, paper_test_mask)

        paper_val_mask_path = os.path.join(input_dir, 'paper/val_mask.npy')
        with open(paper_val_mask_path, 'wb') as f:
            np.save(f, paper_val_mask)

        author_train_mask_path = os.path.join(
            input_dir, 'author/train_mask.npy'
        )
        with open(author_train_mask_path, 'wb') as f:
            np.save(f, author_train_mask)

        author_test_mask_path = os.path.join(input_dir, 'author/test_mask.npy')
        with open(author_test_mask_path, 'wb') as f:
            np.save(f, author_test_mask)

        author_val_mask_path = os.path.join(input_dir, 'author/val_mask.npy')
        with open(author_val_mask_path, 'wb') as f:
            np.save(f, author_val_mask)

        inst_train_mask_path = os.path.join(
            input_dir, 'institution/train_mask.npy'
        )
        with open(inst_train_mask_path, 'wb') as f:
            np.save(f, inst_train_mask)

        inst_test_mask_path = os.path.join(
            input_dir, 'institution/test_mask.npy'
        )
        with open(inst_test_mask_path, 'wb') as f:
            np.save(f, inst_test_mask)

        inst_val_mask_path = os.path.join(input_dir, 'institution/val_mask.npy')
        with open(inst_val_mask_path, 'wb') as f:
            np.save(f, inst_val_mask)

        node_data = {
            'paper': {
                'feat': paper_feat_path,
                'train_mask': paper_train_mask_path,
                'test_mask': paper_test_mask_path,
                'val_mask': paper_val_mask_path,
                'label': paper_label_path,
                'year': paper_year_path,
                'orig_ids': paper_orig_ids_path,
            },
            'author': {
                'train_mask': author_train_mask_path,
                'test_mask': author_test_mask_path,
                'val_mask': author_val_mask_path,
            },
            'institution': {
                'train_mask': inst_train_mask_path,
                'test_mask': inst_test_mask_path,
                'val_mask': inst_val_mask_path,
            },
        }
    else:
        node_data = {
            'paper': {
                'feat': paper_feat_path,
                'label': paper_label_path,
                'year': paper_year_path,
                'orig_ids': paper_orig_ids_path,
            }
        }

216
217
    edge_data = {
        'cites': {'count': cite_count_path},
218
        ('author', 'writes', 'paper'): {
219
220
221
            'year': write_year_path,
            'orig_ids': writes_orig_ids_path
        },
222
        'rev_writes': {'year': write_year_path},
223
224
225
        ('institution', 'writes', 'paper'): {
            'year': write2_year_path,
        },
226
    }
227
228
229
230
231
232
233
234
235

    output_dir = os.path.join(root_dir, 'chunked-data')
    chunk_graph(
        g,
        'mag240m',
        node_data,
        edge_data,
        num_chunks=num_chunks,
        output_path=output_dir,
236
        data_fmt=data_fmt,
237
238
239
240
    )
    print('Done with creating chunked graph')

    return g