Unverified Commit 34a067ea authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[bugfix] Fix the memory leak issue of Cluster GAT under 0.5 kernel and...

[bugfix] Fix the memory leak issue of Cluster GAT under 0.5 kernel and simplify the bipartite GAT. (#1908)

* uipd

* upd

* upd

* upd

* upd
parent 303e4236
...@@ -95,8 +95,6 @@ class GAT(nn.Module): ...@@ -95,8 +95,6 @@ class GAT(nn.Module):
drop_last=False, drop_last=False,
num_workers=args.num_workers) num_workers=args.num_workers)
layer.fc_src = layer.fc
layer.fc_dst = layer.fc
for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader): for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
block = blocks[0].to(device) block = blocks[0].to(device)
h = x[input_nodes].to(device) h = x[input_nodes].to(device)
......
...@@ -522,7 +522,6 @@ class HeteroGraphIndex(ObjectBase): ...@@ -522,7 +522,6 @@ class HeteroGraphIndex(ObjectBase):
eid = F.from_dgl_nd(edge_array(2)) eid = F.from_dgl_nd(edge_array(2))
return src, dst, eid return src, dst, eid
@utils.cached_member(cache='_cache', prefix='edges')
def edges(self, etype, order=None): def edges(self, etype, order=None):
"""Return all the edges """Return all the edges
...@@ -821,7 +820,6 @@ class HeteroGraphIndex(ObjectBase): ...@@ -821,7 +820,6 @@ class HeteroGraphIndex(ObjectBase):
eids = [F.to_dgl_nd(edges) for edges in induced_edges] eids = [F.to_dgl_nd(edges) for edges in induced_edges]
return _CAPI_DGLHeteroEdgeSubgraph(self, eids, preserve_nodes) return _CAPI_DGLHeteroEdgeSubgraph(self, eids, preserve_nodes)
@utils.cached_member(cache='_cache', prefix='unitgraph')
def get_unitgraph(self, etype, ctx): def get_unitgraph(self, etype, ctx):
"""Create a unitgraph graph from given edge type and copy to the given device """Create a unitgraph graph from given edge type and copy to the given device
context. context.
...@@ -912,7 +910,6 @@ class HeteroGraphIndex(ObjectBase): ...@@ -912,7 +910,6 @@ class HeteroGraphIndex(ObjectBase):
"""Create all sparse matrices allowed for the graph.""" """Create all sparse matrices allowed for the graph."""
return _CAPI_DGLHeteroCreateFormat(self) return _CAPI_DGLHeteroCreateFormat(self)
@utils.cached_member(cache='_cache', prefix='reverse')
def reverse(self): def reverse(self):
"""Reverse the heterogeneous graph adjacency """Reverse the heterogeneous graph adjacency
......
...@@ -26,13 +26,8 @@ class GATConv(nn.Module): ...@@ -26,13 +26,8 @@ class GATConv(nn.Module):
Parameters Parameters
---------- ----------
in_feats : int, or pair of ints in_feats : int
Input feature size. Input feature size.
If the layer is to be applied to a unidirectional bipartite graph, ``in_feats``
specifies the input feature size on both the source and destination nodes. If
a scalar is given, the source and destination node feature size would take the
same value.
out_feats : int out_feats : int
Output feature size. Output feature size.
num_heads : int num_heads : int
...@@ -62,12 +57,6 @@ class GATConv(nn.Module): ...@@ -62,12 +57,6 @@ class GATConv(nn.Module):
self._num_heads = num_heads self._num_heads = num_heads
self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats) self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
self._out_feats = out_feats self._out_feats = out_feats
if isinstance(in_feats, tuple):
self.fc_src = nn.Linear(
self._in_src_feats, out_feats * num_heads, bias=False)
self.fc_dst = nn.Linear(
self._in_dst_feats, out_feats * num_heads, bias=False)
else:
self.fc = nn.Linear( self.fc = nn.Linear(
self._in_src_feats, out_feats * num_heads, bias=False) self._in_src_feats, out_feats * num_heads, bias=False)
self.attn_l = nn.Parameter(th.FloatTensor(size=(1, num_heads, out_feats))) self.attn_l = nn.Parameter(th.FloatTensor(size=(1, num_heads, out_feats)))
...@@ -89,11 +78,7 @@ class GATConv(nn.Module): ...@@ -89,11 +78,7 @@ class GATConv(nn.Module):
def reset_parameters(self): def reset_parameters(self):
"""Reinitialize learnable parameters.""" """Reinitialize learnable parameters."""
gain = nn.init.calculate_gain('relu') gain = nn.init.calculate_gain('relu')
if hasattr(self, 'fc'):
nn.init.xavier_normal_(self.fc.weight, gain=gain) nn.init.xavier_normal_(self.fc.weight, gain=gain)
else: # bipartite graph neural networks
nn.init.xavier_normal_(self.fc_src.weight, gain=gain)
nn.init.xavier_normal_(self.fc_dst.weight, gain=gain)
nn.init.xavier_normal_(self.attn_l, gain=gain) nn.init.xavier_normal_(self.attn_l, gain=gain)
nn.init.xavier_normal_(self.attn_r, gain=gain) nn.init.xavier_normal_(self.attn_r, gain=gain)
if isinstance(self.res_fc, nn.Linear): if isinstance(self.res_fc, nn.Linear):
...@@ -122,8 +107,8 @@ class GATConv(nn.Module): ...@@ -122,8 +107,8 @@ class GATConv(nn.Module):
if isinstance(feat, tuple): if isinstance(feat, tuple):
h_src = self.feat_drop(feat[0]) h_src = self.feat_drop(feat[0])
h_dst = self.feat_drop(feat[1]) h_dst = self.feat_drop(feat[1])
feat_src = self.fc_src(h_src).view(-1, self._num_heads, self._out_feats) feat_src = self.fc(h_src).view(-1, self._num_heads, self._out_feats)
feat_dst = self.fc_dst(h_dst).view(-1, self._num_heads, self._out_feats) feat_dst = self.fc(h_dst).view(-1, self._num_heads, self._out_feats)
else: else:
h_src = h_dst = self.feat_drop(feat) h_src = h_dst = self.feat_drop(feat)
feat_src = feat_dst = self.fc(h_src).view( feat_src = feat_dst = self.fc(h_src).view(
......
...@@ -140,6 +140,9 @@ void CSRSort_<kDLGPU, int64_t>(CSRMatrix* csr) { ...@@ -140,6 +140,9 @@ void CSRSort_<kDLGPU, int64_t>(CSRMatrix* csr) {
csr->sorted = true; csr->sorted = true;
csr->indices = new_indices; csr->indices = new_indices;
csr->data = new_data; csr->data = new_data;
// free resources
device->FreeWorkspace(ctx, workspace);
} }
template void CSRSort_<kDLGPU, int32_t>(CSRMatrix* csr); template void CSRSort_<kDLGPU, int32_t>(CSRMatrix* csr);
......
...@@ -497,8 +497,8 @@ def test_gat_conv(g, idtype): ...@@ -497,8 +497,8 @@ def test_gat_conv(g, idtype):
def test_gat_conv_bi(g, idtype): def test_gat_conv_bi(g, idtype):
g = g.astype(idtype).to(F.ctx()) g = g.astype(idtype).to(F.ctx())
ctx = F.ctx() ctx = F.ctx()
gat = nn.GATConv((5, 10), 2, 4) gat = nn.GATConv(5, 2, 4)
feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 10))) feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
gat = gat.to(ctx) gat = gat.to(ctx)
h = gat(g, feat) h = gat(g, feat)
assert h.shape == (g.number_of_dst_nodes(), 4, 2) assert h.shape == (g.number_of_dst_nodes(), 4, 2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment