Unverified Commit de34e15a authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

fix metapath2vec on custom datasets (#1499)

parent 16561a2e
......@@ -10,14 +10,26 @@ Dependencies
How to run the code
-----
Run with the following procedures:
1, Run sampler.py on your graph dataset. Note that: the input text file should be list of mappings so you probably need to preprocess your graph dataset. Files with sample format are available in "net_dbis" file. Of course you could also use your own metapath sampler implementation.
2, Run the following command:
```bash
python metapath2vec.py --download "where/you/want/to/download" --output_file "your_output_file_path"
```
Run with either of the following procedures:
* Running with default AMiner dataset:
1. Directly run the following command:
```bash
python metapath2vec.py --aminer --path "where/you/want/to/download" --output_file "your_model_output_path"
```
* Running with another AMiner-like dataset
1. Prepare the data in the same format as the ones of AMiner and DBIS in Section B of [Author's code repo](https://ericdongyx.github.io/metapath2vec/m2v.html).
2. Run `sampler.py` on your graph dataset with, for instance,
```bash
python sampler.py net_dbis
```
3. Run the following command:
```bash
python metapath2vec.py --path net_dbis/output_path.txt --output_file "your_model_output_path"
```
Tips: Change num_workers based on your GPU instances; Running 3 or 4 epochs is actually enough.
......
......@@ -44,3 +44,11 @@ class AminerDataset(object):
with zipfile.ZipFile(fn) as zf:
zf.extractall(path)
print('Unzip finished.')
class CustomDataset(object):
"""
Custom dataset generated by sampler.py (e.g. NetDBIS)
"""
def __init__(self, path):
self.fn = path
......@@ -7,11 +7,16 @@ from tqdm import tqdm
from reading_data import DataReader, Metapath2vecDataset
from model import SkipGramModel
from download import AminerDataset, CustomDataset
class Metapath2VecTrainer:
def __init__(self, args):
self.data = DataReader(args.download, args.min_count, args.care_type)
if args.aminer:
dataset = AminerDataset(args.path)
else:
dataset = CustomDataset(args.path)
self.data = DataReader(dataset, args.min_count, args.care_type)
dataset = Metapath2vecDataset(self.data, args.window_size)
self.dataloader = DataLoader(dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.num_workers, collate_fn=dataset.collate)
......@@ -60,7 +65,8 @@ class Metapath2VecTrainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Metapath2vec")
#parser.add_argument('--input_file', type=str, help="input_file")
parser.add_argument('--download', type=str, help="download_path")
parser.add_argument('--aminer', action='store_true', help='Use AMiner dataset')
parser.add_argument('--path', type=str, help="input_path")
parser.add_argument('--output_file', type=str, help='output_file')
parser.add_argument('--dim', default=128, type=int, help="embedding dimensions")
parser.add_argument('--window_size', default=7, type=int, help="context window size")
......
......@@ -7,7 +7,7 @@ np.random.seed(12345)
class DataReader:
NEGATIVE_TABLE_SIZE = 1e8
def __init__(self, download, min_count, care_type):
def __init__(self, dataset, min_count, care_type):
self.negatives = []
self.discards = []
......@@ -18,9 +18,7 @@ class DataReader:
self.sentences_count = 0
self.token_count = 0
self.word_frequency = dict()
self.download = download
FB = AminerDataset(self.download)
self.inputFileName = FB.fn
self.inputFileName = dataset.fn
self.read_words(min_count)
self.initTableNegatives()
self.initTableDiscards()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment