Commit d43ed0b9 authored by rusty1s's avatar rusty1s
Browse files

dropout adj

parent 67bf815a
......@@ -13,7 +13,7 @@ from .data import get_data # noqa
from .history import History # noqa
from .pool import AsyncIOPool # noqa
from .metis import metis, permute # noqa
from .utils import compute_micro_f1 # noqa
from .utils import compute_micro_f1, gen_masks, dropout # noqa
from .loader import SubgraphLoader, EvalSubgraphLoader # noqa
__all__ = [
......@@ -23,6 +23,8 @@ __all__ = [
'metis',
'permute',
'compute_micro_f1',
'gen_masks',
'dropout',
'SubgraphLoader',
'EvalSubgraphLoader',
'__version__',
......
......@@ -2,6 +2,8 @@ from typing import Optional, Tuple
import torch
from torch import Tensor
import torch.nn.functional as F
from torch_sparse import SparseTensor
def index2mask(idx: Tensor, size: int) -> Tensor:
......@@ -54,3 +56,17 @@ def gen_masks(y: Tensor, train_per_class: int = 20, val_per_class: int = 30,
test_mask = ~(train_mask | val_mask)
return train_mask, val_mask, test_mask
def dropout(adj_t: SparseTensor, p: float, training: bool = True):
if not training:
return adj_t
if adj_t.storage.value() is not None:
value = F.dropout(adj_t.storage.value(), p=p)
adj_t = adj_t.set_value(value, layout='coo')
else:
mask = torch.rand(adj_t.nnz(), device=adj_t.storage().row.device) > p
adj_t = adj_t.masked_select_nnz(mask, layout='coo')
return adj_t
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment