functions.py 646 Bytes
Newer Older
Zihao Ye's avatar
Zihao Ye committed
1
2
import torch as th

3

Zihao Ye's avatar
Zihao Ye committed
4
5
6
7
def src_dot_dst(src_field, dst_field, out_field):
    """
    This function serves as a surrogate for `src_dot_dst` built-in apply_edge function.
    """
8

Zihao Ye's avatar
Zihao Ye committed
9
    def func(edges):
10
11
12
13
14
15
        return {
            out_field: (edges.src[src_field] * edges.dst[dst_field]).sum(
                -1, keepdim=True
            )
        }

Zihao Ye's avatar
Zihao Ye committed
16
17
    return func

18

Zihao Ye's avatar
Zihao Ye committed
19
20
21
22
def scaled_exp(field, c):
    """
    This function applies $exp(x / c)$ for input $x$, which is required by *Scaled Dot-Product Attention* mentioned in the paper.
    """
23

Zihao Ye's avatar
Zihao Ye committed
24
25
    def func(edges):
        return {field: th.exp((edges.data[field] / c).clamp(-10, 10))}
26

Zihao Ye's avatar
Zihao Ye committed
27
    return func