attention.py 1.06 KB
Newer Older
1
import numpy as np
Zihao Ye's avatar
Zihao Ye committed
2
3
import torch as th
import torch.nn as nn
4

Zihao Ye's avatar
Zihao Ye committed
5
6
from .layers import clones

7

Zihao Ye's avatar
Zihao Ye committed
8
9
class MultiHeadAttention(nn.Module):
    "Multi-Head Attention"
10

Zihao Ye's avatar
Zihao Ye committed
11
12
13
14
15
16
    def __init__(self, h, dim_model):
        "h: number of heads; dim_model: hidden dimension"
        super(MultiHeadAttention, self).__init__()
        self.d_k = dim_model // h
        self.h = h
        # W_q, W_k, W_v, W_o
17
        self.linears = clones(nn.Linear(dim_model, dim_model, bias=False), 4)
Zihao Ye's avatar
Zihao Ye committed
18

19
    def get(self, x, fields="qkv"):
Zihao Ye's avatar
Zihao Ye committed
20
21
22
        "Return a dict of queries / keys / values."
        batch_size = x.shape[0]
        ret = {}
23
24
25
26
27
28
        if "q" in fields:
            ret["q"] = self.linears[0](x).view(batch_size, self.h, self.d_k)
        if "k" in fields:
            ret["k"] = self.linears[1](x).view(batch_size, self.h, self.d_k)
        if "v" in fields:
            ret["v"] = self.linears[2](x).view(batch_size, self.h, self.d_k)
Zihao Ye's avatar
Zihao Ye committed
29
30
31
32
33
34
        return ret

    def get_o(self, x):
        "get output of the multi-head attention"
        batch_size = x.shape[0]
        return self.linears[3](x.view(batch_size, -1))