model.py 5.47 KB
Newer Older
yuguo960516's avatar
yuguo960516 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import oneflow as flow
from oneflow import nn

from libai.layers.cross_entropy import ParallelCrossEntropyLoss
from libai.utils import distributed as dist

from .transformer_model import TransformerModel


class Seq2SeqLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.lm_loss = ParallelCrossEntropyLoss()

    def forward(self, logits, lm_labels):
        logits = logits[:, :-1, :]
        lm_labels = lm_labels[:, 1:]
        lm_loss = self.lm_loss(logits, lm_labels)
        lm_loss = lm_loss.mean()
        return lm_loss


class Seq2Seq(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.language_model = TransformerModel(cfg)
        self.loss_func = Seq2SeqLoss()

    def forward(
        self,
        encoder_input_ids,
        decoder_input_ids,
        encoder_attn_mask,
        decoder_attn_mask,
        encoder_decoder_attn_mask,
    ):
        logits = self.language_model(
            encoder_input_ids,
            decoder_input_ids,
            encoder_attn_mask,
            decoder_attn_mask,
            encoder_decoder_attn_mask,
        )

        if self.training:
            loss = self.loss_func(logits, decoder_input_ids)
            return {"total_loss": loss}

        logits = logits.view(-1, logits.shape[-1])
        return {"prediction_scores": logits}

    def encode(
        self,
        encoder_input_ids,
        encoder_attn_mask,
    ):
        encoder_input_embeddings = self.language_model.embedding(encoder_input_ids)
        if encoder_attn_mask is not None:
            encoder_extended_attn_mask = self.language_model.extended_attn_mask(encoder_attn_mask)
            encoder_states = self.language_model.encoder(
                encoder_input_embeddings,
                encoder_extended_attn_mask,
            )
        else:
            encoder_states = self.language_model.encoder(
                encoder_input_embeddings,
                None,
            )
        return encoder_states

    def decode(
        self,
        decoder_input_ids,
        decoder_attn_mask,
        encoder_states,
        encoder_decoder_attn_mask,
    ):
        decoder_input_embeddings = self.language_model.embedding(decoder_input_ids)
        decoder_extended_attn_mask = self.language_model.extended_attn_mask(decoder_attn_mask)
        if encoder_decoder_attn_mask is not None:
            encoder_decoder_extended_attn_mask = self.language_model.extended_attn_mask(
                encoder_decoder_attn_mask
            )
            decoder_states = self.language_model.decoder(
                decoder_input_embeddings,
                decoder_extended_attn_mask,
                encoder_states,
                encoder_decoder_extended_attn_mask,
            )
        else:
            decoder_states = self.language_model.decoder(
                decoder_input_embeddings,
                decoder_extended_attn_mask,
                encoder_states,
                None,
            )
        logits = self.language_model.lm_head(decoder_states)
        return logits

    @staticmethod
    def set_pipeline_stage_id(model):
        dist_utils = dist.get_dist_util()

        from .transformer_model import ExtendedMask, TransformerEmbedding, TransformerLayer

        # Set pipeline parallelism stage_id
        if hasattr(model.language_model.lm_head, "config"):
            # Old API in OneFlow 0.8
            for module_block in model.modules():
                # module.origin can get the original module
                if isinstance(module_block.origin, TransformerEmbedding):
                    module_block.config.set_stage(
                        dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0)
                    )
                elif isinstance(module_block.origin, ExtendedMask):
                    module_block.config.set_stage(
                        dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0)
                    )
                elif isinstance(module_block.origin, TransformerLayer):
                    module_block.config.set_stage(
                        dist_utils.get_layer_stage_id(module_block.layer_idx),
                        dist.get_layer_placement(module_block.layer_idx),
                    )

            # Set the lm_head stage id
            model.language_model.lm_head.config.set_stage(
                dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
            )
        else:
            for module_block in model.modules():
                if isinstance(module_block.to(nn.Module), TransformerEmbedding):
                    module_block.to(flow.nn.graph.GraphModule).set_stage(
                        dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0)
                    )
                elif isinstance(module_block.to(nn.Module), ExtendedMask):
                    module_block.to(flow.nn.graph.GraphModule).set_stage(
                        dist_utils.get_layer_stage_id(0), dist.get_layer_placement(0)
                    )
                elif isinstance(module_block.to(nn.Module), TransformerLayer):
                    module_block.to(flow.nn.graph.GraphModule).set_stage(
                        dist_utils.get_layer_stage_id(module_block.layer_idx),
                        dist.get_layer_placement(module_block.layer_idx),
                    )

            # Set the lm_head stage id
            model.language_model.lm_head.to(flow.nn.graph.GraphModule).set_stage(
                dist_utils.get_layer_stage_id(-1), dist.get_layer_placement(-1)
            )