"git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "5f04fc2b31410a8801d729207fd8f28ce233bd92"
Unverified Commit 0d878ff8 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Example] Cleaned GraphSAGE node classification example with PyTorch Lightning (#3863)

* cleaned pl node classification example

* conform to PL's method of updating the dataloader

* update

* lint

* fix test

* fix
parent baa92928
...@@ -37,6 +37,15 @@ python3 node_classification.py ...@@ -37,6 +37,15 @@ python3 node_classification.py
python3 multi_gpu_node_classification.py python3 multi_gpu_node_classification.py
``` ```
### PyTorch Lightning for node classification
Train w/ mini-batch sampling for node classification with PyTorch Lightning on OGB-products.
Works with both single GPU and multiple GPUs:
```bash
python3 lightning/node_classification.py
```
### Minibatch training for link prediction ### Minibatch training for link prediction
Train w/ mini-batch sampling for link prediction on OGB-Citation2: Train w/ mini-batch sampling for link prediction on OGB-Citation2:
......
import dgl
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import dgl.nn.pytorch as dglnn
import time
import argparse
import tqdm
import glob
import os
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from load_graph import load_reddit, inductive_split, load_ogb
from torchmetrics import Accuracy
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from model import SAGE
class SAGELightning(LightningModule):
def __init__(self,
in_feats,
n_hidden,
n_classes,
n_layers,
activation,
dropout,
lr):
super().__init__()
self.save_hyperparameters()
self.module = SAGE(in_feats, n_hidden, n_classes, n_layers, activation, dropout)
self.lr = lr
# The usage of `train_acc` and `val_acc` is the recommended practice from now on as per
# https://torchmetrics.readthedocs.io/en/latest/pages/lightning.html
self.train_acc = Accuracy()
self.val_acc = Accuracy()
def training_step(self, batch, batch_idx):
input_nodes, output_nodes, mfgs = batch
mfgs = [mfg.int().to(device) for mfg in mfgs]
batch_inputs = mfgs[0].srcdata['features']
batch_labels = mfgs[-1].dstdata['labels']
batch_pred = self.module(mfgs, batch_inputs)
loss = F.cross_entropy(batch_pred, batch_labels)
self.train_acc(th.softmax(batch_pred, 1), batch_labels)
self.log('train_acc', self.train_acc, prog_bar=True, on_step=True, on_epoch=False)
return loss
def validation_step(self, batch, batch_idx):
input_nodes, output_nodes, mfgs = batch
mfgs = [mfg.int().to(device) for mfg in mfgs]
batch_inputs = mfgs[0].srcdata['features']
batch_labels = mfgs[-1].dstdata['labels']
batch_pred = self.module(mfgs, batch_inputs)
self.val_acc(th.softmax(batch_pred, 1), batch_labels)
self.log('val_acc', self.val_acc, prog_bar=True, on_step=True, on_epoch=True, sync_dist=True)
def configure_optimizers(self):
optimizer = th.optim.Adam(self.parameters(), lr=self.lr)
return optimizer
class DataModule(LightningDataModule):
def __init__(self, dataset_name, data_cpu=False, fan_out=[10, 25],
device=th.device('cpu'), batch_size=1000, num_workers=4):
super().__init__()
if dataset_name == 'reddit':
g, n_classes = load_reddit()
elif dataset_name == 'ogbn-products':
g, n_classes = load_ogb('ogbn-products')
else:
raise ValueError('unknown dataset')
train_nid = th.nonzero(g.ndata['train_mask'], as_tuple=True)[0]
val_nid = th.nonzero(g.ndata['val_mask'], as_tuple=True)[0]
test_nid = th.nonzero(~(g.ndata['train_mask'] | g.ndata['val_mask']), as_tuple=True)[0]
sampler = dgl.dataloading.MultiLayerNeighborSampler([int(_) for _ in fan_out])
dataloader_device = th.device('cpu')
if not data_cpu:
train_nid = train_nid.to(device)
val_nid = val_nid.to(device)
test_nid = test_nid.to(device)
g = g.formats(['csc'])
g = g.to(device)
dataloader_device = device
self.g = g
self.train_nid, self.val_nid, self.test_nid = train_nid, val_nid, test_nid
self.sampler = sampler
self.device = dataloader_device
self.batch_size = batch_size
self.num_workers = num_workers
self.in_feats = g.ndata['features'].shape[1]
self.n_classes = n_classes
def train_dataloader(self):
return dgl.dataloading.NodeDataLoader(
self.g,
self.train_nid,
self.sampler,
device=self.device,
batch_size=self.batch_size,
shuffle=True,
drop_last=False,
num_workers=self.num_workers)
def val_dataloader(self):
return dgl.dataloading.NodeDataLoader(
self.g,
self.val_nid,
self.sampler,
device=self.device,
batch_size=self.batch_size,
shuffle=True,
drop_last=False,
num_workers=self.num_workers)
def evaluate(model, g, val_nid, device):
"""
Evaluate the model on the validation set specified by ``val_nid``.
g : The entire graph.
val_nid : the node Ids for validation.
device : The GPU device to evaluate on.
"""
model.eval()
nfeat = g.ndata['features']
labels = g.ndata['labels']
with th.no_grad():
pred = model.module.inference(g, nfeat, device, args.batch_size, args.num_workers)
model.train()
test_acc = Accuracy()
return test_acc(th.softmax(pred[val_nid], -1), labels[val_nid].to(pred.device))
if __name__ == '__main__':
argparser = argparse.ArgumentParser()
argparser.add_argument('--gpu', type=int, default=0,
help="GPU device ID. Use -1 for CPU training")
argparser.add_argument('--dataset', type=str, default='reddit')
argparser.add_argument('--num-epochs', type=int, default=20)
argparser.add_argument('--num-hidden', type=int, default=16)
argparser.add_argument('--num-layers', type=int, default=2)
argparser.add_argument('--fan-out', type=str, default='10,25')
argparser.add_argument('--batch-size', type=int, default=1000)
argparser.add_argument('--log-every', type=int, default=20)
argparser.add_argument('--eval-every', type=int, default=5)
argparser.add_argument('--lr', type=float, default=0.003)
argparser.add_argument('--dropout', type=float, default=0.5)
argparser.add_argument('--num-workers', type=int, default=0,
help="Number of sampling processes. Use 0 for no extra process.")
argparser.add_argument('--inductive', action='store_true',
help="Inductive learning setting")
argparser.add_argument('--data-cpu', action='store_true',
help="By default the script puts the graph, node features and labels "
"on GPU when using it to save time for data copy. This may "
"be undesired if they cannot fit in GPU memory at once. "
"This flag disables that.")
args = argparser.parse_args()
if args.gpu >= 0:
device = th.device('cuda:%d' % args.gpu)
else:
device = th.device('cpu')
datamodule = DataModule(
args.dataset, args.data_cpu, [int(_) for _ in args.fan_out.split(',')],
device, args.batch_size, args.num_workers)
model = SAGELightning(
datamodule.in_feats, args.num_hidden, datamodule.n_classes, args.num_layers,
F.relu, args.dropout, args.lr)
# Train
checkpoint_callback = ModelCheckpoint(monitor='val_acc', save_top_k=1)
trainer = Trainer(gpus=[args.gpu] if args.gpu != -1 else None,
max_epochs=args.num_epochs,
callbacks=[checkpoint_callback])
trainer.fit(model, datamodule=datamodule)
# Test
dirs = glob.glob('./lightning_logs/*')
version = max([int(os.path.split(x)[-1].split('_')[-1]) for x in dirs])
logdir = './lightning_logs/version_%d' % version
print('Evaluating model in', logdir)
ckpt = glob.glob(os.path.join(logdir, 'checkpoints', '*'))[0]
model = SAGELightning.load_from_checkpoint(
checkpoint_path=ckpt, hparams_file=os.path.join(logdir, 'hparams.yaml')).to(device)
test_acc = evaluate(model, datamodule.g, datamodule.test_nid, device)
print('Test accuracy:', test_acc)
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import dgl.nn.pytorch as dglnn
import tqdm
import glob
import os
from ogb.nodeproppred import DglNodePropPredDataset
from torchmetrics import Accuracy
import torchmetrics.functional as MF
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
class SAGE(LightningModule):
def __init__(self, in_feats, n_hidden, n_classes):
super().__init__()
self.save_hyperparameters()
self.layers = nn.ModuleList()
self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'mean'))
self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean'))
self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, 'mean'))
self.dropout = nn.Dropout(0.5)
self.n_hidden = n_hidden
self.n_classes = n_classes
self.train_acc = Accuracy()
self.val_acc = Accuracy()
def forward(self, blocks, x):
h = x
for l, (layer, block) in enumerate(zip(self.layers, blocks)):
h = layer(block, h)
if l != len(self.layers) - 1:
h = F.relu(h)
h = self.dropout(h)
return h
def inference(self, g, device, batch_size, num_workers, buffer_device=None):
# The difference between this inference function and the one in the official
# example is that the intermediate results can also benefit from prefetching.
g.ndata['h'] = g.ndata['feat']
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1, prefetch_node_feats=['h'])
dataloader = dgl.dataloading.NodeDataLoader(
g, torch.arange(g.num_nodes()).to(g.device), sampler, device=device,
batch_size=batch_size, shuffle=False, drop_last=False, num_workers=num_workers,
persistent_workers=(num_workers > 0))
if buffer_device is None:
buffer_device = device
for l, layer in enumerate(self.layers):
y = torch.zeros(
g.num_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes,
device=buffer_device)
for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
x = blocks[0].srcdata['h']
h = layer(blocks[0], x)
if l != len(self.layers) - 1:
h = F.relu(h)
h = self.dropout(h)
y[output_nodes] = h.to(buffer_device)
g.ndata['h'] = y
return y
def training_step(self, batch, batch_idx):
input_nodes, output_nodes, blocks = batch
x = blocks[0].srcdata['feat']
y = blocks[-1].dstdata['label']
y_hat = self(blocks, x)
loss = F.cross_entropy(y_hat, y)
self.train_acc(torch.argmax(y_hat, 1), y)
self.log('train_acc', self.train_acc, prog_bar=True, on_step=True, on_epoch=False)
return loss
def validation_step(self, batch, batch_idx):
input_nodes, output_nodes, blocks = batch
x = blocks[0].srcdata['feat']
y = blocks[-1].dstdata['label']
y_hat = self(blocks, x)
self.val_acc(torch.argmax(y_hat, 1), y)
self.log('val_acc', self.val_acc, prog_bar=True, on_step=True, on_epoch=True, sync_dist=True)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=0.001, weight_decay=5e-4)
return optimizer
class DataModule(LightningDataModule):
def __init__(self, graph, train_idx, val_idx, fanouts, batch_size, n_classes):
super().__init__()
sampler = dgl.dataloading.NeighborSampler(
fanouts, prefetch_node_feats=['feat'], prefetch_labels=['label'])
self.g = graph
self.train_idx, self.val_idx = train_idx, val_idx
self.sampler = sampler
self.batch_size = batch_size
self.in_feats = graph.ndata['feat'].shape[1]
self.n_classes = n_classes
def train_dataloader(self):
return dgl.dataloading.DataLoader(
self.g, self.train_idx.to('cuda'), self.sampler,
device='cuda', batch_size=self.batch_size, shuffle=True, drop_last=False,
# For CPU sampling, set num_workers to nonzero and use_uva=False
# Set use_ddp to False for single GPU.
num_workers=0, use_uva=True, use_ddp=True)
def val_dataloader(self):
return dgl.dataloading.DataLoader(
self.g, self.val_idx.to('cuda'), self.sampler,
device='cuda', batch_size=self.batch_size, shuffle=True, drop_last=False,
num_workers=0, use_uva=True)
if __name__ == '__main__':
dataset = DglNodePropPredDataset('ogbn-products')
graph, labels = dataset[0]
graph.ndata['label'] = labels.squeeze()
split_idx = dataset.get_idx_split()
train_idx, val_idx, test_idx = split_idx['train'], split_idx['valid'], split_idx['test']
datamodule = DataModule(graph, train_idx, val_idx, [15, 10, 5], 1024, dataset.num_classes)
model = SAGE(datamodule.in_feats, 256, datamodule.n_classes)
# Train
checkpoint_callback = ModelCheckpoint(monitor='val_acc', save_top_k=1)
# Use this for single GPU
#trainer = Trainer(gpus=[0], max_epochs=10, callbacks=[checkpoint_callback])
trainer = Trainer(gpus=[0, 1, 2, 3], max_epochs=10, callbacks=[checkpoint_callback], strategy='ddp_spawn')
trainer.fit(model, datamodule=datamodule)
# Test
dirs = glob.glob('./lightning_logs/*')
version = max([int(os.path.split(x)[-1].split('_')[-1]) for x in dirs])
logdir = './lightning_logs/version_%d' % version
print('Evaluating model in', logdir)
ckpt = glob.glob(os.path.join(logdir, 'checkpoints', '*'))[0]
model = SAGE.load_from_checkpoint(
checkpoint_path=ckpt, hparams_file=os.path.join(logdir, 'hparams.yaml')).to('cuda')
with torch.no_grad():
pred = model.inference(graph, 'cuda', 4096, 12, graph.device)
pred = pred[test_idx]
label = graph.ndata['label'][test_idx]
acc = MF.accuracy(pred, label)
print('Test accuracy:', acc)
...@@ -692,6 +692,36 @@ class DataLoader(torch.utils.data.DataLoader): ...@@ -692,6 +692,36 @@ class DataLoader(torch.utils.data.DataLoader):
ddp_seed=0, batch_size=1, drop_last=False, shuffle=False, ddp_seed=0, batch_size=1, drop_last=False, shuffle=False,
use_prefetch_thread=None, use_alternate_streams=None, use_prefetch_thread=None, use_alternate_streams=None,
pin_prefetcher=None, use_uva=False, **kwargs): pin_prefetcher=None, use_uva=False, **kwargs):
# (BarclayII) PyTorch Lightning sometimes will recreate a DataLoader from an existing
# DataLoader with modifications to the original arguments. The arguments are retrieved
# from the attributes with the same name, and because we change certain arguments
# when calling super().__init__() (e.g. batch_size attribute is None even if the
# batch_size argument is not, so the next DataLoader's batch_size argument will be
# None), we cannot reinitialize the DataLoader with attributes from the previous
# DataLoader directly.
# A workaround is to check whether "collate_fn" appears in kwargs. If "collate_fn"
# is indeed in kwargs and it's already a CollateWrapper object, we can assume that
# the arguments come from a previously created DGL DataLoader, and directly initialize
# the new DataLoader from kwargs without any changes.
if isinstance(kwargs.get('collate_fn', None), CollateWrapper):
assert batch_size is None # must be None
# restore attributes
self.graph = graph
self.indices = indices
self.graph_sampler = graph_sampler
self.device = device
self.use_ddp = use_ddp
self.ddp_seed = ddp_seed
self.shuffle = shuffle
self.drop_last = drop_last
self.use_prefetch_thread = use_prefetch_thread
self.use_alternate_streams = use_alternate_streams
self.pin_prefetcher = pin_prefetcher
self.use_uva = use_uva
kwargs['batch_size'] = None
super().__init__(**kwargs)
return
if isinstance(graph, DistGraph): if isinstance(graph, DistGraph):
raise TypeError( raise TypeError(
'Please use dgl.dataloading.DistNodeDataLoader or ' 'Please use dgl.dataloading.DistNodeDataLoader or '
...@@ -808,6 +838,7 @@ class DataLoader(torch.utils.data.DataLoader): ...@@ -808,6 +838,7 @@ class DataLoader(torch.utils.data.DataLoader):
self.use_alternate_streams = use_alternate_streams self.use_alternate_streams = use_alternate_streams
self.pin_prefetcher = pin_prefetcher self.pin_prefetcher = pin_prefetcher
self.use_prefetch_thread = use_prefetch_thread self.use_prefetch_thread = use_prefetch_thread
worker_init_fn = WorkerInitWrapper(kwargs.get('worker_init_fn', None)) worker_init_fn = WorkerInitWrapper(kwargs.get('worker_init_fn', None))
self.other_storages = {} self.other_storages = {}
......
...@@ -124,6 +124,48 @@ def cast_to_signed(arr): ...@@ -124,6 +124,48 @@ def cast_to_signed(arr):
""" """
return _CAPI_DGLArrayCastToSigned(arr) return _CAPI_DGLArrayCastToSigned(arr)
def get_shared_mem_array(name, shape, dtype):
""" Get a tensor from shared memory with specific name
Parameters
----------
name : str
The unique name of the shared memory
shape : tuple of int
The shape of the returned tensor
dtype : F.dtype
The dtype of the returned tensor
Returns
-------
F.tensor
The tensor got from shared memory.
"""
new_arr = empty_shared_mem(name, False, shape, F.reverse_data_type_dict[dtype])
dlpack = new_arr.to_dlpack()
return F.zerocopy_from_dlpack(dlpack)
def create_shared_mem_array(name, shape, dtype):
""" Create a tensor from shared memory with the specific name
Parameters
----------
name : str
The unique name of the shared memory
shape : tuple of int
The shape of the returned tensor
dtype : F.dtype
The dtype of the returned tensor
Returns
-------
F.tensor
The created tensor.
"""
new_arr = empty_shared_mem(name, True, shape, F.reverse_data_type_dict[dtype])
dlpack = new_arr.to_dlpack()
return F.zerocopy_from_dlpack(dlpack)
def exist_shared_mem_array(name): def exist_shared_mem_array(name):
""" Check the existence of shared-memory array. """ Check the existence of shared-memory array.
......
"""Shared memory utilities.""" """Shared memory utilities.
from .. import backend as F
from .._ffi.ndarray import empty_shared_mem
def get_shared_mem_array(name, shape, dtype): For compatibility with older code that uses ``dgl.utils.shared_mem`` namespace; the
""" Get a tensor from shared memory with specific name content has been moved to ``dgl.ndarray`` module.
"""
Parameters from ..ndarray import get_shared_mem_array, create_shared_mem_array # pylint: disable=unused-import
----------
name : str
The unique name of the shared memory
shape : tuple of int
The shape of the returned tensor
dtype : F.dtype
The dtype of the returned tensor
Returns
-------
F.tensor
The tensor got from shared memory.
"""
name = 'DGL_'+name
new_arr = empty_shared_mem(name, False, shape, F.reverse_data_type_dict[dtype])
dlpack = new_arr.to_dlpack()
return F.zerocopy_from_dlpack(dlpack)
def create_shared_mem_array(name, shape, dtype):
""" Create a tensor from shared memory with the specific name
Parameters
----------
name : str
The unique name of the shared memory
shape : tuple of int
The shape of the returned tensor
dtype : F.dtype
The dtype of the returned tensor
Returns
-------
F.tensor
The created tensor.
"""
name = 'DGL_'+name
new_arr = empty_shared_mem(name, True, shape, F.reverse_data_type_dict[dtype])
dlpack = new_arr.to_dlpack()
return F.zerocopy_from_dlpack(dlpack)
...@@ -1388,10 +1388,6 @@ UnitGraph::CSRPtr UnitGraph::GetInCSR(bool inplace) const { ...@@ -1388,10 +1388,6 @@ UnitGraph::CSRPtr UnitGraph::GetInCSR(bool inplace) const {
// Prefers converting from COO since it is parallelized. // Prefers converting from COO since it is parallelized.
// TODO(BarclayII): need benchmarking. // TODO(BarclayII): need benchmarking.
if (!in_csr_->defined()) { if (!in_csr_->defined()) {
// inplace new formats materialization is not allowed for pinned graphs
if (inplace && IsPinned())
LOG(FATAL) << "Cannot create new formats for pinned graphs, " <<
"please create the CSC format before pinning.";
if (coo_->defined()) { if (coo_->defined()) {
const auto& newadj = aten::COOToCSR( const auto& newadj = aten::COOToCSR(
aten::COOTranspose(coo_->adj())); aten::COOTranspose(coo_->adj()));
...@@ -1409,6 +1405,8 @@ UnitGraph::CSRPtr UnitGraph::GetInCSR(bool inplace) const { ...@@ -1409,6 +1405,8 @@ UnitGraph::CSRPtr UnitGraph::GetInCSR(bool inplace) const {
else else
ret = std::make_shared<CSR>(meta_graph(), newadj); ret = std::make_shared<CSR>(meta_graph(), newadj);
} }
if (inplace && IsPinned())
in_csr_->PinMemory_();
} }
return ret; return ret;
} }
...@@ -1423,10 +1421,6 @@ UnitGraph::CSRPtr UnitGraph::GetOutCSR(bool inplace) const { ...@@ -1423,10 +1421,6 @@ UnitGraph::CSRPtr UnitGraph::GetOutCSR(bool inplace) const {
// Prefers converting from COO since it is parallelized. // Prefers converting from COO since it is parallelized.
// TODO(BarclayII): need benchmarking. // TODO(BarclayII): need benchmarking.
if (!out_csr_->defined()) { if (!out_csr_->defined()) {
// inplace new formats materialization is not allowed for pinned graphs
if (inplace && IsPinned())
LOG(FATAL) << "Cannot create new formats for pinned graphs, " <<
"please create the CSR format before pinning.";
if (coo_->defined()) { if (coo_->defined()) {
const auto& newadj = aten::COOToCSR(coo_->adj()); const auto& newadj = aten::COOToCSR(coo_->adj());
...@@ -1443,6 +1437,8 @@ UnitGraph::CSRPtr UnitGraph::GetOutCSR(bool inplace) const { ...@@ -1443,6 +1437,8 @@ UnitGraph::CSRPtr UnitGraph::GetOutCSR(bool inplace) const {
else else
ret = std::make_shared<CSR>(meta_graph(), newadj); ret = std::make_shared<CSR>(meta_graph(), newadj);
} }
if (inplace && IsPinned())
out_csr_->PinMemory_();
} }
return ret; return ret;
} }
...@@ -1455,10 +1451,6 @@ UnitGraph::COOPtr UnitGraph::GetCOO(bool inplace) const { ...@@ -1455,10 +1451,6 @@ UnitGraph::COOPtr UnitGraph::GetCOO(bool inplace) const {
CodeToStr(formats_) << ", cannot create COO matrix."; CodeToStr(formats_) << ", cannot create COO matrix.";
COOPtr ret = coo_; COOPtr ret = coo_;
if (!coo_->defined()) { if (!coo_->defined()) {
// inplace new formats materialization is not allowed for pinned graphs
if (inplace && IsPinned())
LOG(FATAL) << "Cannot create new formats for pinned graphs, " <<
"please create the COO format before pinning.";
if (in_csr_->defined()) { if (in_csr_->defined()) {
const auto& newadj = aten::COOTranspose(aten::CSRToCOO(in_csr_->adj(), true)); const auto& newadj = aten::COOTranspose(aten::CSRToCOO(in_csr_->adj(), true));
...@@ -1475,6 +1467,8 @@ UnitGraph::COOPtr UnitGraph::GetCOO(bool inplace) const { ...@@ -1475,6 +1467,8 @@ UnitGraph::COOPtr UnitGraph::GetCOO(bool inplace) const {
else else
ret = std::make_shared<COO>(meta_graph(), newadj); ret = std::make_shared<COO>(meta_graph(), newadj);
} }
if (inplace && IsPinned())
coo_->PinMemory_();
} }
return ret; return ret;
} }
......
...@@ -1009,9 +1009,6 @@ def test_pin_memory_(idtype): ...@@ -1009,9 +1009,6 @@ def test_pin_memory_(idtype):
for etype in g.canonical_etypes: for etype in g.canonical_etypes:
assert F.context(g.batch_num_edges(etype)) == F.cpu() assert F.context(g.batch_num_edges(etype)) == F.cpu()
# not allowed to create new formats for the pinned graph
with pytest.raises(DGLError):
g.create_formats_()
# it's fine to clone with new formats, but new graphs are not pinned # it's fine to clone with new formats, but new graphs are not pinned
# >>> g.formats() # >>> g.formats()
# {'created': ['coo'], 'not created': ['csr', 'csc']} # {'created': ['coo'], 'not created': ['csr', 'csc']}
...@@ -1020,7 +1017,7 @@ def test_pin_memory_(idtype): ...@@ -1020,7 +1017,7 @@ def test_pin_memory_(idtype):
# 'coo' formats is already created and thus not cloned # 'coo' formats is already created and thus not cloned
assert g.formats('coo').is_pinned() assert g.formats('coo').is_pinned()
# pin a pinned graph, direcly return # pin a pinned graph, directly return
g.pin_memory_() g.pin_memory_()
assert g.is_pinned() assert g.is_pinned()
assert g.device == F.cpu() assert g.device == F.cpu()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment