deepwalk.py 14.6 KB
Newer Older
1
2
3
4
5
import argparse
import os
import random
import time

Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
6
7
import dgl

8
9
10
import numpy as np
import torch
import torch.multiprocessing as mp
11
from model import SkipGramModel
12
13
from reading_data import DeepwalkDataset
from torch.utils.data import DataLoader
14
from utils import shuffle_walks, sum_up_params
15

16

17
18
class DeepwalkTrainer:
    def __init__(self, args):
19
        """Initializing the trainer with the input arguments"""
20
21
22
23
24
25
26
27
28
29
30
        self.args = args
        self.dataset = DeepwalkDataset(
            net_file=args.data_file,
            map_file=args.map_file,
            walk_length=args.walk_length,
            window_size=args.window_size,
            num_walks=args.num_walks,
            batch_size=args.batch_size,
            negative=args.negative,
            gpus=args.gpus,
            fast_neg=args.fast_neg,
31
32
            ogbl_name=args.ogbl_name,
            load_from_ogbl=args.load_from_ogbl,
33
        )
34
        self.emb_size = self.dataset.G.number_of_nodes()
35
36
37
        self.emb_model = None

    def init_device_emb(self):
38
        """set the device before training
39
40
41
        will be called once in fast_train_mp / fast_train
        """
        choices = sum([self.args.only_gpu, self.args.only_cpu, self.args.mix])
42
43
44
45
        assert (
            choices == 1
        ), "Must choose only *one* training mode in [only_cpu, only_gpu, mix]"

46
47
        # initializing embedding on CPU
        self.emb_model = SkipGramModel(
48
            emb_size=self.emb_size,
49
50
51
52
53
54
55
56
57
58
59
60
61
            emb_dimension=self.args.dim,
            walk_length=self.args.walk_length,
            window_size=self.args.window_size,
            batch_size=self.args.batch_size,
            only_cpu=self.args.only_cpu,
            only_gpu=self.args.only_gpu,
            mix=self.args.mix,
            neg_weight=self.args.neg_weight,
            negative=self.args.negative,
            lr=self.args.lr,
            lap_norm=self.args.lap_norm,
            fast_neg=self.args.fast_neg,
            record_loss=self.args.print_loss,
62
63
            norm=self.args.norm,
            use_context_weight=self.args.use_context_weight,
64
65
            async_update=self.args.async_update,
            num_threads=self.args.num_threads,
66
67
        )

68
69
70
71
72
73
74
75
        torch.set_num_threads(self.args.num_threads)
        if self.args.only_gpu:
            print("Run in 1 GPU")
            assert self.args.gpus[0] >= 0
            self.emb_model.all_to_device(self.args.gpus[0])
        elif self.args.mix:
            print("Mix CPU with %d GPU" % len(self.args.gpus))
            if len(self.args.gpus) == 1:
76
77
78
                assert (
                    self.args.gpus[0] >= 0
                ), "mix CPU with GPU should have available GPU"
79
                self.emb_model.set_device(self.args.gpus[0])
80
81
        else:
            print("Run in CPU process")
82
            self.args.gpus = [torch.device("cpu")]
83
84

    def train(self):
85
        """train the embedding"""
86
87
88
89
90
91
        if len(self.args.gpus) > 1:
            self.fast_train_mp()
        else:
            self.fast_train()

    def fast_train_mp(self):
92
        """multi-cpu-core or mix cpu & multi-gpu"""
93
94
95
        self.init_device_emb()
        self.emb_model.share_memory()

96
97
98
        if self.args.count_params:
            sum_up_params(self.emb_model)

99
100
101
102
        start_all = time.time()
        ps = []

        for i in range(len(self.args.gpus)):
103
104
105
            p = mp.Process(
                target=self.fast_train_sp, args=(i, self.args.gpus[i])
            )
106
107
108
109
110
            ps.append(p)
            p.start()

        for p in ps:
            p.join()
111
112

        print("Used time: %.2fs" % (time.time() - start_all))
113
        if self.args.save_in_txt:
114
115
116
            self.emb_model.save_embedding_txt(
                self.dataset, self.args.output_emb_file
            )
117
        elif self.args.save_in_pt:
118
119
120
            self.emb_model.save_embedding_pt(
                self.dataset, self.args.output_emb_file
            )
121
        else:
122
123
124
            self.emb_model.save_embedding(
                self.dataset, self.args.output_emb_file
            )
125

126
    def fast_train_sp(self, rank, gpu_id):
127
        """a subprocess for fast_train_mp"""
128
129
        if self.args.mix:
            self.emb_model.set_device(gpu_id)
130

131
        torch.set_num_threads(self.args.num_threads)
132
133
        if self.args.async_update:
            self.emb_model.create_async_update()
134

135
        sampler = self.dataset.create_sampler(rank)
