utils.py 7.81 KB
Newer Older
zihanl's avatar
zihanl committed
1

zihanl's avatar
zihanl committed
2
3
"""Utils (functions) for both prompting and finetuning"""

zihanl's avatar
zihanl committed
4
5
6
7
8
9
10
11
12
13
14
15
import torch
from megatron import mpu
from megatron import get_args
from megatron import get_tokenizer
from megatron.utils import get_ltor_masks_and_position_ids, unwrap_model
from megatron.p2p_communication import recv_forward, send_forward
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import Float16Module


def switch(val1, val2, boolean):
zihanl's avatar
zihanl committed
16
    """Return either val1 or val2 depending on boolean"""
zihanl's avatar
zihanl committed
17
18
19
20
21
22
23
24

    boolean = boolean.type_as(val1)
    return (1 - boolean) * val1 + boolean * val2


def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
                 layer_past=None, get_key_value=None,
                 forward_method_parallel_output=None):
zihanl's avatar
zihanl committed
25
26
    """Forward step to get the outputs"""
    
zihanl's avatar
zihanl committed
27
28
29
30
31
32
33
34
35
36
37
38
    # functions the correct size
    args = get_args()
    orig_seq_length = args.seq_length
    args.seq_length = tokens.shape[1]

    input_tensor = recv_forward()

    # Forward pass through the model.
    unwrapped_model = unwrap_model(
        model, (torchDDP, LocalDDP, Float16Module))
    unwrapped_model.set_input_tensor(input_tensor)
    output_tensor = model(tokens, position_ids, attention_mask,
zihanl's avatar
zihanl committed
39
                          tokentype_ids=tokentype_ids)
zihanl's avatar
zihanl committed
40
41
42
43
44
45
46
47
48
49
50
51
52

    if get_key_value:
        output_tensor, layer_past = output_tensor

    send_forward(output_tensor)

    args.seq_length = orig_seq_length
    if get_key_value:
        return output_tensor, layer_past
    return output_tensor
    

def pad_batch(batch, pad_id, args):
zihanl's avatar
zihanl committed
53
    """Pad the context tokens using pad_id"""
zihanl's avatar
zihanl committed
54
55
56
57

    context_lengths = []
    for tokens in batch:
        context_length = len(tokens)
zihanl's avatar
zihanl committed
58
        # padding
zihanl's avatar
zihanl committed
59
60
        if context_length < args.seq_length:
            tokens.extend([pad_id] * (args.seq_length - context_length))
zihanl's avatar
zihanl committed
61
        # record the original context length
zihanl's avatar
zihanl committed
62
63
64
65
66
67
        context_lengths.append(context_length)
    return batch, context_lengths


def get_batch(context_tokens):
    """Generate batch from context tokens."""
zihanl's avatar
zihanl committed
68

zihanl's avatar
zihanl committed
69
70
71
72
73
    args = get_args()
    tokenizer = get_tokenizer()

    # Move to GPU.
    tokens = context_tokens.view(args.micro_batch_size, -1).contiguous().cuda()
zihanl's avatar
zihanl committed
74
    # Get the attention mask and postition ids for the context tokens.
zihanl's avatar
zihanl committed
75
76
77
78
79
80
81
82
83
84
85
86
87
    attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
        tokens,
        tokenizer.eod,
        args.reset_position_ids,
        args.reset_attention_mask,
        args.eod_mask_loss)

    return tokens, attention_mask, position_ids


def sample_sequence_batch(model, context_tokens, context_lengths,
                          attention_mask, position_ids,
                          maxlen=None, type_ids=None):
zihanl's avatar
zihanl committed
88
    """Obtain batch-level generation outputs"""
zihanl's avatar
zihanl committed
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106

    args = get_args()
    tokenizer = get_tokenizer()

    model.eval()
    with torch.no_grad():
        context_length = context_lengths.min().item()

        # added eos_id to support the function generate_samples_eval that passes
        # eos_id as an argument and needs termination when that id id found.
        if hasattr(args, 'eos_id'):
            eos_id = args.eos_id
        else:
            eos_id = tokenizer.eod

        counter = 0
        org_context_length = context_length

zihanl's avatar
zihanl committed
107
        # prepare batch size, context tokens, maximum length
zihanl's avatar
zihanl committed
108
109
110
111
112
113
        layer_past = None
        batch_size = context_tokens.size(0)
        is_done = torch.zeros([batch_size]).byte().cuda()
        tokens = context_tokens
        if maxlen is None:
            maxlen = args.seq_length - 1
root's avatar
root committed
114
115
116
            if maxlen > (org_context_length + args.out_seq_length):
                maxlen = org_context_length + args.out_seq_length

