positions.py 3.42 KB
Newer Older
liugh5's avatar
liugh5 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np


class SinusoidalPositionEncoder(nn.Module):
    def __init__(self, max_len, depth):
        super(SinusoidalPositionEncoder, self).__init__()

        self.max_len = max_len
        self.depth = depth
        self.position_enc = nn.Parameter(
            self.get_sinusoid_encoding_table(max_len, depth).unsqueeze(0),
            requires_grad=False,
        )

    def forward(self, input):
        bz_in, len_in, _ = input.size()
        if len_in > self.max_len:
            self.max_len = len_in
            self.position_enc.data = (
                self.get_sinusoid_encoding_table(self.max_len, self.depth)
                .unsqueeze(0)
                .to(input.device)
            )

        output = input + self.position_enc[:, :len_in, :].expand(bz_in, -1, -1)

        return output

    @staticmethod
    def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None):
        """ Sinusoid position encoding table """

        def cal_angle(position, hid_idx):
            return position / np.power(10000, hid_idx / float(d_hid / 2 - 1))

        def get_posi_angle_vec(position):
            return [cal_angle(position, hid_j) for hid_j in range(d_hid // 2)]

        scaled_time_table = np.array(
            [get_posi_angle_vec(pos_i + 1) for pos_i in range(n_position)]
        )

        sinusoid_table = np.zeros((n_position, d_hid))
        sinusoid_table[:, : d_hid // 2] = np.sin(scaled_time_table)
        sinusoid_table[:, d_hid // 2 :] = np.cos(scaled_time_table)

        if padding_idx is not None:
            # zero vector for padding dimension
            sinusoid_table[padding_idx] = 0.0

        return torch.FloatTensor(sinusoid_table)


class DurSinusoidalPositionEncoder(nn.Module):
    def __init__(self, depth, outputs_per_step):
        super(DurSinusoidalPositionEncoder, self).__init__()

        self.depth = depth
        self.outputs_per_step = outputs_per_step

        inv_timescales = [
            np.power(10000, 2 * (hid_idx // 2) / depth) for hid_idx in range(depth)
        ]
        self.inv_timescales = nn.Parameter(
            torch.FloatTensor(inv_timescales), requires_grad=False
        )

    def forward(self, durations, masks=None):
        reps = (durations + 0.5).long()
        output_lens = reps.sum(dim=1)
        max_len = output_lens.max()
        reps_cumsum = torch.cumsum(F.pad(reps.float(), (1, 0, 0, 0), value=0.0), dim=1)[
            :, None, :
        ]
        range_ = torch.arange(max_len).to(durations.device)[None, :, None]
        mult = (reps_cumsum[:, :, :-1] <= range_) & (reps_cumsum[:, :, 1:] > range_)
        mult = mult.float()
        offsets = torch.matmul(mult, reps_cumsum[:, 0, :-1].unsqueeze(-1)).squeeze(-1)
        dur_pos = range_[:, :, 0] - offsets + 1

        if masks is not None:
            assert masks.size(1) == dur_pos.size(1)
            dur_pos = dur_pos.masked_fill(masks, 0.0)

        seq_len = dur_pos.size(1)
        padding = self.outputs_per_step - int(seq_len) % self.outputs_per_step
        if padding < self.outputs_per_step:
            dur_pos = F.pad(dur_pos, (0, padding, 0, 0), value=0.0)

        position_embedding = dur_pos[:, :, None] / self.inv_timescales[None, None, :]
        position_embedding[:, :, 0::2] = torch.sin(position_embedding[:, :, 0::2])
        position_embedding[:, :, 1::2] = torch.cos(position_embedding[:, :, 1::2])

        return position_embedding