test_falcon.py 16.2 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
3
4
5
# Copyright (c) 2023, Tri Dao.

import os
import time
from pathlib import Path
Tri Dao's avatar
Tri Dao committed
6

Tri Dao's avatar
Tri Dao committed
7
8
9
current_dir = Path(__file__).parent.absolute()

import pytest
Tri Dao's avatar
Tri Dao committed
10
import torch
Tri Dao's avatar
Tri Dao committed
11
from einops import rearrange
Tri Dao's avatar
Tri Dao committed
12
from flash_attn.models.falcon import falcon_config_to_gpt2_config, remap_state_dict_hf_falcon
Tri Dao's avatar
Tri Dao committed
13
14
15
from flash_attn.models.gpt import GPTLMHeadModel, combine_state_dicts_tp, shard_state_dict_tp
from flash_attn.utils.distributed import all_gather_raw
from flash_attn.utils.generation import update_graph_cache
Tri Dao's avatar
Tri Dao committed
16
17
from flash_attn.utils.pretrained import state_dict_from_pretrained
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
Tri Dao's avatar
Tri Dao committed
18
19


Tri Dao's avatar
Tri Dao committed
20
@pytest.mark.parametrize("model_name", ["tiiuae/falcon-7b", "tiiuae/falcon-40b"])
Tri Dao's avatar
Tri Dao committed
21
def test_falcon_state_dict(model_name):
Tri Dao's avatar
Tri Dao committed
22
23
24
25
26
27
28
    config = falcon_config_to_gpt2_config(
        AutoConfig.from_pretrained(model_name, trust_remote_code=True)
    )
    pretrained_state_dict = remap_state_dict_hf_falcon(
        state_dict_from_pretrained(model_name), config
    )
    model = GPTLMHeadModel(config, device="meta")  # Without device='meta' init is very slow
Tri Dao's avatar
Tri Dao committed
29
30
31
32
33
34
    state_dict = model.state_dict()
    assert state_dict.keys() == pretrained_state_dict.keys()
    for k in state_dict.keys():
        assert state_dict[k].shape == pretrained_state_dict[k].shape


Tri Dao's avatar
Tri Dao committed
35
@pytest.mark.parametrize("model_name", ["tiiuae/falcon-7b"])
Tri Dao's avatar
Tri Dao committed
36
37
38
39
40
41
def test_falcon_optimized(model_name):
    """Check that our implementation (with all optimizations enabled) matches the
    HF implementation: the output of our forward pass in fp16 should be around the same as the HF
    forward pass in fp16, when compared to the HF forward pass in fp32.
    """
    dtype = torch.float16
Tri Dao's avatar
Tri Dao committed
42
43
44
45
    device = "cuda"
    config = falcon_config_to_gpt2_config(
        AutoConfig.from_pretrained(model_name, trust_remote_code=True)
    )
Tri Dao's avatar
Tri Dao committed
46
47
48
49
50
51
52
53
54
55
56
57
    config.use_flash_attn = True
    config.fused_bias_fc = True
    config.fused_mlp = False  # We don't have fused MLP for "gelu" activation
    config.fused_dropout_add_ln = True
    config.residual_in_fp32 = True

    model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
    model.eval()

    torch.manual_seed(0)
    batch_size = 2
    max_seqlen = 256
Tri Dao's avatar
Tri Dao committed
58
59
60
    input_ids = torch.randint(
        0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
    )
Tri Dao's avatar
Tri Dao committed
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
    with torch.no_grad():
        out = model.transformer(input_ids)
        logits = model(input_ids).logits
    del model

    # Without device_map, the model is loaded on the CPU, which is very slow
    model_ref = AutoModelForCausalLM.from_pretrained(
        model_name, device_map={"": device}, trust_remote_code=True
    )
    model_ref.eval()
    with torch.no_grad():
        out_ref = model_ref.transformer(input_ids).last_hidden_state.to(device=device)
        logits_ref = model_ref(input_ids).logits.to(device=device)
    del model_ref

    model_hf = AutoModelForCausalLM.from_pretrained(
        model_name, torch_dtype=dtype, device_map={"": device}, trust_remote_code=True
    )
    model_hf.eval()
    out_hf = model_hf.transformer(input_ids).last_hidden_state
    logits_hf = model_hf(input_ids).logits
    del model_hf