136
137
138
139
140
141
142

        dataloader = DataLoader(
            dataset=sampler.seeds,
            batch_size=self.args.batch_size,
            collate_fn=sampler.sample,
            shuffle=False,
            drop_last=False,
143
            num_workers=self.args.num_sampler_threads,
144
        )
145
        num_batches = len(dataloader)
146
147
148
149
        print(
            "num batchs: %d in process [%d] GPU [%d]"
            % (num_batches, rank, gpu_id)
        )
150
        # number of positive node pairs in a sequence
151
152
153
154
155
        num_pos = int(
            2 * self.args.walk_length * self.args.window_size
            - self.args.window_size * (self.args.window_size + 1)
        )

156
157
158
159
        start = time.time()
        with torch.no_grad():
            for i, walks in enumerate(dataloader):
                if self.args.fast_neg:
160
                    self.emb_model.fast_learn(walks)
161
162
163
164
                else:
                    # do negative sampling
                    bs = len(walks)
                    neg_nodes = torch.LongTensor(
165
166
167
168
169
170
                        np.random.choice(
                            self.dataset.neg_table,
                            bs * num_pos * self.args.negative,
                            replace=True,
                        )
                    )
171
                    self.emb_model.fast_learn(walks, neg_nodes=neg_nodes)
172
173
174

                if i > 0 and i % self.args.print_interval == 0:
                    if self.args.print_loss:
175
176
177
178
179
180
181
182
183
184
                        print(
                            "GPU-[%d] batch %d time: %.2fs loss: %.4f"
                            % (
                                gpu_id,
                                i,
                                time.time() - start,
                                -sum(self.emb_model.loss)
                                / self.args.print_interval,
                            )
                        )
185
186
                        self.emb_model.loss = []
                    else:
187
188
189
190
                        print(
                            "GPU-[%d] batch %d time: %.2fs"
                            % (gpu_id, i, time.time() - start)
                        )
191
192
                    start = time.time()

193
194
195
            if self.args.async_update:
                self.emb_model.finish_async_update()

196
    def fast_train(self):
197
        """fast train with dataloader with only gpu / only cpu"""
198
        # the number of postive node pairs of a node sequence
199
200
        num_pos = (
            2 * self.args.walk_length * self.args.window_size
201
            - self.args.window_size * (self.args.window_size + 1)
202
        )
203
204
205
206
        num_pos = int(num_pos)

        self.init_device_emb()

207
208
209
210
        if self.args.async_update:
            self.emb_model.share_memory()
            self.emb_model.create_async_update()

211
212
213
        if self.args.count_params:
            sum_up_params(self.emb_model)

214
215
216
217
218
219
220
221
        sampler = self.dataset.create_sampler(0)

        dataloader = DataLoader(
            dataset=sampler.seeds,
            batch_size=self.args.batch_size,
            collate_fn=sampler.sample,
            shuffle=False,
            drop_last=False,
222
            num_workers=self.args.num_sampler_threads,
223
224
        )

225
        num_batches = len(dataloader)
226
        print("num batchs: %d\n" % num_batches)
227
228
229
230

        start_all = time.time()
        start = time.time()
        with torch.no_grad():
231
232
233
234
235
236
237
238
            max_i = num_batches
            for i, walks in enumerate(dataloader):
                if self.args.fast_neg:
                    self.emb_model.fast_learn(walks)
                else:
                    # do negative sampling
                    bs = len(walks)
                    neg_nodes = torch.LongTensor(
239
240
241
242
243
244
                        np.random.choice(
                            self.dataset.neg_table,
                            bs * num_pos * self.args.negative,
                            replace=True,
                        )
                    )
245
                    self.emb_model.fast_learn(walks, neg_nodes=neg_nodes)
246

247
248
                if i > 0 and i % self.args.print_interval == 0:
                    if self.args.print_loss:
249
250
251
252
253
254
255
256
257
                        print(
                            "Batch %d training time: %.2fs loss: %.4f"
                            % (
                                i,
                                time.time() - start,
                                -sum(self.emb_model.loss)
                                / self.args.print_interval,
                            )
                        )
258
                        self.emb_model.loss = []
259
                    else:
260
261
262
263
                        print(
                            "Batch %d, training time: %.2fs"
                            % (i, time.time() - start)
                        )
264
                    start = time.time()
265

266
267
            if self.args.async_update:
                self.emb_model.finish_async_update()
268

269
        print("Training used time: %.2fs" % (time.time() - start_all))
270
        if self.args.save_in_txt:
271
272
273
            self.emb_model.save_embedding_txt(
                self.dataset, self.args.output_emb_file
            )
274
        elif self.args.save_in_pt:
275
276
277
            self.emb_model.save_embedding_pt(
                self.dataset, self.args.output_emb_file
            )
278
        else:
279
280
281
            self.emb_model.save_embedding(
                self.dataset, self.args.output_emb_file
            )