zihanl's avatar
zihanl committed
117
118
        lengths = torch.ones([batch_size]).long().cuda() * maxlen

zihanl's avatar
zihanl committed
119
        # start the generation process
zihanl's avatar
zihanl committed
120
        while context_length <= (maxlen):
zihanl's avatar
zihanl committed
121
            # forward and obtain the logits
zihanl's avatar
zihanl committed
122
123
124
125
126
127
128
129
130
            output = forward_step(model, tokens,
                                    position_ids,
                                    attention_mask,
                                    tokentype_ids=type_ids,
                                    forward_method_parallel_output=False)
            if mpu.is_pipeline_last_stage():
                assert output is not None
                logits = output[:, context_length - 1, :]
            
zihanl's avatar
zihanl committed
131
            # generate tokens iteratively
zihanl's avatar
zihanl committed
132
133
            if mpu.is_pipeline_last_stage():
                prev = torch.argmax(logits, dim=-1).view(-1)
zihanl's avatar
zihanl committed
134
135
136
                
                # start to add new tokens when the generated length
                # exceeds the context length
zihanl's avatar
zihanl committed
137
138
139
140
141
142
143
144
                started = context_lengths <= context_length
                new_tokens = switch(
                    tokens[:, context_length].view(-1), prev, started)
                tokens[:, context_length] = new_tokens
                src = mpu.get_pipeline_model_parallel_last_rank()
                group = mpu.get_embedding_group()
                torch.distributed.broadcast(new_tokens, src, group)

zihanl's avatar
zihanl committed
145
                # check whether the generation is finished
zihanl's avatar
zihanl committed
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
                done_token = (prev == eos_id).byte() & started.byte()
                just_finished = (done_token & ~is_done).bool()
                lengths[just_finished.view(-1)] = context_length
                is_done = is_done | done_token

                done = torch.all(is_done)
                src = mpu.get_pipeline_model_parallel_last_rank()
                group = mpu.get_pipeline_model_parallel_group()
                torch.distributed.broadcast(done, src, group)
                yield tokens, lengths

            else:
                if mpu.is_pipeline_first_stage():
                    src = mpu.get_pipeline_model_parallel_last_rank()
                    group = mpu.get_embedding_group()
                    new_tokens = torch.empty_like(tokens[:, context_length])
                    torch.distributed.broadcast(new_tokens, src, group)
                    tokens[:, context_length] = new_tokens
                    yield tokens, None
                else:
                    yield None, None

                done = torch.cuda.ByteTensor([0])
                src = mpu.get_pipeline_model_parallel_last_rank()
                group = mpu.get_pipeline_model_parallel_group()
                torch.distributed.broadcast(done, src, group)

            context_length += 1
            counter += 1
            if done:
                break


def get_token_stream(model, context_tokens):
zihanl's avatar
zihanl committed
180
    """Get output tokens iteratively"""
zihanl's avatar
zihanl committed
181

zihanl's avatar
zihanl committed
182
    # get tokenizer
zihanl's avatar
zihanl committed
183
184
185
    args = get_args()
    tokenizer = get_tokenizer()

zihanl's avatar
zihanl committed
186
    # padding for context tokens
zihanl's avatar
zihanl committed
187
188
189
    context_tokens, context_lengths = pad_batch(context_tokens,
                                                tokenizer.eod, args)

zihanl's avatar
zihanl committed
190
    # move tokens to CUDA
zihanl's avatar
zihanl committed
191
192
193
194
195
196
197
198
199
200
    context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
    context_length_tensor = torch.cuda.LongTensor(context_lengths)

    torch.distributed.broadcast(context_length_tensor,
                                mpu.get_tensor_model_parallel_src_rank(),
                                group=mpu.get_tensor_model_parallel_group())
    torch.distributed.broadcast(context_tokens_tensor,
                                mpu.get_tensor_model_parallel_src_rank(),
                                group=mpu.get_tensor_model_parallel_group())

zihanl's avatar
zihanl committed
201
    # prepare batch
zihanl's avatar
zihanl committed
202
203
204
    context_length = context_length_tensor.min().item()
    tokens, attention_mask, position_ids = get_batch(context_tokens_tensor)

zihanl's avatar
zihanl committed
205
    # get generation outputs
zihanl's avatar
zihanl committed
206
207
208
209
210
211
212
213
214
215
216
    batch_token_iterator = sample_sequence_batch(model, context_tokens_tensor,
                                                 context_length_tensor,
                                                 attention_mask, position_ids)
    for tokens, lengths in batch_token_iterator:
        context_length += 1
        if tokens is not None:
            yield tokens[:, :context_length], lengths
        else:
            yield None, None