loss.py 571 Bytes
Newer Older
1
2
3
import torch
import torch.nn as nn

4

5
6
7
class EntropyLoss(nn.Module):
    # Return Scalar
    def forward(self, adj, anext, s_l):
8
9
        entropy = (torch.distributions.Categorical(
            probs=s_l).entropy()).sum(-1).mean(-1)
10
11
12
13
14
15
16
        assert not torch.isnan(entropy)
        return entropy


class LinkPredLoss(nn.Module):

    def forward(self, adj, anext, s_l):
17
18
        link_pred_loss = (
            adj - s_l.matmul(s_l.transpose(-1, -2))).norm(dim=(1, 2))
19
20
        link_pred_loss = link_pred_loss / (adj.size(1) * adj.size(2))
        return link_pred_loss.mean()