import torch import torch.nn as nn class EntropyLoss(nn.Module): # Return Scalar def forward(self, adj, anext, s_l): entropy = (torch.distributions.Categorical( probs=s_l).entropy()).sum(-1).mean(-1) assert not torch.isnan(entropy) return entropy class LinkPredLoss(nn.Module): def forward(self, adj, anext, s_l): link_pred_loss = ( adj - s_l.matmul(s_l.transpose(-1, -2))).norm(dim=(1, 2)) link_pred_loss = link_pred_loss / (adj.size(1) * adj.size(2)) return link_pred_loss.mean()