"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "8e963d1c2a2b6ebf2880675516809bd9a39359a4"
Unverified Commit 895d6fff authored by Kay Liu's avatar Kay Liu Committed by GitHub
Browse files

[Doc] changing of output dimension (#2981)


Co-authored-by: default avatarzhjwy9343 <6593865@qq.com>
parent e2a28a6c
...@@ -30,7 +30,7 @@ implementation would be like: ...@@ -30,7 +30,7 @@ implementation would be like:
import torch import torch
import torch.nn as nn import torch.nn as nn
linear = nn.Parameter(torch.FloatTensor(size=(node_feat_dim * 2, 1))) linear = nn.Parameter(torch.FloatTensor(size=(node_feat_dim * 2, out_dim)))
def concat_message_function(edges): def concat_message_function(edges):
return {'cat_feat': torch.cat([edges.src['feat'], edges.dst['feat']])} return {'cat_feat': torch.cat([edges.src['feat'], edges.dst['feat']])}
g.apply_edges(concat_message_function) g.apply_edges(concat_message_function)
...@@ -48,8 +48,8 @@ respectively: ...@@ -48,8 +48,8 @@ respectively:
import dgl.function as fn import dgl.function as fn
linear_src = nn.Parameter(torch.FloatTensor(size=(node_feat_dim, 1))) linear_src = nn.Parameter(torch.FloatTensor(size=(node_feat_dim, out_dim)))
linear_dst = nn.Parameter(torch.FloatTensor(size=(node_feat_dim, 1))) linear_dst = nn.Parameter(torch.FloatTensor(size=(node_feat_dim, out_dim)))
out_src = g.ndata['feat'] @ linear_src out_src = g.ndata['feat'] @ linear_src
out_dst = g.ndata['feat'] @ linear_dst out_dst = g.ndata['feat'] @ linear_dst
g.srcdata.update({'out_src': out_src}) g.srcdata.update({'out_src': out_src})
......
...@@ -20,7 +20,7 @@ DGL建议用户尽量减少边的特征维数。 ...@@ -20,7 +20,7 @@ DGL建议用户尽量减少边的特征维数。
import torch import torch
import torch.nn as nn import torch.nn as nn
linear = nn.Parameter(torch.FloatTensor(size=(node_feat_dim * 2, 1))) linear = nn.Parameter(torch.FloatTensor(size=(node_feat_dim * 2, out_dim)))
def concat_message_function(edges): def concat_message_function(edges):
return {'cat_feat': torch.cat([edges.src.ndata['feat'], edges.dst.ndata['feat']])} return {'cat_feat': torch.cat([edges.src.ndata['feat'], edges.dst.ndata['feat']])}
g.apply_edges(concat_message_function) g.apply_edges(concat_message_function)
...@@ -35,8 +35,8 @@ DGL建议用户尽量减少边的特征维数。 ...@@ -35,8 +35,8 @@ DGL建议用户尽量减少边的特征维数。
import dgl.function as fn import dgl.function as fn
linear_src = nn.Parameter(torch.FloatTensor(size=(node_feat_dim, 1))) linear_src = nn.Parameter(torch.FloatTensor(size=(node_feat_dim, out_dim)))
linear_dst = nn.Parameter(torch.FloatTensor(size=(node_feat_dim, 1))) linear_dst = nn.Parameter(torch.FloatTensor(size=(node_feat_dim, out_dim)))
out_src = g.ndata['feat'] @ linear_src out_src = g.ndata['feat'] @ linear_src
out_dst = g.ndata['feat'] @ linear_dst out_dst = g.ndata['feat'] @ linear_dst
g.srcdata.update({'out_src': out_src}) g.srcdata.update({'out_src': out_src})
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment