torch_sdpa.py 530 Bytes
Newer Older
helloyongyang's avatar
helloyongyang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import torch
import torch.nn.functional as F


def torch_sdpa(
    q,
    k,
    v,
    drop_rate=0,
    attn_mask=None,
    causal=False,
):
    q = q.transpose(1, 2)
    k = k.transpose(1, 2)
    v = v.transpose(1, 2)
    if attn_mask is not None and attn_mask.dtype != torch.bool:
        attn_mask = attn_mask.to(q.dtype)
Dongz's avatar
Dongz committed
18
    x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal)
helloyongyang's avatar
helloyongyang committed
19
20
21
22
    x = x.transpose(1, 2)
    b, s, a, d = x.shape
    out = x.reshape(b, s, -1)
    return out