lm_head_layer.py 689 Bytes
Newer Older
yuguo960516's avatar
yuguo960516 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from oneflow import nn

from libai.layers import Linear, LMLogits


class LMHead(nn.Module):
    def __init__(self, model_type, hidden_size, vocab_size, hidden_layers):
        super().__init__()
        if model_type == "mt5":
            self.lm_head = Linear(
                hidden_size, vocab_size, bias=False, layer_idx=2 * hidden_layers - 1
            )
        else:
            self.lm_head = LMLogits(vocab_size, bias=True)

    def forward(self, decoder_states, embed_weight=None):
        if isinstance(self.lm_head, Linear):
            logits = self.lm_head(decoder_states)
        else:
            logits = self.lm_head(decoder_states, embed_weight)
        return logits