282

283
284

if __name__ == "__main__":
285
    parser = argparse.ArgumentParser(description="DeepWalk")
286
287
    # input files
    ## personal datasets
288
289
290
291
292
    parser.add_argument(
        "--data_file",
        type=str,
        help="path of the txt network file, builtin dataset include youtube-net and blog-net",
    )
293
    ## ogbl datasets
294
295
296
297
298
299
300
301
302
    parser.add_argument(
        "--ogbl_name", type=str, help="name of ogbl dataset, e.g. ogbl-ddi"
    )
    parser.add_argument(
        "--load_from_ogbl",
        default=False,
        action="store_true",
        help="whether load dataset from ogbl",
    )
303
304

    # output files
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
    parser.add_argument(
        "--save_in_txt",
        default=False,
        action="store_true",
        help="Whether save dat in txt format or npy",
    )
    parser.add_argument(
        "--save_in_pt",
        default=False,
        action="store_true",
        help="Whether save dat in pt format or npy",
    )
    parser.add_argument(
        "--output_emb_file",
        type=str,
        default="emb.npy",
        help="path of the output npy embedding file",
    )
    parser.add_argument(
        "--map_file",
        type=str,
        default="nodeid_to_index.pickle",
        help="path of the mapping dict that maps node ids to embedding index",
    )
    parser.add_argument(
        "--norm",
        default=False,
        action="store_true",
        help="whether to do normalization over node embedding after training",
    )

336
    # model parameters
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
    parser.add_argument(
        "--dim", default=128, type=int, help="embedding dimensions"
    )
    parser.add_argument(
        "--window_size", default=5, type=int, help="context window size"
    )
    parser.add_argument(
        "--use_context_weight",
        default=False,
        action="store_true",
        help="whether to add weights over nodes in the context window",
    )
    parser.add_argument(
        "--num_walks",
        default=10,
        type=int,
        help="number of walks for each node",
    )
    parser.add_argument(
        "--negative",
        default=1,
        type=int,
        help="negative samples for each positve node pair",
    )
    parser.add_argument(
        "--batch_size",
        default=128,
        type=int,
        help="number of node sequences in each batch",
    )
    parser.add_argument(
        "--walk_length",
        default=80,
        type=int,
        help="number of nodes in a sequence",
    )
    parser.add_argument(
        "--neg_weight", default=1.0, type=float, help="negative weight"
    )
    parser.add_argument(
        "--lap_norm",
        default=0.01,
        type=float,
        help="weight of laplacian normalization, recommend to set as 0.1 / windoe_size",
    )

383
    # training parameters
384
385
386
387
388
389
390
391
392
393
394
395
396
397
    parser.add_argument(
        "--print_interval",
        default=100,
        type=int,
        help="number of batches between printing",
    )
    parser.add_argument(
        "--print_loss",
        default=False,
        action="store_true",
        help="whether print loss during training",
    )
    parser.add_argument("--lr", default=0.2, type=float, help="learning rate")

398
    # optimization settings
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
    parser.add_argument(
        "--mix",
        default=False,
        action="store_true",
        help="mixed training with CPU and GPU",
    )
    parser.add_argument(
        "--gpus",
        type=int,
        default=[-1],
        nargs="+",
        help="a list of active gpu ids, e.g. 0, used with --mix",
    )
    parser.add_argument(
        "--only_cpu",
        default=False,
        action="store_true",
        help="training with CPU",
    )
    parser.add_argument(
        "--only_gpu",
        default=False,
        action="store_true",
        help="training with GPU",
    )
    parser.add_argument(
        "--async_update",
        default=False,
        action="store_true",
        help="mixed training asynchronously, not recommended",
    )

    parser.add_argument(
        "--true_neg",
        default=False,
        action="store_true",
        help="If not specified, this program will use "
        "a faster negative sampling method, "
        "but the samples might be false negative "
        "with a small probability. If specified, "
        "this program will generate a true negative sample table,"
        "and select from it when doing negative samling",
    )
    parser.add_argument(
        "--num_threads",
        default=8,
        type=int,
        help="number of threads used for each CPU-core/GPU",
    )
    parser.add_argument(
        "--num_sampler_threads",
        default=2,
        type=int,
        help="number of threads used for sampling",
    )

    parser.add_argument(
        "--count_params",
        default=False,
        action="store_true",
        help="count the params, exit once counting over",
    )
461

462
    args = parser.parse_args()
Jinjing Zhou's avatar
Jinjing Zhou committed
463
    args.fast_neg = not args.true_neg
464
465
466
    if args.async_update:
        assert args.mix, "--async_update only with --mix"

467
468
469
470
    start_time = time.time()
    trainer = DeepwalkTrainer(args)
    trainer.train()
    print("Total used time: %.2f" % (time.time() - start_time))