loss.py 616 Bytes
Newer Older
1
2
3
import torch
import torch.nn as nn

4

5
6
7
class EntropyLoss(nn.Module):
    # Return Scalar
    def forward(self, adj, anext, s_l):
8
9
10
11
12
        entropy = (
            (torch.distributions.Categorical(probs=s_l).entropy())
            .sum(-1)
            .mean(-1)
        )
13
14
15
16
17
18
        assert not torch.isnan(entropy)
        return entropy


class LinkPredLoss(nn.Module):
    def forward(self, adj, anext, s_l):
19
20
21
        link_pred_loss = (adj - s_l.matmul(s_l.transpose(-1, -2))).norm(
            dim=(1, 2)
        )
22
23
        link_pred_loss = link_pred_loss / (adj.size(1) * adj.size(2))
        return link_pred_loss.mean()