Tri Dao's avatar
Tri Dao committed
84
85
86
87
    print(f"Output max diff: {(out - out_ref).abs().max().item()}")
    print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
    print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
    print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
Tri Dao's avatar
Tri Dao committed
88
89
    assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()

Tri Dao's avatar
Tri Dao committed
90
91
92
93
94
95
96
    print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
    print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
    print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
    print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
    assert (logits - logits_ref).abs().max().item() < 3 * (
        logits_hf - logits_ref
    ).abs().max().item()
Tri Dao's avatar
Tri Dao committed
97
98
99
100
101


# torchrun --no_python --nproc_per_node=4 pytest -q -s tests/models/test_falcon.py -k "falcon_parallel_forward"
# We want to run this on a machine with 4 x A100 80GB or 8 x A100 40GB so we have enough
# memory to run the model in fp32.
Tri Dao's avatar
Tri Dao committed
102
103
@pytest.mark.parametrize("world_size", [4])
@pytest.mark.parametrize("model_name", ["tiiuae/falcon-40b"])
Tri Dao's avatar
Tri Dao committed
104
105
106
107
def test_falcon_parallel_forward(model_name, world_size):
    from apex.transformer import parallel_state

    dtype = torch.float16
Tri Dao's avatar
Tri Dao committed
108
109
110
    config = falcon_config_to_gpt2_config(
        AutoConfig.from_pretrained(model_name, trust_remote_code=True)
    )
Tri Dao's avatar
Tri Dao committed
111
112
113
114
115
116
117
    config.use_flash_attn = False
    config.fused_bias_fc = True
    config.fused_mlp = False  # We don't have fused MLP for "gelu" activation
    config.fused_dropout_add_ln = False
    config.residual_in_fp32 = True

    if not torch.distributed.is_initialized():
Tri Dao's avatar
Tri Dao committed
118
119
        torch.distributed.init_process_group(backend="nccl", init_method="env://")
    device = f"cuda:{torch.distributed.get_rank()}"
Tri Dao's avatar
Tri Dao committed
120
121
122
123
124
    assert world_size <= torch.distributed.get_world_size()
    parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
    rank = parallel_state.get_tensor_model_parallel_rank()
    process_group = parallel_state.get_tensor_model_parallel_group()

Tri Dao's avatar
Tri Dao committed
125
126
127
    pretrained_state_dict = remap_state_dict_hf_falcon(
        state_dict_from_pretrained(model_name), config
    )
Tri Dao's avatar
Tri Dao committed
128
129
130
131
132
133
134
135
136

    model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)
    model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))
    model.eval()

    torch.manual_seed(0)
    batch_size = 2
    max_seqlen = 256
    seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)
Tri Dao's avatar
Tri Dao committed
137
138
139
    input_ids = torch.randint(
        0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
    )
Tri Dao's avatar
Tri Dao committed
140
141
142
143
144
145
146
    with torch.no_grad():
        out = model.transformer(input_ids)
        out, _ = all_gather_raw(out, process_group=process_group)
        out = rearrange(out, "(b s) d -> b s d", b=batch_size)
        logits = model(input_ids).logits
        logits = rearrange(logits, "(b s) d -> b s d", b=batch_size)
        logits, _ = all_gather_raw(logits, process_group)
Tri Dao's avatar
Tri Dao committed
147
        logits = rearrange(logits, "(n b) ... d -> b ... (n d)", b=batch_size)
Tri Dao's avatar
Tri Dao committed
148
    del model
Tri Dao's avatar
Tri Dao committed
149
    parallel_state.destroy_model_parallel()
Tri Dao's avatar
Tri Dao committed
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169

    if rank == 0:
        model_hf = AutoModelForCausalLM.from_pretrained(
            model_name, torch_dtype=dtype, device_map="auto", trust_remote_code=True
        )
        model_hf.eval()
        out_hf = model_hf.transformer(input_ids).last_hidden_state.to(device=device)
        logits_hf = model_hf(input_ids).logits.to(device=device)
        del model_hf

        # Without device_map, the model is loaded on the CPU, which is very slow
        model_ref = AutoModelForCausalLM.from_pretrained(
            model_name, device_map="auto", trust_remote_code=True
        )
        model_ref.eval()
        with torch.no_grad():
            out_ref = model_ref.transformer(input_ids).last_hidden_state.to(device=device)
            logits_ref = model_ref(input_ids).logits.to(device=device)
        del model_ref

