Unverified Commit ee9887d6 authored by dddg617's avatar dddg617 Committed by GitHub
Browse files

[Fix]Support HGT mini-batch (#4664)

* [Fix]Support HGT mini-batch

* [Test] Update HGT test

* [Conv] nightly update conv

* [Conv] nightly update
parent f4ee96db
...@@ -142,16 +142,26 @@ class HGTConv(nn.Module): ...@@ -142,16 +142,26 @@ class HGTConv(nn.Module):
New node features. Shape: :math:`(|V|, D_{head} * N_{head})`. New node features. Shape: :math:`(|V|, D_{head} * N_{head})`.
""" """
self.presorted = presorted self.presorted = presorted
if g.is_block:
x_src = x
x_dst = x[:g.num_dst_nodes()]
srcntype = ntype
dstntype = ntype[:g.num_dst_nodes()]
else:
x_src = x
x_dst = x
srcntype = ntype
dstntype = ntype
with g.local_scope(): with g.local_scope():
k = self.linear_k(x, ntype, presorted).view( k = self.linear_k(x_src, srcntype, presorted).view(
-1, self.num_heads, self.head_size -1, self.num_heads, self.head_size
) )
q = self.linear_q(x, ntype, presorted).view( q = self.linear_q(x_dst, dstntype, presorted).view(
-1, self.num_heads, self.head_size -1, self.num_heads, self.head_size
) )
v = self.linear_v(x, ntype, presorted).view( v = self.linear_v(x_src, srcntype, presorted).view(
-1, self.num_heads, self.head_size -1, self.num_heads, self.head_size
) )
g.srcdata["k"] = k g.srcdata["k"] = k
g.dstdata["q"] = q g.dstdata["q"] = q
g.srcdata["v"] = v g.srcdata["v"] = v
...@@ -163,12 +173,12 @@ class HGTConv(nn.Module): ...@@ -163,12 +173,12 @@ class HGTConv(nn.Module):
g.update_all(fn.copy_e("m", "m"), fn.sum("m", "h")) g.update_all(fn.copy_e("m", "m"), fn.sum("m", "h"))
h = g.dstdata["h"].view(-1, self.num_heads * self.head_size) h = g.dstdata["h"].view(-1, self.num_heads * self.head_size)
# target-specific aggregation # target-specific aggregation
h = self.drop(self.linear_a(h, ntype, presorted)) h = self.drop(self.linear_a(h, dstntype, presorted))
alpha = torch.sigmoid(self.skip[ntype]).unsqueeze(-1) alpha = torch.sigmoid(self.skip[dstntype]).unsqueeze(-1)
if x.shape != h.shape: if x_dst.shape != h.shape:
h = h * alpha + (x @ self.residual_w) * (1 - alpha) h = h * alpha + (x_dst @ self.residual_w) * (1 - alpha)
else: else:
h = h * alpha + x * (1 - alpha) h = h * alpha + x_dst * (1 - alpha)
if self.use_norm: if self.use_norm:
h = self.norm(h) h = self.norm(h)
return h return h
......
...@@ -1515,6 +1515,20 @@ def test_hgt(idtype, in_size, num_heads): ...@@ -1515,6 +1515,20 @@ def test_hgt(idtype, in_size, num_heads):
sorted_x = sorted_g.ndata['x'] sorted_x = sorted_g.ndata['x']
sorted_y = m(sorted_g, sorted_x, sorted_ntype, sorted_etype, presorted=False) sorted_y = m(sorted_g, sorted_x, sorted_ntype, sorted_etype, presorted=False)
assert sorted_y.shape == (g.num_nodes(), head_size * num_heads) assert sorted_y.shape == (g.num_nodes(), head_size * num_heads)
# mini-batch
train_idx = th.randint(0, 100, (10, ), dtype = idtype)
sampler = dgl.dataloading.NeighborSampler([-1])
train_loader = dgl.dataloading.DataLoader(g, train_idx.to(dev), sampler,
batch_size=8, device=dev,
shuffle=True)
(input_nodes, output_nodes, block) = next(iter(train_loader))
block = block[0]
x = x[input_nodes.to(th.long)]
ntype = ntype[input_nodes.to(th.long)]
edge = block.edata[dgl.EID]
etype = etype[edge.to(th.long)]
y = m(block, x, ntype, etype)
assert y.shape == (block.number_of_dst_nodes(), head_size * num_heads)
# TODO(minjie): enable the following check # TODO(minjie): enable the following check
#assert th.allclose(y, sorted_y[rev_idx], atol=1e-4, rtol=1e-4) #assert th.allclose(y, sorted_y[rev_idx], atol=1e-4, rtol=1e-4)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment