xlm_roberta.py 4.33 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
2
3
4
5
6
# Modified from transformers.models.xlm_roberta.modeling_xlm_roberta
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F

Dongz's avatar
Dongz committed
7
__all__ = ["XLMRoberta", "xlm_roberta_large"]
helloyongyang's avatar
helloyongyang committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58


class SelfAttention(nn.Module):
    def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):
        assert dim % num_heads == 0
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.eps = eps

        # layers
        self.q = nn.Linear(dim, dim)
        self.k = nn.Linear(dim, dim)
        self.v = nn.Linear(dim, dim)
        self.o = nn.Linear(dim, dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        """
        x:   [B, L, C].
        """
        b, s, c, n, d = *x.size(), self.num_heads, self.head_dim

        # compute query, key, value
        q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
        k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
        v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3)

        # compute attention
        p = self.dropout.p if self.training else 0.0
        x = F.scaled_dot_product_attention(q, k, v, mask, p)
        x = x.permute(0, 2, 1, 3).reshape(b, s, c)

        # output
        x = self.o(x)
        x = self.dropout(x)
        return x


class AttentionBlock(nn.Module):
    def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.post_norm = post_norm
        self.eps = eps

        # layers
        self.attn = SelfAttention(dim, num_heads, dropout, eps)
        self.norm1 = nn.LayerNorm(dim, eps=eps)
Dongz's avatar
Dongz committed
59
        self.ffn = nn.Sequential(nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim), nn.Dropout(dropout))
helloyongyang's avatar
helloyongyang committed
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
        self.norm2 = nn.LayerNorm(dim, eps=eps)

    def forward(self, x, mask):
        if self.post_norm:
            x = self.norm1(x + self.attn(x, mask))
            x = self.norm2(x + self.ffn(x))
        else:
            x = x + self.attn(self.norm1(x), mask)
            x = x + self.ffn(self.norm2(x))
        return x


class XLMRoberta(nn.Module):
    """
    XLMRobertaModel with no pooler and no LM head.
    """

Dongz's avatar
Dongz committed
77
    def __init__(self, vocab_size=250002, max_seq_len=514, type_size=1, pad_id=1, dim=1024, num_heads=16, num_layers=24, post_norm=True, dropout=0.1, eps=1e-5):
helloyongyang's avatar
helloyongyang committed
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
        super().__init__()
        self.vocab_size = vocab_size
        self.max_seq_len = max_seq_len
        self.type_size = type_size
        self.pad_id = pad_id
        self.dim = dim
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.post_norm = post_norm
        self.eps = eps

        # embeddings
        self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id)
        self.type_embedding = nn.Embedding(type_size, dim)
        self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id)
        self.dropout = nn.Dropout(dropout)

        # blocks
Dongz's avatar
Dongz committed
96
        self.blocks = nn.ModuleList([AttentionBlock(dim, num_heads, post_norm, dropout, eps) for _ in range(num_layers)])
helloyongyang's avatar
helloyongyang committed
97
98
99
100
101
102
103
104
105
106
107
108

        # norm layer
        self.norm = nn.LayerNorm(dim, eps=eps)

    def forward(self, ids):
        """
        ids: [B, L] of torch.LongTensor.
        """
        b, s = ids.shape
        mask = ids.ne(self.pad_id).long()

        # embeddings
Dongz's avatar
Dongz committed
109
        x = self.token_embedding(ids) + self.type_embedding(torch.zeros_like(ids)) + self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask)
helloyongyang's avatar
helloyongyang committed
110
111
112
113
114
        if self.post_norm:
            x = self.norm(x)
        x = self.dropout(x)

        # blocks
Dongz's avatar
Dongz committed
115
        mask = torch.where(mask.view(b, 1, 1, s).gt(0), 0.0, torch.finfo(x.dtype).min)
helloyongyang's avatar
helloyongyang committed
116
117
118
119
120
121
122
123
124
        for block in self.blocks:
            x = block(x, mask)

        # output
        if not self.post_norm:
            x = self.norm(x)
        return x


Dongz's avatar
Dongz committed
125
def xlm_roberta_large(pretrained=False, return_tokenizer=False, device="cpu", **kwargs):
helloyongyang's avatar
helloyongyang committed
126
127
128
129
    """
    XLMRobertaLarge adapted from Huggingface.
    """
    # params
Dongz's avatar
Dongz committed
130
    cfg = dict(vocab_size=250002, max_seq_len=514, type_size=1, pad_id=1, dim=1024, num_heads=16, num_layers=24, post_norm=True, dropout=0.1, eps=1e-5)
helloyongyang's avatar
helloyongyang committed
131
132
133
134
135
136
    cfg.update(**kwargs)

    # init a model on device
    with torch.device(device):
        model = XLMRoberta(**cfg)
    return model