Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
pyg_autoscale
Commits
d43ed0b9
Commit
d43ed0b9
authored
Jun 09, 2021
by
rusty1s
Browse files
dropout adj
parent
67bf815a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
19 additions
and
1 deletion
+19
-1
torch_geometric_autoscale/__init__.py
torch_geometric_autoscale/__init__.py
+3
-1
torch_geometric_autoscale/utils.py
torch_geometric_autoscale/utils.py
+16
-0
No files found.
torch_geometric_autoscale/__init__.py
View file @
d43ed0b9
...
@@ -13,7 +13,7 @@ from .data import get_data # noqa
...
@@ -13,7 +13,7 @@ from .data import get_data # noqa
from
.history
import
History
# noqa
from
.history
import
History
# noqa
from
.pool
import
AsyncIOPool
# noqa
from
.pool
import
AsyncIOPool
# noqa
from
.metis
import
metis
,
permute
# noqa
from
.metis
import
metis
,
permute
# noqa
from
.utils
import
compute_micro_f1
# noqa
from
.utils
import
compute_micro_f1
,
gen_masks
,
dropout
# noqa
from
.loader
import
SubgraphLoader
,
EvalSubgraphLoader
# noqa
from
.loader
import
SubgraphLoader
,
EvalSubgraphLoader
# noqa
__all__
=
[
__all__
=
[
...
@@ -23,6 +23,8 @@ __all__ = [
...
@@ -23,6 +23,8 @@ __all__ = [
'metis'
,
'metis'
,
'permute'
,
'permute'
,
'compute_micro_f1'
,
'compute_micro_f1'
,
'gen_masks'
,
'dropout'
,
'SubgraphLoader'
,
'SubgraphLoader'
,
'EvalSubgraphLoader'
,
'EvalSubgraphLoader'
,
'__version__'
,
'__version__'
,
...
...
torch_geometric_autoscale/utils.py
View file @
d43ed0b9
...
@@ -2,6 +2,8 @@ from typing import Optional, Tuple
...
@@ -2,6 +2,8 @@ from typing import Optional, Tuple
import
torch
import
torch
from
torch
import
Tensor
from
torch
import
Tensor
import
torch.nn.functional
as
F
from
torch_sparse
import
SparseTensor
def
index2mask
(
idx
:
Tensor
,
size
:
int
)
->
Tensor
:
def
index2mask
(
idx
:
Tensor
,
size
:
int
)
->
Tensor
:
...
@@ -54,3 +56,17 @@ def gen_masks(y: Tensor, train_per_class: int = 20, val_per_class: int = 30,
...
@@ -54,3 +56,17 @@ def gen_masks(y: Tensor, train_per_class: int = 20, val_per_class: int = 30,
test_mask
=
~
(
train_mask
|
val_mask
)
test_mask
=
~
(
train_mask
|
val_mask
)
return
train_mask
,
val_mask
,
test_mask
return
train_mask
,
val_mask
,
test_mask
def
dropout
(
adj_t
:
SparseTensor
,
p
:
float
,
training
:
bool
=
True
):
if
not
training
:
return
adj_t
if
adj_t
.
storage
.
value
()
is
not
None
:
value
=
F
.
dropout
(
adj_t
.
storage
.
value
(),
p
=
p
)
adj_t
=
adj_t
.
set_value
(
value
,
layout
=
'coo'
)
else
:
mask
=
torch
.
rand
(
adj_t
.
nnz
(),
device
=
adj_t
.
storage
().
row
.
device
)
>
p
adj_t
=
adj_t
.
masked_select_nnz
(
mask
,
layout
=
'coo'
)
return
adj_t
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment