deepwalk.py 12.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
import torch
import argparse
import dgl
import torch.multiprocessing as mp
from torch.utils.data import DataLoader
import os
import random
import time
import numpy as np

from reading_data import DeepwalkDataset
from model import SkipGramModel
13
from utils import thread_wrapped_func, shuffle_walks, sum_up_params
14
15
16
17
18
19

class DeepwalkTrainer:
    def __init__(self, args):
        """ Initializing the trainer with the input arguments """
        self.args = args
        self.dataset = DeepwalkDataset(
20
            net_file=args.data_file,
21
22
23
24
25
26
            map_file=args.map_file,
            walk_length=args.walk_length,
            window_size=args.window_size,
            num_walks=args.num_walks,
            batch_size=args.batch_size,
            negative=args.negative,
27
            gpus=args.gpus,
28
            fast_neg=args.fast_neg,
29
30
            ogbl_name=args.ogbl_name,
            load_from_ogbl=args.load_from_ogbl,
31
            )
32
        self.emb_size = self.dataset.G.number_of_nodes()
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
        self.emb_model = None

    def init_device_emb(self):
        """ set the device before training 
        will be called once in fast_train_mp / fast_train
        """
        choices = sum([self.args.only_gpu, self.args.only_cpu, self.args.mix])
        assert choices == 1, "Must choose only *one* training mode in [only_cpu, only_gpu, mix]"
        
        # initializing embedding on CPU
        self.emb_model = SkipGramModel(
            emb_size=self.emb_size, 
            emb_dimension=self.args.dim,
            walk_length=self.args.walk_length,
            window_size=self.args.window_size,
            batch_size=self.args.batch_size,
            only_cpu=self.args.only_cpu,
            only_gpu=self.args.only_gpu,
            mix=self.args.mix,
            neg_weight=self.args.neg_weight,
            negative=self.args.negative,
            lr=self.args.lr,
            lap_norm=self.args.lap_norm,
            fast_neg=self.args.fast_neg,
57
58
59
60
61
            record_loss=self.args.print_loss,
            norm=self.args.norm,
            use_context_weight=self.args.use_context_weight,
            async_update=self.args.async_update,
            num_threads=self.args.num_threads,
62
63
64
65
66
            )
        
        torch.set_num_threads(self.args.num_threads)
        if self.args.only_gpu:
            print("Run in 1 GPU")
67
68
            assert self.args.gpus[0] >= 0
            self.emb_model.all_to_device(self.args.gpus[0])
69
        elif self.args.mix:
70
71
            print("Mix CPU with %d GPU" % len(self.args.gpus))
            if len(self.args.gpus) == 1:
72
                assert self.args.gpus[0] >= 0, 'mix CPU with GPU should have available GPU'
73
                self.emb_model.set_device(self.args.gpus[0])
74
        else:
75
76
77
            print("Run in CPU process")
            self.args.gpus = [torch.device('cpu')]

78
79
80

    def train(self):
        """ train the embedding """
81
        if len(self.args.gpus) > 1:
82
83
84
85
86
87
88
89
90
            self.fast_train_mp()
        else:
            self.fast_train()

    def fast_train_mp(self):
        """ multi-cpu-core or mix cpu & multi-gpu """
        self.init_device_emb()
        self.emb_model.share_memory()

91
92
93
        if self.args.count_params:
            sum_up_params(self.emb_model)

94
95
96
        start_all = time.time()
        ps = []

97
        for i in range(len(self.args.gpus)):
98
            p = mp.Process(target=self.fast_train_sp, args=(i, self.args.gpus[i]))
99
100
101
102
103
104
105
            ps.append(p)
            p.start()

        for p in ps:
            p.join()
        
        print("Used time: %.2fs" % (time.time()-start_all))
106
107
        if self.args.save_in_txt:
            self.emb_model.save_embedding_txt(self.dataset, self.args.output_emb_file)
108
109
        elif self.args.save_in_pt:
            self.emb_model.save_embedding_pt(self.dataset, self.args.output_emb_file)
110
111
        else:
            self.emb_model.save_embedding(self.dataset, self.args.output_emb_file)
112
113

    @thread_wrapped_func
114
    def fast_train_sp(self, rank, gpu_id):
115
116
117
        """ a subprocess for fast_train_mp """
        if self.args.mix:
            self.emb_model.set_device(gpu_id)
