Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
704bcaf6
Unverified
Commit
704bcaf6
authored
Feb 19, 2023
by
Hongzhi (Steve), Chen
Committed by
GitHub
Feb 19, 2023
Browse files
examples (#5323)
Co-authored-by:
Ubuntu
<
ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal
>
parent
6bc82161
Changes
332
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
282 additions
and
166 deletions
+282
-166
examples/pytorch/pointcloud/bipointnet/bipointnet2.py
examples/pytorch/pointcloud/bipointnet/bipointnet2.py
+48
-21
examples/pytorch/pointcloud/bipointnet/bipointnet_cls.py
examples/pytorch/pointcloud/bipointnet/bipointnet_cls.py
+23
-12
examples/pytorch/pointcloud/bipointnet/train_cls.py
examples/pytorch/pointcloud/bipointnet/train_cls.py
+56
-39
examples/pytorch/pointcloud/edgeconv/main.py
examples/pytorch/pointcloud/edgeconv/main.py
+3
-3
examples/pytorch/pointcloud/pct/ShapeNet.py
examples/pytorch/pointcloud/pct/ShapeNet.py
+3
-3
examples/pytorch/pointcloud/pct/helper.py
examples/pytorch/pointcloud/pct/helper.py
+1
-2
examples/pytorch/pointcloud/pct/train_cls.py
examples/pytorch/pointcloud/pct/train_cls.py
+2
-3
examples/pytorch/pointcloud/pct/train_partseg.py
examples/pytorch/pointcloud/pct/train_partseg.py
+2
-2
examples/pytorch/pointcloud/point_transformer/ShapeNet.py
examples/pytorch/pointcloud/point_transformer/ShapeNet.py
+3
-3
examples/pytorch/pointcloud/point_transformer/helper.py
examples/pytorch/pointcloud/point_transformer/helper.py
+1
-3
examples/pytorch/pointcloud/point_transformer/point_transformer.py
...pytorch/pointcloud/point_transformer/point_transformer.py
+1
-1
examples/pytorch/pointcloud/point_transformer/train_cls.py
examples/pytorch/pointcloud/point_transformer/train_cls.py
+2
-3
examples/pytorch/pointcloud/point_transformer/train_partseg.py
...les/pytorch/pointcloud/point_transformer/train_partseg.py
+2
-2
examples/pytorch/pointcloud/pointnet/ShapeNet.py
examples/pytorch/pointcloud/pointnet/ShapeNet.py
+3
-3
examples/pytorch/pointcloud/pointnet/pointnet2.py
examples/pytorch/pointcloud/pointnet/pointnet2.py
+3
-4
examples/pytorch/pointcloud/pointnet/train_cls.py
examples/pytorch/pointcloud/pointnet/train_cls.py
+3
-4
examples/pytorch/pointcloud/pointnet/train_partseg.py
examples/pytorch/pointcloud/pointnet/train_partseg.py
+5
-3
examples/pytorch/rect/utils.py
examples/pytorch/rect/utils.py
+1
-2
examples/pytorch/rgat/train.py
examples/pytorch/rgat/train.py
+119
-52
examples/pytorch/rgcn-hetero/entity_classify.py
examples/pytorch/rgcn-hetero/entity_classify.py
+1
-1
No files found.
examples/pytorch/pointcloud/bipointnet/bipointnet2.py
View file @
704bcaf6
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
basic
import
BiLinearLSR
,
BiConv2d
,
FixedRadiusNNGraph
,
RelativePositionMessage
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
basic
import
(
BiConv2d
,
BiLinearLSR
,
FixedRadiusNNGraph
,
RelativePositionMessage
,
)
from
dgl.geometry
import
farthest_point_sampler
from
dgl.geometry
import
farthest_point_sampler
class
BiPointNetConv
(
nn
.
Module
):
class
BiPointNetConv
(
nn
.
Module
):
'''
"""
Feature aggregation
Feature aggregation
'''
"""
def
__init__
(
self
,
sizes
,
batch_size
):
def
__init__
(
self
,
sizes
,
batch_size
):
super
(
BiPointNetConv
,
self
).
__init__
()
super
(
BiPointNetConv
,
self
).
__init__
()
self
.
batch_size
=
batch_size
self
.
batch_size
=
batch_size
self
.
conv
=
nn
.
ModuleList
()
self
.
conv
=
nn
.
ModuleList
()
self
.
bn
=
nn
.
ModuleList
()
self
.
bn
=
nn
.
ModuleList
()
for
i
in
range
(
1
,
len
(
sizes
)):
for
i
in
range
(
1
,
len
(
sizes
)):
self
.
conv
.
append
(
BiConv2d
(
sizes
[
i
-
1
],
sizes
[
i
],
1
))
self
.
conv
.
append
(
BiConv2d
(
sizes
[
i
-
1
],
sizes
[
i
],
1
))
self
.
bn
.
append
(
nn
.
BatchNorm2d
(
sizes
[
i
]))
self
.
bn
.
append
(
nn
.
BatchNorm2d
(
sizes
[
i
]))
def
forward
(
self
,
nodes
):
def
forward
(
self
,
nodes
):
shape
=
nodes
.
mailbox
[
'agg_feat'
].
shape
shape
=
nodes
.
mailbox
[
"agg_feat"
].
shape
h
=
nodes
.
mailbox
[
'agg_feat'
].
view
(
self
.
batch_size
,
-
1
,
shape
[
1
],
shape
[
2
]).
permute
(
0
,
3
,
2
,
1
)
h
=
(
nodes
.
mailbox
[
"agg_feat"
]
.
view
(
self
.
batch_size
,
-
1
,
shape
[
1
],
shape
[
2
])
.
permute
(
0
,
3
,
2
,
1
)
)
for
conv
,
bn
in
zip
(
self
.
conv
,
self
.
bn
):
for
conv
,
bn
in
zip
(
self
.
conv
,
self
.
bn
):
h
=
conv
(
h
)
h
=
conv
(
h
)
h
=
bn
(
h
)
h
=
bn
(
h
)
...
@@ -28,12 +38,12 @@ class BiPointNetConv(nn.Module):
...
@@ -28,12 +38,12 @@ class BiPointNetConv(nn.Module):
h
=
torch
.
max
(
h
,
2
)[
0
]
h
=
torch
.
max
(
h
,
2
)[
0
]
feat_dim
=
h
.
shape
[
1
]
feat_dim
=
h
.
shape
[
1
]
h
=
h
.
permute
(
0
,
2
,
1
).
reshape
(
-
1
,
feat_dim
)
h
=
h
.
permute
(
0
,
2
,
1
).
reshape
(
-
1
,
feat_dim
)
return
{
'
new_feat
'
:
h
}
return
{
"
new_feat
"
:
h
}
def
group_all
(
self
,
pos
,
feat
):
def
group_all
(
self
,
pos
,
feat
):
'''
"""
Feature aggregation and pooling for the non-sampling layer
Feature aggregation and pooling for the non-sampling layer
'''
"""
if
feat
is
not
None
:
if
feat
is
not
None
:
h
=
torch
.
cat
([
pos
,
feat
],
2
)
h
=
torch
.
cat
([
pos
,
feat
],
2
)
else
:
else
:
...
@@ -49,12 +59,21 @@ class BiPointNetConv(nn.Module):
...
@@ -49,12 +59,21 @@ class BiPointNetConv(nn.Module):
h
=
torch
.
max
(
h
[:,
:,
:,
0
],
2
)[
0
]
# [B,D]
h
=
torch
.
max
(
h
[:,
:,
:,
0
],
2
)[
0
]
# [B,D]
return
new_pos
,
h
return
new_pos
,
h
class
BiSAModule
(
nn
.
Module
):
class
BiSAModule
(
nn
.
Module
):
"""
"""
The Set Abstraction Layer
The Set Abstraction Layer
"""
"""
def
__init__
(
self
,
npoints
,
batch_size
,
radius
,
mlp_sizes
,
n_neighbor
=
64
,
group_all
=
False
):
def
__init__
(
self
,
npoints
,
batch_size
,
radius
,
mlp_sizes
,
n_neighbor
=
64
,
group_all
=
False
,
):
super
(
BiSAModule
,
self
).
__init__
()
super
(
BiSAModule
,
self
).
__init__
()
self
.
group_all
=
group_all
self
.
group_all
=
group_all
if
not
group_all
:
if
not
group_all
:
...
@@ -72,22 +91,30 @@ class BiSAModule(nn.Module):
...
@@ -72,22 +91,30 @@ class BiSAModule(nn.Module):
g
=
self
.
frnn_graph
(
pos
,
centroids
,
feat
)
g
=
self
.
frnn_graph
(
pos
,
centroids
,
feat
)
g
.
update_all
(
self
.
message
,
self
.
conv
)
g
.
update_all
(
self
.
message
,
self
.
conv
)
mask
=
g
.
ndata
[
'
center
'
]
==
1
mask
=
g
.
ndata
[
"
center
"
]
==
1
pos_dim
=
g
.
ndata
[
'
pos
'
].
shape
[
-
1
]
pos_dim
=
g
.
ndata
[
"
pos
"
].
shape
[
-
1
]
feat_dim
=
g
.
ndata
[
'
new_feat
'
].
shape
[
-
1
]
feat_dim
=
g
.
ndata
[
"
new_feat
"
].
shape
[
-
1
]
pos_res
=
g
.
ndata
[
'
pos
'
][
mask
].
view
(
self
.
batch_size
,
-
1
,
pos_dim
)
pos_res
=
g
.
ndata
[
"
pos
"
][
mask
].
view
(
self
.
batch_size
,
-
1
,
pos_dim
)
feat_res
=
g
.
ndata
[
'
new_feat
'
][
mask
].
view
(
self
.
batch_size
,
-
1
,
feat_dim
)
feat_res
=
g
.
ndata
[
"
new_feat
"
][
mask
].
view
(
self
.
batch_size
,
-
1
,
feat_dim
)
return
pos_res
,
feat_res
return
pos_res
,
feat_res
class
BiPointNet2SSGCls
(
nn
.
Module
):
class
BiPointNet2SSGCls
(
nn
.
Module
):
def
__init__
(
self
,
output_classes
,
batch_size
,
input_dims
=
3
,
dropout_prob
=
0.4
):
def
__init__
(
self
,
output_classes
,
batch_size
,
input_dims
=
3
,
dropout_prob
=
0.4
):
super
(
BiPointNet2SSGCls
,
self
).
__init__
()
super
(
BiPointNet2SSGCls
,
self
).
__init__
()
self
.
input_dims
=
input_dims
self
.
input_dims
=
input_dims
self
.
sa_module1
=
BiSAModule
(
512
,
batch_size
,
0.2
,
[
input_dims
,
64
,
64
,
128
])
self
.
sa_module1
=
BiSAModule
(
self
.
sa_module2
=
BiSAModule
(
128
,
batch_size
,
0.4
,
[
128
+
3
,
128
,
128
,
256
])
512
,
batch_size
,
0.2
,
[
input_dims
,
64
,
64
,
128
]
self
.
sa_module3
=
BiSAModule
(
None
,
batch_size
,
None
,
[
256
+
3
,
256
,
512
,
1024
],
)
group_all
=
True
)
self
.
sa_module2
=
BiSAModule
(
128
,
batch_size
,
0.4
,
[
128
+
3
,
128
,
128
,
256
]
)
self
.
sa_module3
=
BiSAModule
(
None
,
batch_size
,
None
,
[
256
+
3
,
256
,
512
,
1024
],
group_all
=
True
)
self
.
mlp1
=
BiLinearLSR
(
1024
,
512
)
self
.
mlp1
=
BiLinearLSR
(
1024
,
512
)
self
.
bn1
=
nn
.
BatchNorm1d
(
512
)
self
.
bn1
=
nn
.
BatchNorm1d
(
512
)
...
...
examples/pytorch/pointcloud/bipointnet/bipointnet_cls.py
View file @
704bcaf6
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch.autograd
import
Variable
import
numpy
as
np
from
basic
import
BiLinear
from
basic
import
BiLinear
from
torch.autograd
import
Variable
offset_map
=
{
1024
:
-
3.2041
,
2048
:
-
3.4025
,
4096
:
-
3.5836
}
offset_map
=
{
1024
:
-
3.2041
,
2048
:
-
3.4025
,
4096
:
-
3.5836
}
class
Conv1d
(
nn
.
Module
):
class
Conv1d
(
nn
.
Module
):
def
__init__
(
self
,
inplane
,
outplane
,
Linear
):
def
__init__
(
self
,
inplane
,
outplane
,
Linear
):
...
@@ -38,9 +35,16 @@ class EmaMaxPool(nn.Module):
...
@@ -38,9 +35,16 @@ class EmaMaxPool(nn.Module):
x
=
torch
.
max
(
x
,
2
,
keepdim
=
True
)[
0
]
-
0.3
x
=
torch
.
max
(
x
,
2
,
keepdim
=
True
)[
0
]
-
0.3
return
x
return
x
class
BiPointNetCls
(
nn
.
Module
):
class
BiPointNetCls
(
nn
.
Module
):
def
__init__
(
self
,
output_classes
,
input_dims
=
3
,
conv1_dim
=
64
,
def
__init__
(
use_transform
=
True
,
Linear
=
BiLinear
):
self
,
output_classes
,
input_dims
=
3
,
conv1_dim
=
64
,
use_transform
=
True
,
Linear
=
BiLinear
,
):
super
(
BiPointNetCls
,
self
).
__init__
()
super
(
BiPointNetCls
,
self
).
__init__
()
self
.
input_dims
=
input_dims
self
.
input_dims
=
input_dims
self
.
conv1
=
nn
.
ModuleList
()
self
.
conv1
=
nn
.
ModuleList
()
...
@@ -119,6 +123,7 @@ class BiPointNetCls(nn.Module):
...
@@ -119,6 +123,7 @@ class BiPointNetCls(nn.Module):
out
=
self
.
mlp_out
(
h
)
out
=
self
.
mlp_out
(
h
)
return
out
return
out
class
TransformNet
(
nn
.
Module
):
class
TransformNet
(
nn
.
Module
):
def
__init__
(
self
,
input_dims
=
3
,
conv1_dim
=
64
,
Linear
=
BiLinear
):
def
__init__
(
self
,
input_dims
=
3
,
conv1_dim
=
64
,
Linear
=
BiLinear
):
super
(
TransformNet
,
self
).
__init__
()
super
(
TransformNet
,
self
).
__init__
()
...
@@ -153,7 +158,7 @@ class TransformNet(nn.Module):
...
@@ -153,7 +158,7 @@ class TransformNet(nn.Module):
h
=
conv
(
h
)
h
=
conv
(
h
)
h
=
bn
(
h
)
h
=
bn
(
h
)
h
=
F
.
relu
(
h
)
h
=
F
.
relu
(
h
)
h
=
self
.
maxpool
(
h
).
view
(
-
1
,
self
.
pool_feat_len
)
h
=
self
.
maxpool
(
h
).
view
(
-
1
,
self
.
pool_feat_len
)
for
mlp
,
bn
in
zip
(
self
.
mlp2
,
self
.
bn2
):
for
mlp
,
bn
in
zip
(
self
.
mlp2
,
self
.
bn2
):
h
=
mlp
(
h
)
h
=
mlp
(
h
)
...
@@ -162,8 +167,14 @@ class TransformNet(nn.Module):
...
@@ -162,8 +167,14 @@ class TransformNet(nn.Module):
out
=
self
.
mlp_out
(
h
)
out
=
self
.
mlp_out
(
h
)
iden
=
Variable
(
torch
.
from_numpy
(
np
.
eye
(
self
.
input_dims
).
flatten
().
astype
(
np
.
float32
)))
iden
=
Variable
(
iden
=
iden
.
view
(
1
,
self
.
input_dims
*
self
.
input_dims
).
repeat
(
batch_size
,
1
)
torch
.
from_numpy
(
np
.
eye
(
self
.
input_dims
).
flatten
().
astype
(
np
.
float32
)
)
)
iden
=
iden
.
view
(
1
,
self
.
input_dims
*
self
.
input_dims
).
repeat
(
batch_size
,
1
)
if
out
.
is_cuda
:
if
out
.
is_cuda
:
iden
=
iden
.
cuda
()
iden
=
iden
.
cuda
()
out
=
out
+
iden
out
=
out
+
iden
...
...
examples/pytorch/pointcloud/bipointnet/train_cls.py
View file @
704bcaf6
from
bipointnet_cls
import
BiPointNetCls
from
bipointnet2
import
BiPointNet2SSGCls
from
ModelNetDataLoader
import
ModelNetDataLoader
import
provider
import
argparse
import
argparse
import
os
import
os
import
urllib
import
urllib
import
tqdm
from
functools
import
partial
from
functools
import
partial
from
dgl.data.utils
import
download
,
get_download_dir
import
dgl
import
dgl
from
torch.utils.data
import
DataLoader
import
provider
import
torch.optim
as
optim
import
torch.nn.functional
as
F
import
torch.nn
as
nn
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.optim
as
optim
import
tqdm
from
bipointnet2
import
BiPointNet2SSGCls
from
bipointnet_cls
import
BiPointNetCls
from
dgl.data.utils
import
download
,
get_download_dir
from
ModelNetDataLoader
import
ModelNetDataLoader
from
torch.utils.data
import
DataLoader
torch
.
backends
.
cudnn
.
enabled
=
False
torch
.
backends
.
cudnn
.
enabled
=
False
# from dataset import ModelNet
# from dataset import ModelNet
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'
--model
'
,
type
=
str
,
default
=
'
bipointnet
'
)
parser
.
add_argument
(
"
--model
"
,
type
=
str
,
default
=
"
bipointnet
"
)
parser
.
add_argument
(
'
--dataset-path
'
,
type
=
str
,
default
=
''
)
parser
.
add_argument
(
"
--dataset-path
"
,
type
=
str
,
default
=
""
)
parser
.
add_argument
(
'
--load-model-path
'
,
type
=
str
,
default
=
''
)
parser
.
add_argument
(
"
--load-model-path
"
,
type
=
str
,
default
=
""
)
parser
.
add_argument
(
'
--save-model-path
'
,
type
=
str
,
default
=
''
)
parser
.
add_argument
(
"
--save-model-path
"
,
type
=
str
,
default
=
""
)
parser
.
add_argument
(
'
--num-epochs
'
,
type
=
int
,
default
=
200
)
parser
.
add_argument
(
"
--num-epochs
"
,
type
=
int
,
default
=
200
)
parser
.
add_argument
(
'
--num-workers
'
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"
--num-workers
"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
'
--batch-size
'
,
type
=
int
,
default
=
32
)
parser
.
add_argument
(
"
--batch-size
"
,
type
=
int
,
default
=
32
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
num_workers
=
args
.
num_workers
num_workers
=
args
.
num_workers
batch_size
=
args
.
batch_size
batch_size
=
args
.
batch_size
data_filename
=
'
modelnet40_normal_resampled.zip
'
data_filename
=
"
modelnet40_normal_resampled.zip
"
download_path
=
os
.
path
.
join
(
get_download_dir
(),
data_filename
)
download_path
=
os
.
path
.
join
(
get_download_dir
(),
data_filename
)
local_path
=
args
.
dataset_path
or
os
.
path
.
join
(
local_path
=
args
.
dataset_path
or
os
.
path
.
join
(
get_download_dir
(),
'modelnet40_normal_resampled'
)
get_download_dir
(),
"modelnet40_normal_resampled"
)
if
not
os
.
path
.
exists
(
local_path
):
if
not
os
.
path
.
exists
(
local_path
):
download
(
'https://shapenet.cs.stanford.edu/media/modelnet40_normal_resampled.zip'
,
download
(
download_path
,
verify_ssl
=
False
)
"https://shapenet.cs.stanford.edu/media/modelnet40_normal_resampled.zip"
,
download_path
,
verify_ssl
=
False
,
)
from
zipfile
import
ZipFile
from
zipfile
import
ZipFile
with
ZipFile
(
download_path
)
as
z
:
with
ZipFile
(
download_path
)
as
z
:
z
.
extractall
(
path
=
get_download_dir
())
z
.
extractall
(
path
=
get_download_dir
())
...
@@ -48,11 +55,11 @@ CustomDataLoader = partial(
...
@@ -48,11 +55,11 @@ CustomDataLoader = partial(
num_workers
=
num_workers
,
num_workers
=
num_workers
,
batch_size
=
batch_size
,
batch_size
=
batch_size
,
shuffle
=
True
,
shuffle
=
True
,
drop_last
=
True
)
drop_last
=
True
,
)
def
train
(
net
,
opt
,
scheduler
,
train_loader
,
dev
):
def
train
(
net
,
opt
,
scheduler
,
train_loader
,
dev
):
net
.
train
()
net
.
train
()
total_loss
=
0
total_loss
=
0
...
@@ -64,8 +71,7 @@ def train(net, opt, scheduler, train_loader, dev):
...
@@ -64,8 +71,7 @@ def train(net, opt, scheduler, train_loader, dev):
for
data
,
label
in
tq
:
for
data
,
label
in
tq
:
data
=
data
.
data
.
numpy
()
data
=
data
.
data
.
numpy
()
data
=
provider
.
random_point_dropout
(
data
)
data
=
provider
.
random_point_dropout
(
data
)
data
[:,
:,
0
:
3
]
=
provider
.
random_scale_point_cloud
(
data
[:,
:,
0
:
3
]
=
provider
.
random_scale_point_cloud
(
data
[:,
:,
0
:
3
])
data
[:,
:,
0
:
3
])
data
[:,
:,
0
:
3
]
=
provider
.
jitter_point_cloud
(
data
[:,
:,
0
:
3
])
data
[:,
:,
0
:
3
]
=
provider
.
jitter_point_cloud
(
data
[:,
:,
0
:
3
])
data
[:,
:,
0
:
3
]
=
provider
.
shift_point_cloud
(
data
[:,
:,
0
:
3
])
data
[:,
:,
0
:
3
]
=
provider
.
shift_point_cloud
(
data
[:,
:,
0
:
3
])
data
=
torch
.
tensor
(
data
)
data
=
torch
.
tensor
(
data
)
...
@@ -88,9 +94,12 @@ def train(net, opt, scheduler, train_loader, dev):
...
@@ -88,9 +94,12 @@ def train(net, opt, scheduler, train_loader, dev):
total_loss
+=
loss
total_loss
+=
loss
total_correct
+=
correct
total_correct
+=
correct
tq
.
set_postfix
({
tq
.
set_postfix
(
'AvgLoss'
:
'%.5f'
%
(
total_loss
/
num_batches
),
{
'AvgAcc'
:
'%.5f'
%
(
total_correct
/
count
)})
"AvgLoss"
:
"%.5f"
%
(
total_loss
/
num_batches
),
"AvgAcc"
:
"%.5f"
%
(
total_correct
/
count
),
}
)
scheduler
.
step
()
scheduler
.
step
()
...
@@ -113,17 +122,16 @@ def evaluate(net, test_loader, dev):
...
@@ -113,17 +122,16 @@ def evaluate(net, test_loader, dev):
total_correct
+=
correct
total_correct
+=
correct
count
+=
num_examples
count
+=
num_examples
tq
.
set_postfix
({
tq
.
set_postfix
({
"AvgAcc"
:
"%.5f"
%
(
total_correct
/
count
)})
'AvgAcc'
:
'%.5f'
%
(
total_correct
/
count
)})
return
total_correct
/
count
return
total_correct
/
count
dev
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
dev
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
if
args
.
model
==
'
bipointnet
'
:
if
args
.
model
==
"
bipointnet
"
:
net
=
BiPointNetCls
(
40
,
input_dims
=
6
)
net
=
BiPointNetCls
(
40
,
input_dims
=
6
)
elif
args
.
model
==
'
bipointnet2_ssg
'
:
elif
args
.
model
==
"
bipointnet2_ssg
"
:
net
=
BiPointNet2SSGCls
(
40
,
batch_size
,
input_dims
=
6
)
net
=
BiPointNet2SSGCls
(
40
,
batch_size
,
input_dims
=
6
)
net
=
net
.
to
(
dev
)
net
=
net
.
to
(
dev
)
...
@@ -134,23 +142,32 @@ opt = optim.Adam(net.parameters(), lr=1e-3, weight_decay=1e-4)
...
@@ -134,23 +142,32 @@ opt = optim.Adam(net.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler
=
optim
.
lr_scheduler
.
StepLR
(
opt
,
step_size
=
20
,
gamma
=
0.7
)
scheduler
=
optim
.
lr_scheduler
.
StepLR
(
opt
,
step_size
=
20
,
gamma
=
0.7
)
train_dataset
=
ModelNetDataLoader
(
local_path
,
1024
,
split
=
'
train
'
)
train_dataset
=
ModelNetDataLoader
(
local_path
,
1024
,
split
=
"
train
"
)
test_dataset
=
ModelNetDataLoader
(
local_path
,
1024
,
split
=
'
test
'
)
test_dataset
=
ModelNetDataLoader
(
local_path
,
1024
,
split
=
"
test
"
)
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
train_dataset
,
batch_size
=
batch_size
,
shuffle
=
True
,
num_workers
=
num_workers
,
drop_last
=
True
)
train_dataset
,
batch_size
=
batch_size
,
shuffle
=
True
,
num_workers
=
num_workers
,
drop_last
=
True
,
)
test_loader
=
torch
.
utils
.
data
.
DataLoader
(
test_loader
=
torch
.
utils
.
data
.
DataLoader
(
test_dataset
,
batch_size
=
batch_size
,
shuffle
=
False
,
num_workers
=
num_workers
,
drop_last
=
True
)
test_dataset
,
batch_size
=
batch_size
,
shuffle
=
False
,
num_workers
=
num_workers
,
drop_last
=
True
,
)
best_test_acc
=
0
best_test_acc
=
0
for
epoch
in
range
(
args
.
num_epochs
):
for
epoch
in
range
(
args
.
num_epochs
):
train
(
net
,
opt
,
scheduler
,
train_loader
,
dev
)
train
(
net
,
opt
,
scheduler
,
train_loader
,
dev
)
if
(
epoch
+
1
)
%
1
==
0
:
if
(
epoch
+
1
)
%
1
==
0
:
print
(
'
Epoch #%d Testing
'
%
epoch
)
print
(
"
Epoch #%d Testing
"
%
epoch
)
test_acc
=
evaluate
(
net
,
test_loader
,
dev
)
test_acc
=
evaluate
(
net
,
test_loader
,
dev
)
if
test_acc
>
best_test_acc
:
if
test_acc
>
best_test_acc
:
best_test_acc
=
test_acc
best_test_acc
=
test_acc
if
args
.
save_model_path
:
if
args
.
save_model_path
:
torch
.
save
(
net
.
state_dict
(),
args
.
save_model_path
)
torch
.
save
(
net
.
state_dict
(),
args
.
save_model_path
)
print
(
'Current test acc: %.5f (best: %.5f)'
%
(
print
(
"Current test acc: %.5f (best: %.5f)"
%
(
test_acc
,
best_test_acc
))
test_acc
,
best_test_acc
))
examples/pytorch/pointcloud/edgeconv/main.py
View file @
704bcaf6
...
@@ -8,11 +8,11 @@ import torch.nn as nn
...
@@ -8,11 +8,11 @@ import torch.nn as nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
torch.optim
as
optim
import
torch.optim
as
optim
import
tqdm
import
tqdm
from
model
import
Model
,
compute_loss
from
modelnet
import
ModelNet
from
torch.utils.data
import
DataLoader
from
dgl.data.utils
import
download
,
get_download_dir
from
dgl.data.utils
import
download
,
get_download_dir
from
model
import
compute_loss
,
Model
from
modelnet
import
ModelNet
from
torch.utils.data
import
DataLoader
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--dataset-path"
,
type
=
str
,
default
=
""
)
parser
.
add_argument
(
"--dataset-path"
,
type
=
str
,
default
=
""
)
...
...
examples/pytorch/pointcloud/pct/ShapeNet.py
View file @
704bcaf6
...
@@ -2,14 +2,14 @@ import json
...
@@ -2,14 +2,14 @@ import json
import
os
import
os
from
zipfile
import
ZipFile
from
zipfile
import
ZipFile
import
dgl
import
numpy
as
np
import
numpy
as
np
import
tqdm
import
tqdm
from
dgl.data.utils
import
download
,
get_download_dir
from
scipy.sparse
import
csr_matrix
from
scipy.sparse
import
csr_matrix
from
torch.utils.data
import
Dataset
from
torch.utils.data
import
Dataset
import
dgl
from
dgl.data.utils
import
download
,
get_download_dir
class
ShapeNet
(
object
):
class
ShapeNet
(
object
):
def
__init__
(
self
,
num_points
=
2048
,
normal_channel
=
True
):
def
__init__
(
self
,
num_points
=
2048
,
normal_channel
=
True
):
...
...
examples/pytorch/pointcloud/pct/helper.py
View file @
704bcaf6
import
dgl
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
dgl
from
dgl.geometry
import
farthest_point_sampler
from
dgl.geometry
import
farthest_point_sampler
"""
"""
...
...
examples/pytorch/pointcloud/pct/train_cls.py
View file @
704bcaf6
...
@@ -7,12 +7,12 @@ import provider
...
@@ -7,12 +7,12 @@ import provider
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
tqdm
import
tqdm
from
dgl.data.utils
import
download
,
get_download_dir
from
ModelNetDataLoader
import
ModelNetDataLoader
from
ModelNetDataLoader
import
ModelNetDataLoader
from
pct
import
PointTransformerCLS
from
pct
import
PointTransformerCLS
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
from
dgl.data.utils
import
download
,
get_download_dir
torch
.
backends
.
cudnn
.
enabled
=
False
torch
.
backends
.
cudnn
.
enabled
=
False
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
...
@@ -54,7 +54,6 @@ CustomDataLoader = partial(
...
@@ -54,7 +54,6 @@ CustomDataLoader = partial(
def
train
(
net
,
opt
,
scheduler
,
train_loader
,
dev
):
def
train
(
net
,
opt
,
scheduler
,
train_loader
,
dev
):
net
.
train
()
net
.
train
()
total_loss
=
0
total_loss
=
0
...
...
examples/pytorch/pointcloud/pct/train_partseg.py
View file @
704bcaf6
...
@@ -2,6 +2,8 @@ import argparse
...
@@ -2,6 +2,8 @@ import argparse
import
time
import
time
from
functools
import
partial
from
functools
import
partial
import
dgl
import
numpy
as
np
import
numpy
as
np
import
provider
import
provider
import
torch
import
torch
...
@@ -11,8 +13,6 @@ from pct import PartSegLoss, PointTransformerSeg
...
@@ -11,8 +13,6 @@ from pct import PartSegLoss, PointTransformerSeg
from
ShapeNet
import
ShapeNet
from
ShapeNet
import
ShapeNet
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
import
dgl
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--dataset-path"
,
type
=
str
,
default
=
""
)
parser
.
add_argument
(
"--dataset-path"
,
type
=
str
,
default
=
""
)
parser
.
add_argument
(
"--load-model-path"
,
type
=
str
,
default
=
""
)
parser
.
add_argument
(
"--load-model-path"
,
type
=
str
,
default
=
""
)
...
...
examples/pytorch/pointcloud/point_transformer/ShapeNet.py
View file @
704bcaf6
...
@@ -2,14 +2,14 @@ import json
...
@@ -2,14 +2,14 @@ import json
import
os
import
os
from
zipfile
import
ZipFile
from
zipfile
import
ZipFile
import
dgl
import
numpy
as
np
import
numpy
as
np
import
tqdm
import
tqdm
from
dgl.data.utils
import
download
,
get_download_dir
from
scipy.sparse
import
csr_matrix
from
scipy.sparse
import
csr_matrix
from
torch.utils.data
import
Dataset
from
torch.utils.data
import
Dataset
import
dgl
from
dgl.data.utils
import
download
,
get_download_dir
class
ShapeNet
(
object
):
class
ShapeNet
(
object
):
def
__init__
(
self
,
num_points
=
2048
,
normal_channel
=
True
):
def
__init__
(
self
,
num_points
=
2048
,
normal_channel
=
True
):
...
...
examples/pytorch/pointcloud/point_transformer/helper.py
View file @
704bcaf6
import
dgl
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
dgl
from
dgl.geometry
import
farthest_point_sampler
from
dgl.geometry
import
farthest_point_sampler
"""
"""
...
@@ -270,7 +269,6 @@ class TransitionUp(nn.Module):
...
@@ -270,7 +269,6 @@ class TransitionUp(nn.Module):
"""
"""
def
__init__
(
self
,
dim1
,
dim2
,
dim_out
):
def
__init__
(
self
,
dim1
,
dim2
,
dim_out
):
super
(
TransitionUp
,
self
).
__init__
()
super
(
TransitionUp
,
self
).
__init__
()
self
.
fc1
=
nn
.
Sequential
(
self
.
fc1
=
nn
.
Sequential
(
nn
.
Linear
(
dim1
,
dim_out
),
nn
.
Linear
(
dim1
,
dim_out
),
...
...
examples/pytorch/pointcloud/point_transformer/point_transformer.py
View file @
704bcaf6
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
helper
import
TransitionDown
,
TransitionUp
,
index_points
,
square_distance
from
helper
import
index_points
,
square_distance
,
TransitionDown
,
TransitionUp
from
torch
import
nn
from
torch
import
nn
"""
"""
...
...
examples/pytorch/pointcloud/point_transformer/train_cls.py
View file @
704bcaf6
...
@@ -7,12 +7,12 @@ import provider
...
@@ -7,12 +7,12 @@ import provider
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
tqdm
import
tqdm
from
dgl.data.utils
import
download
,
get_download_dir
from
ModelNetDataLoader
import
ModelNetDataLoader
from
ModelNetDataLoader
import
ModelNetDataLoader
from
point_transformer
import
PointTransformerCLS
from
point_transformer
import
PointTransformerCLS
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
from
dgl.data.utils
import
download
,
get_download_dir
torch
.
backends
.
cudnn
.
enabled
=
False
torch
.
backends
.
cudnn
.
enabled
=
False
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
...
@@ -55,7 +55,6 @@ CustomDataLoader = partial(
...
@@ -55,7 +55,6 @@ CustomDataLoader = partial(
def
train
(
net
,
opt
,
scheduler
,
train_loader
,
dev
):
def
train
(
net
,
opt
,
scheduler
,
train_loader
,
dev
):
net
.
train
()
net
.
train
()
total_loss
=
0
total_loss
=
0
...
...
examples/pytorch/pointcloud/point_transformer/train_partseg.py
View file @
704bcaf6
...
@@ -2,6 +2,8 @@ import argparse
...
@@ -2,6 +2,8 @@ import argparse
import
time
import
time
from
functools
import
partial
from
functools
import
partial
import
dgl
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.optim
as
optim
import
torch.optim
as
optim
...
@@ -10,8 +12,6 @@ from point_transformer import PartSegLoss, PointTransformerSeg
...
@@ -10,8 +12,6 @@ from point_transformer import PartSegLoss, PointTransformerSeg
from
ShapeNet
import
ShapeNet
from
ShapeNet
import
ShapeNet
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
import
dgl
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--dataset-path"
,
type
=
str
,
default
=
""
)
parser
.
add_argument
(
"--dataset-path"
,
type
=
str
,
default
=
""
)
parser
.
add_argument
(
"--load-model-path"
,
type
=
str
,
default
=
""
)
parser
.
add_argument
(
"--load-model-path"
,
type
=
str
,
default
=
""
)
...
...
examples/pytorch/pointcloud/pointnet/ShapeNet.py
View file @
704bcaf6
...
@@ -2,14 +2,14 @@ import json
...
@@ -2,14 +2,14 @@ import json
import
os
import
os
from
zipfile
import
ZipFile
from
zipfile
import
ZipFile
import
dgl
import
numpy
as
np
import
numpy
as
np
import
tqdm
import
tqdm
from
dgl.data.utils
import
download
,
get_download_dir
from
scipy.sparse
import
csr_matrix
from
scipy.sparse
import
csr_matrix
from
torch.utils.data
import
Dataset
from
torch.utils.data
import
Dataset
import
dgl
from
dgl.data.utils
import
download
,
get_download_dir
class
ShapeNet
(
object
):
class
ShapeNet
(
object
):
def
__init__
(
self
,
num_points
=
2048
,
normal_channel
=
True
):
def
__init__
(
self
,
num_points
=
2048
,
normal_channel
=
True
):
...
...
examples/pytorch/pointcloud/pointnet/pointnet2.py
View file @
704bcaf6
import
dgl
import
dgl.function
as
fn
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch.autograd
import
Variable
import
dgl
import
dgl.function
as
fn
from
dgl.geometry
import
(
from
dgl.geometry
import
(
farthest_point_sampler
,
farthest_point_sampler
,
)
# dgl.geometry.pytorch -> dgl.geometry
)
# dgl.geometry.pytorch -> dgl.geometry
from
torch.autograd
import
Variable
"""
"""
Part of the code are adapted from
Part of the code are adapted from
...
...
examples/pytorch/pointcloud/pointnet/train_cls.py
View file @
704bcaf6
...
@@ -3,20 +3,20 @@ import os
...
@@ -3,20 +3,20 @@ import os
import
urllib
import
urllib
from
functools
import
partial
from
functools
import
partial
import
dgl
import
provider
import
provider
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
torch.optim
as
optim
import
torch.optim
as
optim
import
tqdm
import
tqdm
from
dgl.data.utils
import
download
,
get_download_dir
from
ModelNetDataLoader
import
ModelNetDataLoader
from
ModelNetDataLoader
import
ModelNetDataLoader
from
pointnet2
import
PointNet2MSGCls
,
PointNet2SSGCls
from
pointnet2
import
PointNet2MSGCls
,
PointNet2SSGCls
from
pointnet_cls
import
PointNetCls
from
pointnet_cls
import
PointNetCls
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
import
dgl
from
dgl.data.utils
import
download
,
get_download_dir
torch
.
backends
.
cudnn
.
enabled
=
False
torch
.
backends
.
cudnn
.
enabled
=
False
...
@@ -62,7 +62,6 @@ CustomDataLoader = partial(
...
@@ -62,7 +62,6 @@ CustomDataLoader = partial(
def
train
(
net
,
opt
,
scheduler
,
train_loader
,
dev
):
def
train
(
net
,
opt
,
scheduler
,
train_loader
,
dev
):
net
.
train
()
net
.
train
()
total_loss
=
0
total_loss
=
0
...
...
examples/pytorch/pointcloud/pointnet/train_partseg.py
View file @
704bcaf6
...
@@ -4,20 +4,20 @@ import time
...
@@ -4,20 +4,20 @@ import time
import
urllib
import
urllib
from
functools
import
partial
from
functools
import
partial
import
dgl
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
torch.optim
as
optim
import
torch.optim
as
optim
import
tqdm
import
tqdm
from
dgl.data.utils
import
download
,
get_download_dir
from
pointnet2_partseg
import
PointNet2MSGPartSeg
,
PointNet2SSGPartSeg
from
pointnet2_partseg
import
PointNet2MSGPartSeg
,
PointNet2SSGPartSeg
from
pointnet_partseg
import
PartSegLoss
,
PointNetPartSeg
from
pointnet_partseg
import
PartSegLoss
,
PointNetPartSeg
from
ShapeNet
import
ShapeNet
from
ShapeNet
import
ShapeNet
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
import
dgl
from
dgl.data.utils
import
download
,
get_download_dir
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--model"
,
type
=
str
,
default
=
"pointnet"
)
parser
.
add_argument
(
"--model"
,
type
=
str
,
default
=
"pointnet"
)
parser
.
add_argument
(
"--dataset-path"
,
type
=
str
,
default
=
""
)
parser
.
add_argument
(
"--dataset-path"
,
type
=
str
,
default
=
""
)
...
@@ -260,6 +260,8 @@ color_map = torch.tensor(
...
@@ -260,6 +260,8 @@ color_map = torch.tensor(
[
255
,
105
,
180
],
[
255
,
105
,
180
],
]
]
)
)
# paint each point according to its pred
# paint each point according to its pred
def
paint
(
batched_points
):
def
paint
(
batched_points
):
B
,
N
=
batched_points
.
shape
B
,
N
=
batched_points
.
shape
...
...
examples/pytorch/rect/utils.py
View file @
704bcaf6
import
torch
import
dgl
import
dgl
import
torch
from
dgl.data
import
CiteseerGraphDataset
,
CoraGraphDataset
from
dgl.data
import
CiteseerGraphDataset
,
CoraGraphDataset
...
...
examples/pytorch/rgat/train.py
View file @
704bcaf6
import
dgl
import
dgl.function
as
fn
import
dgl.nn
as
dglnn
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
torchmetrics.functional
as
MF
import
torchmetrics.functional
as
MF
import
dgl
import
tqdm
import
dgl.function
as
fn
import
dgl.nn
as
dglnn
from
dgl.dataloading
import
NeighborSampler
,
DataLoader
from
dgl
import
apply_each
from
dgl
import
apply_each
from
dgl.dataloading
import
DataLoader
,
NeighborSampler
from
ogb.nodeproppred
import
DglNodePropPredDataset
from
ogb.nodeproppred
import
DglNodePropPredDataset
import
tqdm
class
HeteroGAT
(
nn
.
Module
):
class
HeteroGAT
(
nn
.
Module
):
def
__init__
(
self
,
etypes
,
in_size
,
hid_size
,
out_size
,
n_heads
=
4
):
def
__init__
(
self
,
etypes
,
in_size
,
hid_size
,
out_size
,
n_heads
=
4
):
super
().
__init__
()
super
().
__init__
()
self
.
layers
=
nn
.
ModuleList
()
self
.
layers
=
nn
.
ModuleList
()
self
.
layers
.
append
(
dglnn
.
HeteroGraphConv
({
self
.
layers
.
append
(
etype
:
dglnn
.
GATConv
(
in_size
,
hid_size
//
n_heads
,
n_heads
)
dglnn
.
HeteroGraphConv
(
for
etype
in
etypes
}))
{
self
.
layers
.
append
(
dglnn
.
HeteroGraphConv
({
etype
:
dglnn
.
GATConv
(
in_size
,
hid_size
//
n_heads
,
n_heads
)
etype
:
dglnn
.
GATConv
(
hid_size
,
hid_size
//
n_heads
,
n_heads
)
for
etype
in
etypes
for
etype
in
etypes
}))
}
self
.
layers
.
append
(
dglnn
.
HeteroGraphConv
({
)
etype
:
dglnn
.
GATConv
(
hid_size
,
hid_size
//
n_heads
,
n_heads
)
)
for
etype
in
etypes
}))
self
.
layers
.
append
(
dglnn
.
HeteroGraphConv
(
{
etype
:
dglnn
.
GATConv
(
hid_size
,
hid_size
//
n_heads
,
n_heads
)
for
etype
in
etypes
}
)
)
self
.
layers
.
append
(
dglnn
.
HeteroGraphConv
(
{
etype
:
dglnn
.
GATConv
(
hid_size
,
hid_size
//
n_heads
,
n_heads
)
for
etype
in
etypes
}
)
)
self
.
dropout
=
nn
.
Dropout
(
0.5
)
self
.
dropout
=
nn
.
Dropout
(
0.5
)
self
.
linear
=
nn
.
Linear
(
hid_size
,
out_size
)
# Should be HeteroLinear
self
.
linear
=
nn
.
Linear
(
hid_size
,
out_size
)
# Should be HeteroLinear
def
forward
(
self
,
blocks
,
x
):
def
forward
(
self
,
blocks
,
x
):
h
=
x
h
=
x
...
@@ -32,19 +48,24 @@ class HeteroGAT(nn.Module):
...
@@ -32,19 +48,24 @@ class HeteroGAT(nn.Module):
h
=
layer
(
block
,
h
)
h
=
layer
(
block
,
h
)
# One thing is that h might return tensors with zero rows if the number of dst nodes
# One thing is that h might return tensors with zero rows if the number of dst nodes
# of one node type is 0. x.view(x.shape[0], -1) wouldn't work in this case.
# of one node type is 0. x.view(x.shape[0], -1) wouldn't work in this case.
h
=
apply_each
(
h
,
lambda
x
:
x
.
view
(
x
.
shape
[
0
],
x
.
shape
[
1
]
*
x
.
shape
[
2
]))
h
=
apply_each
(
h
,
lambda
x
:
x
.
view
(
x
.
shape
[
0
],
x
.
shape
[
1
]
*
x
.
shape
[
2
])
)
if
l
!=
len
(
self
.
layers
)
-
1
:
if
l
!=
len
(
self
.
layers
)
-
1
:
h
=
apply_each
(
h
,
F
.
relu
)
h
=
apply_each
(
h
,
F
.
relu
)
h
=
apply_each
(
h
,
self
.
dropout
)
h
=
apply_each
(
h
,
self
.
dropout
)
return
self
.
linear
(
h
[
'paper'
])
return
self
.
linear
(
h
[
"paper"
])
def
evaluate
(
model
,
dataloader
,
desc
):
def
evaluate
(
model
,
dataloader
,
desc
):
preds
=
[]
preds
=
[]
labels
=
[]
labels
=
[]
with
torch
.
no_grad
():
with
torch
.
no_grad
():
for
input_nodes
,
output_nodes
,
blocks
in
tqdm
.
tqdm
(
dataloader
,
desc
=
desc
):
for
input_nodes
,
output_nodes
,
blocks
in
tqdm
.
tqdm
(
x
=
blocks
[
0
].
srcdata
[
'feat'
]
dataloader
,
desc
=
desc
y
=
blocks
[
-
1
].
dstdata
[
'label'
][
'paper'
][:,
0
]
):
x
=
blocks
[
0
].
srcdata
[
"feat"
]
y
=
blocks
[
-
1
].
dstdata
[
"label"
][
"paper"
][:,
0
]
y_hat
=
model
(
blocks
,
x
)
y_hat
=
model
(
blocks
,
x
)
preds
.
append
(
y_hat
.
cpu
())
preds
.
append
(
y_hat
.
cpu
())
labels
.
append
(
y
.
cpu
())
labels
.
append
(
y
.
cpu
())
...
@@ -53,6 +74,7 @@ def evaluate(model, dataloader, desc):
...
@@ -53,6 +74,7 @@ def evaluate(model, dataloader, desc):
acc
=
MF
.
accuracy
(
preds
,
labels
)
acc
=
MF
.
accuracy
(
preds
,
labels
)
return
acc
return
acc
def
train
(
train_loader
,
val_loader
,
test_loader
,
model
):
def
train
(
train_loader
,
val_loader
,
test_loader
,
model
):
# loss function and optimizer
# loss function and optimizer
loss_fcn
=
nn
.
CrossEntropyLoss
()
loss_fcn
=
nn
.
CrossEntropyLoss
()
...
@@ -62,9 +84,11 @@ def train(train_loader, val_loader, test_loader, model):
...
@@ -62,9 +84,11 @@ def train(train_loader, val_loader, test_loader, model):
for
epoch
in
range
(
10
):
for
epoch
in
range
(
10
):
model
.
train
()
model
.
train
()
total_loss
=
0
total_loss
=
0
for
it
,
(
input_nodes
,
output_nodes
,
blocks
)
in
enumerate
(
tqdm
.
tqdm
(
train_dataloader
,
desc
=
"Train"
)):
for
it
,
(
input_nodes
,
output_nodes
,
blocks
)
in
enumerate
(
x
=
blocks
[
0
].
srcdata
[
'feat'
]
tqdm
.
tqdm
(
train_dataloader
,
desc
=
"Train"
)
y
=
blocks
[
-
1
].
dstdata
[
'label'
][
'paper'
][:,
0
]
):
x
=
blocks
[
0
].
srcdata
[
"feat"
]
y
=
blocks
[
-
1
].
dstdata
[
"label"
][
"paper"
][:,
0
]
y_hat
=
model
(
blocks
,
x
)
y_hat
=
model
(
blocks
,
x
)
loss
=
loss_fcn
(
y_hat
,
y
)
loss
=
loss_fcn
(
y_hat
,
y
)
opt
.
zero_grad
()
opt
.
zero_grad
()
...
@@ -72,51 +96,94 @@ def train(train_loader, val_loader, test_loader, model):
...
@@ -72,51 +96,94 @@ def train(train_loader, val_loader, test_loader, model):
opt
.
step
()
opt
.
step
()
total_loss
+=
loss
.
item
()
total_loss
+=
loss
.
item
()
model
.
eval
()
model
.
eval
()
val_acc
=
evaluate
(
model
,
val_dataloader
,
'Val. '
)
val_acc
=
evaluate
(
model
,
val_dataloader
,
"Val. "
)
test_acc
=
evaluate
(
model
,
test_dataloader
,
'Test '
)
test_acc
=
evaluate
(
model
,
test_dataloader
,
"Test "
)
print
(
f
'Epoch
{
epoch
:
05
d
}
| Loss
{
total_loss
/
(
it
+
1
):.
4
f
}
| Validation Acc.
{
val_acc
.
item
():.
4
f
}
| Test Acc.
{
test_acc
.
item
():.
4
f
}
'
)
print
(
f
"Epoch
{
epoch
:
05
d
}
| Loss
{
total_loss
/
(
it
+
1
):.
4
f
}
| Validation Acc.
{
val_acc
.
item
():.
4
f
}
| Test Acc.
{
test_acc
.
item
():.
4
f
}
"
)
if
__name__
==
'__main__'
:
if
__name__
==
"__main__"
:
print
(
f
'Training with DGL built-in HeteroGraphConv using GATConv as its convolution sub-modules'
)
print
(
device
=
torch
.
device
(
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
)
f
"Training with DGL built-in HeteroGraphConv using GATConv as its convolution sub-modules"
)
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
# load and preprocess dataset
# load and preprocess dataset
print
(
'
Loading data
'
)
print
(
"
Loading data
"
)
dataset
=
DglNodePropPredDataset
(
'
ogbn-mag
'
)
dataset
=
DglNodePropPredDataset
(
"
ogbn-mag
"
)
graph
,
labels
=
dataset
[
0
]
graph
,
labels
=
dataset
[
0
]
graph
.
ndata
[
'
label
'
]
=
labels
graph
.
ndata
[
"
label
"
]
=
labels
# add reverse edges in "cites" relation, and add reverse edge types for the rest etypes
# add reverse edges in "cites" relation, and add reverse edge types for the rest etypes
graph
=
dgl
.
AddReverse
()(
graph
)
graph
=
dgl
.
AddReverse
()(
graph
)
# precompute the author, topic, and institution features
# precompute the author, topic, and institution features
graph
.
update_all
(
fn
.
copy_u
(
'feat'
,
'm'
),
fn
.
mean
(
'm'
,
'feat'
),
etype
=
'rev_writes'
)
graph
.
update_all
(
graph
.
update_all
(
fn
.
copy_u
(
'feat'
,
'm'
),
fn
.
mean
(
'm'
,
'feat'
),
etype
=
'has_topic'
)
fn
.
copy_u
(
"feat"
,
"m"
),
fn
.
mean
(
"m"
,
"feat"
),
etype
=
"rev_writes"
graph
.
update_all
(
fn
.
copy_u
(
'feat'
,
'm'
),
fn
.
mean
(
'm'
,
'feat'
),
etype
=
'affiliated_with'
)
)
graph
.
update_all
(
fn
.
copy_u
(
"feat"
,
"m"
),
fn
.
mean
(
"m"
,
"feat"
),
etype
=
"has_topic"
)
graph
.
update_all
(
fn
.
copy_u
(
"feat"
,
"m"
),
fn
.
mean
(
"m"
,
"feat"
),
etype
=
"affiliated_with"
)
# find train/val/test indexes
# find train/val/test indexes
split_idx
=
dataset
.
get_idx_split
()
split_idx
=
dataset
.
get_idx_split
()
train_idx
,
val_idx
,
test_idx
=
split_idx
[
'train'
],
split_idx
[
'valid'
],
split_idx
[
'test'
]
train_idx
,
val_idx
,
test_idx
=
(
split_idx
[
"train"
],
split_idx
[
"valid"
],
split_idx
[
"test"
],
)
train_idx
=
apply_each
(
train_idx
,
lambda
x
:
x
.
to
(
device
))
train_idx
=
apply_each
(
train_idx
,
lambda
x
:
x
.
to
(
device
))
val_idx
=
apply_each
(
val_idx
,
lambda
x
:
x
.
to
(
device
))
val_idx
=
apply_each
(
val_idx
,
lambda
x
:
x
.
to
(
device
))
test_idx
=
apply_each
(
test_idx
,
lambda
x
:
x
.
to
(
device
))
test_idx
=
apply_each
(
test_idx
,
lambda
x
:
x
.
to
(
device
))
# create RGAT model
# create RGAT model
in_size
=
graph
.
ndata
[
'
feat
'
][
'
paper
'
].
shape
[
1
]
in_size
=
graph
.
ndata
[
"
feat
"
][
"
paper
"
].
shape
[
1
]
out_size
=
dataset
.
num_classes
out_size
=
dataset
.
num_classes
model
=
HeteroGAT
(
graph
.
etypes
,
in_size
,
256
,
out_size
).
to
(
device
)
model
=
HeteroGAT
(
graph
.
etypes
,
in_size
,
256
,
out_size
).
to
(
device
)
# dataloader + model training + testing
# dataloader + model training + testing
train_sampler
=
NeighborSampler
([
5
,
5
,
5
],
train_sampler
=
NeighborSampler
(
prefetch_node_feats
=
{
k
:
[
'feat'
]
for
k
in
graph
.
ntypes
},
[
5
,
5
,
5
],
prefetch_labels
=
{
'paper'
:
[
'label'
]})
prefetch_node_feats
=
{
k
:
[
"feat"
]
for
k
in
graph
.
ntypes
},
val_sampler
=
NeighborSampler
([
10
,
10
,
10
],
prefetch_labels
=
{
"paper"
:
[
"label"
]},
prefetch_node_feats
=
{
k
:
[
'feat'
]
for
k
in
graph
.
ntypes
},
)
prefetch_labels
=
{
'paper'
:
[
'label'
]})
val_sampler
=
NeighborSampler
(
train_dataloader
=
DataLoader
(
graph
,
train_idx
,
train_sampler
,
[
10
,
10
,
10
],
device
=
device
,
batch_size
=
1000
,
shuffle
=
True
,
prefetch_node_feats
=
{
k
:
[
"feat"
]
for
k
in
graph
.
ntypes
},
drop_last
=
False
,
num_workers
=
0
,
use_uva
=
torch
.
cuda
.
is_available
())
prefetch_labels
=
{
"paper"
:
[
"label"
]},
val_dataloader
=
DataLoader
(
graph
,
val_idx
,
val_sampler
,
)
device
=
device
,
batch_size
=
1000
,
shuffle
=
False
,
train_dataloader
=
DataLoader
(
drop_last
=
False
,
num_workers
=
0
,
use_uva
=
torch
.
cuda
.
is_available
())
graph
,
test_dataloader
=
DataLoader
(
graph
,
test_idx
,
val_sampler
,
train_idx
,
device
=
device
,
batch_size
=
1000
,
shuffle
=
False
,
train_sampler
,
drop_last
=
False
,
num_workers
=
0
,
use_uva
=
torch
.
cuda
.
is_available
())
device
=
device
,
batch_size
=
1000
,
shuffle
=
True
,
drop_last
=
False
,
num_workers
=
0
,
use_uva
=
torch
.
cuda
.
is_available
(),
)
val_dataloader
=
DataLoader
(
graph
,
val_idx
,
val_sampler
,
device
=
device
,
batch_size
=
1000
,
shuffle
=
False
,
drop_last
=
False
,
num_workers
=
0
,
use_uva
=
torch
.
cuda
.
is_available
(),
)
test_dataloader
=
DataLoader
(
graph
,
test_idx
,
val_sampler
,
device
=
device
,
batch_size
=
1000
,
shuffle
=
False
,
drop_last
=
False
,
num_workers
=
0
,
use_uva
=
torch
.
cuda
.
is_available
(),
)
train
(
train_dataloader
,
val_dataloader
,
test_dataloader
,
model
)
train
(
train_dataloader
,
val_dataloader
,
test_dataloader
,
model
)
examples/pytorch/rgcn-hetero/entity_classify.py
View file @
704bcaf6
...
@@ -9,9 +9,9 @@ import numpy as np
...
@@ -9,9 +9,9 @@ import numpy as np
import
torch
as
th
import
torch
as
th
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
model
import
EntityClassify
from
dgl.data.rdf
import
AIFBDataset
,
AMDataset
,
BGSDataset
,
MUTAGDataset
from
dgl.data.rdf
import
AIFBDataset
,
AMDataset
,
BGSDataset
,
MUTAGDataset
from
model
import
EntityClassify
def
main
(
args
):
def
main
(
args
):
...
...
Prev
1
…
10
11
12
13
14
15
16
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment