head.py 2.4 KB
Newer Older
1
2
3
4
5
import torch
import torch.nn as nn
import torch.nn.functional as F
from loss_func.cross_entropy import vocab_cross_entropy

6
7
from colossalai.legacy.context import ParallelMode
from colossalai.legacy.core import global_context as gpc
8
from colossalai.nn.layer.layernorm import MixedFusedLayerNorm as LayerNorm
9
10
11
12

from .linear import Linear
from .pooler import Pooler

13
14
15
16
17
18
19
20
21

class BertLMHead(nn.Module):
    """Masked LM head for Bert
    Arguments:
        hidden_size: hidden size
        init_method: init method for weight initialization
        layernorm_epsilon: tolerance for layer norm divisions
    """

22
23
24
25
26
    def __init__(
        self,
        vocab_size,
        hidden_size,
    ):
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
        super(BertLMHead, self).__init__()
        self.bias = torch.nn.Parameter(torch.zeros(vocab_size))

        self.dense = Linear(hidden_size, hidden_size)
        self.layernorm = LayerNorm(hidden_size)
        self.gelu = torch.nn.functional.gelu

    def forward(self, hidden_states, word_embeddings_weight, lm_labels):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.gelu(hidden_states)
        hidden_states = self.layernorm(hidden_states)

        output = F.linear(hidden_states, word_embeddings_weight, self.bias)
        lm_loss = vocab_cross_entropy(output, lm_labels)

        return lm_loss


class BertBinaryHead(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.pooler = Pooler(hidden_size)
        self.dense = Linear(hidden_size, 2)

    def forward(self, hidden_states):
        if gpc.get_local_rank(ParallelMode.SEQUENCE) == 0:
            output = self.pooler(hidden_states)
            output = self.dense(output)
        else:
            output = None
        return output


class BertDualHead(nn.Module):
    def __init__(self, hidden_size, vocab_size, add_binary_head):
        super().__init__()
        self.lm_head = BertLMHead(vocab_size, hidden_size)
        self.add_binary_head = add_binary_head
        if add_binary_head:
            self.binary_head = BertBinaryHead(hidden_size)
        else:
            self.binary_head = None

    def forward(self, hidden_states, word_embeddings_weight, lm_labels):
        if self.add_binary_head:
            binary_output = self.binary_head(hidden_states)
        else:
            binary_output = None
        lm_loss = self.lm_head(hidden_states, word_embeddings_weight, lm_labels)
        return lm_loss, binary_output