118
        
119
        torch.set_num_threads(self.args.num_threads)
120
121
        if self.args.async_update:
            self.emb_model.create_async_update()
122

123
        sampler = self.dataset.create_sampler(rank)
124
125
126
127
128
129
130

        dataloader = DataLoader(
            dataset=sampler.seeds,
            batch_size=self.args.batch_size,
            collate_fn=sampler.sample,
            shuffle=False,
            drop_last=False,
131
            num_workers=self.args.num_sampler_threads,
132
133
            )
        num_batches = len(dataloader)
134
        print("num batchs: %d in process [%d] GPU [%d]" % (num_batches, rank, gpu_id))
135
136
137
138
139
140
141
142
        # number of positive node pairs in a sequence
        num_pos = int(2 * self.args.walk_length * self.args.window_size\
            - self.args.window_size * (self.args.window_size + 1))
        
        start = time.time()
        with torch.no_grad():
            for i, walks in enumerate(dataloader):
                if self.args.fast_neg:
143
                    self.emb_model.fast_learn(walks)
144
145
146
147
148
149
150
                else:
                    # do negative sampling
                    bs = len(walks)
                    neg_nodes = torch.LongTensor(
                        np.random.choice(self.dataset.neg_table, 
                            bs * num_pos * self.args.negative, 
                            replace=True))
151
                    self.emb_model.fast_learn(walks, neg_nodes=neg_nodes)
152
153

                if i > 0 and i % self.args.print_interval == 0:
154
155
156
157
158
159
                    if self.args.print_loss:
                        print("GPU-[%d] batch %d time: %.2fs loss: %.4f" \
                            % (gpu_id, i, time.time()-start, -sum(self.emb_model.loss)/self.args.print_interval))
                        self.emb_model.loss = []
                    else:
                        print("GPU-[%d] batch %d time: %.2fs" % (gpu_id, i, time.time()-start))
160
161
                    start = time.time()

162
163
164
            if self.args.async_update:
                self.emb_model.finish_async_update()

165
    def fast_train(self):
166
        """ fast train with dataloader with only gpu / only cpu"""
167
168
169
170
171
172
173
        # the number of postive node pairs of a node sequence
        num_pos = 2 * self.args.walk_length * self.args.window_size\
            - self.args.window_size * (self.args.window_size + 1)
        num_pos = int(num_pos)

        self.init_device_emb()

174
175
176
177
178
179
180
        if self.args.async_update:
            self.emb_model.share_memory()
            self.emb_model.create_async_update()

        if self.args.count_params:
            sum_up_params(self.emb_model)

181
182
183
184
185
186
187
188
        sampler = self.dataset.create_sampler(0)

        dataloader = DataLoader(
            dataset=sampler.seeds,
            batch_size=self.args.batch_size,
            collate_fn=sampler.sample,
            shuffle=False,
            drop_last=False,
189
            num_workers=self.args.num_sampler_threads,
190
191
192
            )
        
        num_batches = len(dataloader)
193
        print("num batchs: %d\n" % num_batches)
194
195
196
197

        start_all = time.time()
        start = time.time()
        with torch.no_grad():
198
199
200
201
202
203
204
205
206
207
208
209
            max_i = num_batches
            for i, walks in enumerate(dataloader):
                if self.args.fast_neg:
                    self.emb_model.fast_learn(walks)
                else:
                    # do negative sampling
                    bs = len(walks)
                    neg_nodes = torch.LongTensor(
                        np.random.choice(self.dataset.neg_table, 
                            bs * num_pos * self.args.negative, 
                            replace=True))
                    self.emb_model.fast_learn(walks, neg_nodes=neg_nodes)
210

211
212
213
214
215
                if i > 0 and i % self.args.print_interval == 0:
                    if self.args.print_loss:
                        print("Batch %d training time: %.2fs loss: %.4f" \
                            % (i, time.time()-start, -sum(self.emb_model.loss)/self.args.print_interval))
                        self.emb_model.loss = []
216
217
                    else:
                        print("Batch %d, training time: %.2fs" % (i, time.time()-start))
218
219
220
221
                    start = time.time()

            if self.args.async_update:
                self.emb_model.finish_async_update()
222
223

        print("Training used time: %.2fs" % (time.time()-start_all))