Tri Dao's avatar
Tri Dao committed
170
171
172
173
        print(f"Output max diff: {(out - out_ref).abs().max().item()}")
        print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
        print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
        print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
Tri Dao's avatar
Tri Dao committed
174
175
        assert (out - out_ref).abs().max().item() < 2 * (out_hf - out_ref).abs().max().item()

Tri Dao's avatar
Tri Dao committed
176
177
178
179
180
181
182
        print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
        print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
        print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
        print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
        assert (logits - logits_ref).abs().max().item() < 2 * (
            logits_hf - logits_ref
        ).abs().max().item()
Tri Dao's avatar
Tri Dao committed
183
184


Tri Dao's avatar
Tri Dao committed
185
@pytest.mark.parametrize("model_name", ["tiiuae/falcon-7b"])
Tri Dao's avatar
Tri Dao committed
186
187
188
189
190
191
def test_falcon_generation(model_name):
    """Check that our implementation (with all optimizations enabled) matches the
    HF implementation: the output of our forward pass in fp16 should be around the same as the HF
    forward pass in fp16, when compared to the HF forward pass in fp32.
    """
    dtype = torch.float16
Tri Dao's avatar
Tri Dao committed
192
193
194
195
    device = "cuda"
    config = falcon_config_to_gpt2_config(
        AutoConfig.from_pretrained(model_name, trust_remote_code=True)
    )
Tri Dao's avatar
Tri Dao committed
196
197
198
199
200
201
202
203
204
205
206
207
208
    config.use_flash_attn = True
    config.fused_bias_fc = True
    config.fused_mlp = False  # We don't have fused MLP for "gelu" activation
    config.fused_dropout_add_ln = True
    config.residual_in_fp32 = True

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    eos_token_id = tokenizer.eos_token_id

    torch.manual_seed(0)
    batch_size = 1
    seqlen = 100
    max_length = 150
Tri Dao's avatar
Tri Dao committed
209
210
211
    input_ids = torch.randint(
        0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
    )
Tri Dao's avatar
Tri Dao committed
212
213
214
215
216
217
218
219

    model_hf = AutoModelForCausalLM.from_pretrained(
        model_name, torch_dtype=dtype, device_map={"": device}, trust_remote_code=True
    )
    model_hf.eval()
    print("HF fp16")
    torch.cuda.synchronize()
    start = time.time()
Tri Dao's avatar
Tri Dao committed
220
221
222
    out_hf = model_hf.generate(
        input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True
    )
Tri Dao's avatar
Tri Dao committed
223
    torch.cuda.synchronize()
Tri Dao's avatar
Tri Dao committed
224
    print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
Tri Dao's avatar
Tri Dao committed
225
226
227
228
229
230
231
    del model_hf

    model_ref = AutoModelForCausalLM.from_pretrained(
        model_name, device_map={"": device}, trust_remote_code=True
    )
    model_ref.eval()
    with torch.no_grad():
Tri Dao's avatar
Tri Dao committed
232
        logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1]
Tri Dao's avatar
Tri Dao committed
233
234
235
236
237
    del model_ref

    model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
    model.eval()

Tri Dao's avatar
Tri Dao committed
238
    print("Without CUDA graph")
Tri Dao's avatar
Tri Dao committed
239
240
    torch.cuda.synchronize()
    start = time.time()
Tri Dao's avatar
Tri Dao committed
241
242
243
244
245
246
    out = model.generate(
        input_ids=input_ids,
        max_length=max_length,
        eos_token_id=eos_token_id,
        return_dict_in_generate=True,
        output_scores=True,
Tri Dao's avatar
Tri Dao committed
247
        enable_timing=True,
Tri Dao's avatar
Tri Dao committed
248
249
        teacher_outputs=out_hf.sequences,
    )
Tri Dao's avatar
Tri Dao committed
250
    torch.cuda.synchronize()
Tri Dao's avatar
Tri Dao committed
251
    print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
Tri Dao's avatar
Tri Dao committed
252
253
254

    # Capture graph outside the timing loop
    batch_size, seqlen_og = input_ids.shape
