mask.py 382 Bytes
Newer Older
yuguo960516's avatar
bloom  
yuguo960516 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
import numpy as np


def make_padding_mask(q_ids, kv_ids, pad_id):
    q = (np.array(q_ids) != pad_id).reshape(-1, 1)
    kv = (np.array(kv_ids) != pad_id).reshape(1, -1)
    padding_mask = (q * kv).astype(float)
    return padding_mask


def make_sequence_mask(ids):
    seqlen = len(ids)
    sequence_mask = np.triu(np.ones((seqlen, seqlen))).transpose()
    return sequence_mask