224
225
        if self.args.save_in_txt:
            self.emb_model.save_embedding_txt(self.dataset, self.args.output_emb_file)
226
227
        elif self.args.save_in_pt:
            self.emb_model.save_embedding_pt(self.dataset, self.args.output_emb_file)
228
229
        else:
            self.emb_model.save_embedding(self.dataset, self.args.output_emb_file)
230
231
232

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="DeepWalk")
233
234
    # input files
    ## personal datasets
235
236
    parser.add_argument('--data_file', type=str, 
            help="path of the txt network file, builtin dataset include youtube-net and blog-net") 
237
238
239
240
241
242
243
    ## ogbl datasets
    parser.add_argument('--ogbl_name', type=str, 
            help="name of ogbl dataset, e.g. ogbl-ddi")
    parser.add_argument('--load_from_ogbl', default=False, action="store_true",
            help="whether load dataset from ogbl")

    # output files
244
245
    parser.add_argument('--save_in_txt', default=False, action="store_true",
            help='Whether save dat in txt format or npy')
246
247
    parser.add_argument('--save_in_pt', default=False, action="store_true",
            help='Whether save dat in pt format or npy')
248
249
    parser.add_argument('--output_emb_file', type=str, default="emb.npy",
            help='path of the output npy embedding file')
250
251
    parser.add_argument('--map_file', type=str, default="nodeid_to_index.pickle",
            help='path of the mapping dict that maps node ids to embedding index')
252
253
254
255
    parser.add_argument('--norm', default=False, action="store_true", 
            help="whether to do normalization over node embedding after training")
    
    # model parameters
256
257
258
259
    parser.add_argument('--dim', default=128, type=int, 
            help="embedding dimensions")
    parser.add_argument('--window_size', default=5, type=int, 
            help="context window size")
260
261
    parser.add_argument('--use_context_weight', default=False, action="store_true", 
            help="whether to add weights over nodes in the context window")
262
263
    parser.add_argument('--num_walks', default=10, type=int, 
            help="number of walks for each node")
264
    parser.add_argument('--negative', default=1, type=int, 
265
            help="negative samples for each positve node pair")
266
    parser.add_argument('--batch_size', default=128, type=int, 
267
268
269
270
271
272
            help="number of node sequences in each batch")
    parser.add_argument('--walk_length', default=80, type=int, 
            help="number of nodes in a sequence")
    parser.add_argument('--neg_weight', default=1., type=float, 
            help="negative weight")
    parser.add_argument('--lap_norm', default=0.01, type=float, 
273
274
275
276
277
278
279
280
281
282
283
            help="weight of laplacian normalization, recommend to set as 0.1 / windoe_size")
    
    # training parameters
    parser.add_argument('--print_interval', default=100, type=int, 
            help="number of batches between printing")
    parser.add_argument('--print_loss', default=False, action="store_true", 
            help="whether print loss during training")
    parser.add_argument('--lr', default=0.2, type=float, 
            help="learning rate")
    
    # optimization settings
284
285
    parser.add_argument('--mix', default=False, action="store_true", 
            help="mixed training with CPU and GPU")
286
287
    parser.add_argument('--gpus', type=int, default=[-1], nargs='+', 
            help='a list of active gpu ids, e.g. 0, used with --mix')
288
289
290
291
    parser.add_argument('--only_cpu', default=False, action="store_true", 
            help="training with CPU")
    parser.add_argument('--only_gpu', default=False, action="store_true", 
            help="training with GPU")
292
293
294
295
    parser.add_argument('--async_update', default=False, action="store_true", 
            help="mixed training asynchronously, not recommended")

    parser.add_argument('--fast_neg', default=False, action="store_true", 
296
            help="do negative sampling inside a batch")
297
    parser.add_argument('--num_threads', default=8, type=int, 
298
            help="number of threads used for each CPU-core/GPU")
299
300
301
302
303
304
    parser.add_argument('--num_sampler_threads', default=2, type=int, 
            help="number of threads used for sampling")
    
    parser.add_argument('--count_params', default=False, action="store_true", 
            help="count the params, exit once counting over")

305
306
    args = parser.parse_args()

307
308
309
    if args.async_update:
        assert args.mix, "--async_update only with --mix"

310
311
312
    start_time = time.time()
    trainer = DeepwalkTrainer(args)
    trainer.train()
313
    print("Total used time: %.2f" % (time.time() - start_time))