255
    model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
Tri Dao's avatar
Tri Dao committed
256
    print("With CUDA graph")
Tri Dao's avatar
Tri Dao committed
257
258
    torch.cuda.synchronize()
    start = time.time()
Tri Dao's avatar
Tri Dao committed
259
260
261
262
263
264
    out_cg = model.generate(
        input_ids=input_ids,
        max_length=max_length,
        cg=True,
        return_dict_in_generate=True,
        output_scores=True,
Tri Dao's avatar
Tri Dao committed
265
        enable_timing=True,
Tri Dao's avatar
Tri Dao committed
266
267
        teacher_outputs=out_hf.sequences,
    )
Tri Dao's avatar
Tri Dao committed
268
    torch.cuda.synchronize()
Tri Dao's avatar
Tri Dao committed
269
    print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
Tri Dao's avatar
Tri Dao committed
270
271

    with torch.no_grad():
Tri Dao's avatar
Tri Dao committed
272
        logits_parallel = model(out_hf.sequences).logits[:, (seqlen - 1) : -1]
Tri Dao's avatar
Tri Dao committed
273
274
275
276
277
278
279
280
281
    logits_hf = torch.stack(out_hf.scores, dim=1)
    logits = torch.stack(out.scores, dim=1)
    logits_cg = torch.stack(out_cg.scores, dim=1)

    del model

    hf_error = (logits_hf - logits_ref).abs().max().item()
    assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error

Tri Dao's avatar
Tri Dao committed
282
283
    print(f"HF fp16 logits max diff: {hf_error}")
    print(f"Logits max diff: {(logits - logits_ref).abs().max().item() }")
Tri Dao's avatar
Tri Dao committed
284
    assert (logits - logits_ref).abs().max().item() < 2 * hf_error
Tri Dao's avatar
Tri Dao committed
285
    print(f"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }")
Tri Dao's avatar
Tri Dao committed
286
287
288
289
290
291
    assert torch.equal(logits_cg, logits)


# torchrun --no_python --nproc_per_node=4 pytest -q -s tests/models/test_falcon.py -k "falcon_parallel_generation"
# We want to run this on a machine with 4 x A100 80GB or 8 x A100 40GB so we have enough
# memory to run the model in fp32.
Tri Dao's avatar
Tri Dao committed
292
293
@pytest.mark.parametrize("world_size", [4])
@pytest.mark.parametrize("model_name", ["tiiuae/falcon-40b"])
Tri Dao's avatar
Tri Dao committed
294
295
296
297
298
299
300
301
def test_falcon_parallel_generation(model_name, world_size):
    """Check that our implementation matches the HF implementation:
    the scores in fp16 should be around the same as the HF scores in fp16, when compared to
    the HF scores in fp32.
    """
    from apex.transformer import parallel_state

    dtype = torch.float16
Tri Dao's avatar
Tri Dao committed
302
303
304
    config = falcon_config_to_gpt2_config(
        AutoConfig.from_pretrained(model_name, trust_remote_code=True)
    )
Tri Dao's avatar
Tri Dao committed
305
306
307
308
309
310
311
312
313
314
    config.use_flash_attn = False
    config.fused_bias_fc = True
    config.fused_mlp = False  # We don't have fused MLP for "gelu" activation
    config.fused_dropout_add_ln = False
    config.residual_in_fp32 = True
    config.pad_vocab_size_multiple = 8 * world_size
    config.sequence_parallel = False  # Need to set this to False for generation

    os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0"
    if not torch.distributed.is_initialized():
Tri Dao's avatar
Tri Dao committed
315
316
        torch.distributed.init_process_group(backend="nccl", init_method="env://")
    device = f"cuda:{torch.distributed.get_rank()}"
Tri Dao's avatar
Tri Dao committed
317
318
319
320
321
322
323
324
325
    assert world_size <= torch.distributed.get_world_size()
    parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
    rank = parallel_state.get_tensor_model_parallel_rank()
    process_group = parallel_state.get_tensor_model_parallel_group()

    torch.manual_seed(0)
    batch_size = 1
    seqlen = 100
    max_length = 150
Tri Dao's avatar
Tri Dao committed
326
327
328
    input_ids = torch.randint(
        0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
    )
