"tests/dist/vscode:/vscode.git/clone" did not exist on "2caa6bd02d9d86c911021dcf86781645e27273d9"
model.pretrain.rst 1.84 KB
Newer Older
Mufei Li's avatar
Mufei Li committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
.. _apimodelpretrain:

Pre-trained Models
==================

We provide multiple pre-trained models for users to use without the need of training from scratch.

Example Usage
-------------

Property Prediction
```````````````````

.. code-block:: python

    from dgllife.data import Tox21
    from dgllife.model import load_pretrained
    from dgllife.utils import smiles_to_bigraph, CanonicalAtomFeaturizer

    dataset = Tox21(smiles_to_bigraph, CanonicalAtomFeaturizer())
    model = load_pretrained('GCN_Tox21') # Pretrained model loaded
    model.eval()

    smiles, g, label, mask = dataset[0]
    feats = g.ndata.pop('h')
    label_pred = model(g, feats)
    print(smiles)                   # CCOc1ccc2nc(S(N)(=O)=O)sc2c1
    print(label_pred[:, mask != 0]) # Mask non-existing labels
    # tensor([[ 1.4190, -0.1820,  1.2974,  1.4416,  0.6914,
    # 2.0957,  0.5919,  0.7715, 1.7273,  0.2070]])

Generative Models

.. code-block:: python

    from dgllife.model import load_pretrained

    model = load_pretrained('DGMG_ZINC_canonical')
    model.eval()
    smiles = []
    for i in range(4):
        smiles.append(model(rdkit_mol=True))

    print(smiles)
    # ['CC1CCC2C(CCC3C2C(NC2=CC(Cl)=CC=C2N)S3(=O)=O)O1',
    # 'O=C1SC2N=CN=C(NC(SC3=CC=CC=N3)C1=CC=CO)C=2C1=CCCC1',
    # 'CC1C=CC(=CC=1)C(=O)NN=C(C)C1=CC=CC2=CC=CC=C21',
    # 'CCN(CC1=CC=CC=C1F)CC1CCCN(C)C1']

If you are running the code block above in Jupyter notebook, you can also visualize the molecules generated with

.. code-block:: python

    from IPython.display import SVG
    from rdkit import Chem
    from rdkit.Chem import Draw

    mols = [Chem.MolFromSmiles(s) for s in smiles]
    SVG(Draw.MolsToGridImage(mols, molsPerRow=4, subImgSize=(180, 150), useSVG=True))

.. image:: https://data.dgl.ai/dgllife/dgmg/dgmg_model_zoo_example2.png

API
---

.. autofunction:: dgllife.model.load_pretrained