Tri Dao's avatar
Tri Dao committed
329
330
331
332
333

    # Need this, otherwise when we capture the graph the process for GPU 1 would run on both
    # GPU0 and GPU1 and things would hang
    torch.cuda.set_device(device)

Tri Dao's avatar
Tri Dao committed
334
335
336
    pretrained_state_dict = remap_state_dict_hf_falcon(
        state_dict_from_pretrained(model_name), config
    )
Tri Dao's avatar
Tri Dao committed
337
338
339
340
341

    model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)
    model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))
    model.eval()

Tri Dao's avatar
Tri Dao committed
342
    print("Without CUDA graph")
Tri Dao's avatar
Tri Dao committed
343
    out = model.generate(
Tri Dao's avatar
Tri Dao committed
344
345
346
347
        input_ids=input_ids,
        max_length=max_length,
        tensor_parallel=world_size,
        vocab_size=config.vocab_size,
Tri Dao's avatar
Tri Dao committed
348
        # teacher_outputs=out_hf.sequences,
Tri Dao's avatar
Tri Dao committed
349
350
        return_dict_in_generate=True,
        output_scores=True,
Tri Dao's avatar
Tri Dao committed
351
        enable_timing=True,
Tri Dao's avatar
Tri Dao committed
352
353
354
355
    )

    # Capture graph outside the timing loop
    batch_size, seqlen_og = input_ids.shape
356
    model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
Tri Dao's avatar
Tri Dao committed
357
    print("With CUDA graph")
Tri Dao's avatar
Tri Dao committed
358
    out_cg = model.generate(
Tri Dao's avatar
Tri Dao committed
359
360
361
362
363
        input_ids=input_ids,
        max_length=max_length,
        tensor_parallel=world_size,
        vocab_size=config.vocab_size,
        cg=True,
Tri Dao's avatar
Tri Dao committed
364
        # teacher_outputs=out_hf.sequences,
Tri Dao's avatar
Tri Dao committed
365
366
        return_dict_in_generate=True,
        output_scores=True,
Tri Dao's avatar
Tri Dao committed
367
        enable_timing=True,
Tri Dao's avatar
Tri Dao committed
368
369
370
371
372
373
374
375
376
377
378
379
380
381
    )
    del model
    parallel_state.destroy_model_parallel()

    if rank == 0:
        model_hf = AutoModelForCausalLM.from_pretrained(
            model_name, torch_dtype=dtype, device_map="auto", trust_remote_code=True
        )
        model_hf.eval()
        print("HF fp16")
        torch.cuda.synchronize()
        start = time.time()
        with torch.inference_mode():
            out_hf = model_hf.generate(
Tri Dao's avatar
Tri Dao committed
382
383
384
385
                input_ids=input_ids,
                max_length=max_length,
                return_dict_in_generate=True,
                output_scores=True,
Tri Dao's avatar
Tri Dao committed
386
387
            )
        torch.cuda.synchronize()
Tri Dao's avatar
Tri Dao committed
388
        print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
Tri Dao's avatar
Tri Dao committed
389
390
391
392
393
394
395
        del model_hf

        model_ref = AutoModelForCausalLM.from_pretrained(
            model_name, device_map="auto", trust_remote_code=True
        )
        model_ref.eval()
        with torch.inference_mode():
Tri Dao's avatar
Tri Dao committed
396
            logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1]
Tri Dao's avatar
Tri Dao committed
397
398
399
400
401
402
403
        del model_ref
        logits_hf = torch.stack(out_hf.scores, dim=1)

        logits = torch.stack(out.scores, dim=1)
        logits_cg = torch.stack(out_cg.scores, dim=1)

        hf_error = (logits_hf - logits_ref).abs().max().item()
Tri Dao's avatar
Tri Dao committed
404
405
        print(f"HF fp16 logits max diff: {hf_error}")
        print(f"Logits max diff: {(logits - logits_ref).abs().max().item() }")
Tri Dao's avatar
Tri Dao committed
406
        assert (logits - logits_ref).abs().max().item() < 2 * hf_error
Tri Dao's avatar
Tri Dao committed
407
        print(f"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }")
Tri Dao's avatar
Tri Dao committed
408
        assert torch.equal(logits_cg, logits)