test_falcon.py 16.3 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
3
4
5
# Copyright (c) 2023, Tri Dao.

import os
import time
from pathlib import Path
Tri Dao's avatar
Tri Dao committed
6

Tri Dao's avatar
Tri Dao committed
7
8
9
current_dir = Path(__file__).parent.absolute()

import pytest
Tri Dao's avatar
Tri Dao committed
10
import torch
Tri Dao's avatar
Tri Dao committed
11
from einops import rearrange
Tri Dao's avatar
Tri Dao committed
12
from flash_attn.models.falcon import falcon_config_to_gpt2_config, remap_state_dict_hf_falcon
Tri Dao's avatar
Tri Dao committed
13
14
15
from flash_attn.models.gpt import GPTLMHeadModel, combine_state_dicts_tp, shard_state_dict_tp
from flash_attn.utils.distributed import all_gather_raw
from flash_attn.utils.generation import update_graph_cache
Tri Dao's avatar
Tri Dao committed
16
17
from flash_attn.utils.pretrained import state_dict_from_pretrained
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
Tri Dao's avatar
Tri Dao committed
18
19


Tri Dao's avatar
Tri Dao committed
20
@pytest.mark.parametrize("model_name", ["tiiuae/falcon-7b", "tiiuae/falcon-40b"])
Tri Dao's avatar
Tri Dao committed
21
def test_falcon_state_dict(model_name):
Tri Dao's avatar
Tri Dao committed
22
23
24
25
26
27
28
    config = falcon_config_to_gpt2_config(
        AutoConfig.from_pretrained(model_name, trust_remote_code=True)
    )
    pretrained_state_dict = remap_state_dict_hf_falcon(
        state_dict_from_pretrained(model_name), config
    )
    model = GPTLMHeadModel(config, device="meta")  # Without device='meta' init is very slow
Tri Dao's avatar
Tri Dao committed
29
30
31
32
33
34
    state_dict = model.state_dict()
    assert state_dict.keys() == pretrained_state_dict.keys()
    for k in state_dict.keys():
        assert state_dict[k].shape == pretrained_state_dict[k].shape


Tri Dao's avatar
Tri Dao committed
35
@pytest.mark.parametrize("model_name", ["tiiuae/falcon-7b"])
Tri Dao's avatar
Tri Dao committed
36
37
38
39
40
41
def test_falcon_optimized(model_name):
    """Check that our implementation (with all optimizations enabled) matches the
    HF implementation: the output of our forward pass in fp16 should be around the same as the HF
    forward pass in fp16, when compared to the HF forward pass in fp32.
    """
    dtype = torch.float16
Tri Dao's avatar
Tri Dao committed
42
43
44
45
    device = "cuda"
    config = falcon_config_to_gpt2_config(
        AutoConfig.from_pretrained(model_name, trust_remote_code=True)
    )
Tri Dao's avatar
Tri Dao committed
46
47
48
49
50
51
52
53
54
55
56
57
    config.use_flash_attn = True
    config.fused_bias_fc = True
    config.fused_mlp = False  # We don't have fused MLP for "gelu" activation
    config.fused_dropout_add_ln = True
    config.residual_in_fp32 = True

    model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
    model.eval()

    torch.manual_seed(0)
    batch_size = 2
    max_seqlen = 256
Tri Dao's avatar
Tri Dao committed
58
59
60
    input_ids = torch.randint(
        0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
    )
Tri Dao's avatar
Tri Dao committed
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
    with torch.no_grad():
        out = model.transformer(input_ids)
        logits = model(input_ids).logits
    del model

    # Without device_map, the model is loaded on the CPU, which is very slow
    model_ref = AutoModelForCausalLM.from_pretrained(
        model_name, device_map={"": device}, trust_remote_code=True
    )
    model_ref.eval()
    with torch.no_grad():
        out_ref = model_ref.transformer(input_ids).last_hidden_state.to(device=device)
        logits_ref = model_ref(input_ids).logits.to(device=device)
    del model_ref

    model_hf = AutoModelForCausalLM.from_pretrained(
        model_name, torch_dtype=dtype, device_map={"": device}, trust_remote_code=True
    )
    model_hf.eval()
    out_hf = model_hf.transformer(input_ids).last_hidden_state
    logits_hf = model_hf(input_ids).logits
    del model_hf

Tri Dao's avatar
Tri Dao committed
84
85
86
87
    print(f"Output max diff: {(out - out_ref).abs().max().item()}")
    print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
    print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
    print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
Tri Dao's avatar
Tri Dao committed
88
89
    assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()

Tri Dao's avatar
Tri Dao committed
90
91
92
93
94
95
96
    print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
    print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
    print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
    print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
    assert (logits - logits_ref).abs().max().item() < 3 * (
        logits_hf - logits_ref
    ).abs().max().item()
Tri Dao's avatar
Tri Dao committed
97
98
99
100
101


# torchrun --no_python --nproc_per_node=4 pytest -q -s tests/models/test_falcon.py -k "falcon_parallel_forward"
# We want to run this on a machine with 4 x A100 80GB or 8 x A100 40GB so we have enough
# memory to run the model in fp32.
Tri Dao's avatar
Tri Dao committed
102
103
@pytest.mark.parametrize("world_size", [4])
@pytest.mark.parametrize("model_name", ["tiiuae/falcon-40b"])
Tri Dao's avatar
Tri Dao committed
104
105
106
107
def test_falcon_parallel_forward(model_name, world_size):
    from apex.transformer import parallel_state

    dtype = torch.float16
Tri Dao's avatar
Tri Dao committed
108
109
110
    config = falcon_config_to_gpt2_config(
        AutoConfig.from_pretrained(model_name, trust_remote_code=True)
    )
Tri Dao's avatar
Tri Dao committed
111
112
113
114
115
116
117
    config.use_flash_attn = False
    config.fused_bias_fc = True
    config.fused_mlp = False  # We don't have fused MLP for "gelu" activation
    config.fused_dropout_add_ln = False
    config.residual_in_fp32 = True

    if not torch.distributed.is_initialized():
Tri Dao's avatar
Tri Dao committed
118
119
        torch.distributed.init_process_group(backend="nccl", init_method="env://")
    device = f"cuda:{torch.distributed.get_rank()}"
Tri Dao's avatar
Tri Dao committed
120
121
122
123
124
    assert world_size <= torch.distributed.get_world_size()
    parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
    rank = parallel_state.get_tensor_model_parallel_rank()
    process_group = parallel_state.get_tensor_model_parallel_group()

Tri Dao's avatar
Tri Dao committed
125
126
127
    pretrained_state_dict = remap_state_dict_hf_falcon(
        state_dict_from_pretrained(model_name), config
    )
Tri Dao's avatar
Tri Dao committed
128
129
130
131
132
133
134
135
136

    model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)
    model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))
    model.eval()

    torch.manual_seed(0)
    batch_size = 2
    max_seqlen = 256
    seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)
Tri Dao's avatar
Tri Dao committed
137
138
139
    input_ids = torch.randint(
        0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
    )
Tri Dao's avatar
Tri Dao committed
140
141
142
143
144
145
146
    with torch.no_grad():
        out = model.transformer(input_ids)
        out, _ = all_gather_raw(out, process_group=process_group)
        out = rearrange(out, "(b s) d -> b s d", b=batch_size)
        logits = model(input_ids).logits
        logits = rearrange(logits, "(b s) d -> b s d", b=batch_size)
        logits, _ = all_gather_raw(logits, process_group)
Tri Dao's avatar
Tri Dao committed
147
        logits = rearrange(logits, "(n b) ... d -> b ... (n d)", b=batch_size)
Tri Dao's avatar
Tri Dao committed
148
    del model
Tri Dao's avatar
Tri Dao committed
149
    parallel_state.destroy_model_parallel()
Tri Dao's avatar
Tri Dao committed
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169

    if rank == 0:
        model_hf = AutoModelForCausalLM.from_pretrained(
            model_name, torch_dtype=dtype, device_map="auto", trust_remote_code=True
        )
        model_hf.eval()
        out_hf = model_hf.transformer(input_ids).last_hidden_state.to(device=device)
        logits_hf = model_hf(input_ids).logits.to(device=device)
        del model_hf

        # Without device_map, the model is loaded on the CPU, which is very slow
        model_ref = AutoModelForCausalLM.from_pretrained(
            model_name, device_map="auto", trust_remote_code=True
        )
        model_ref.eval()
        with torch.no_grad():
            out_ref = model_ref.transformer(input_ids).last_hidden_state.to(device=device)
            logits_ref = model_ref(input_ids).logits.to(device=device)
        del model_ref

Tri Dao's avatar
Tri Dao committed
170
171
172
173
        print(f"Output max diff: {(out - out_ref).abs().max().item()}")
        print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
        print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
        print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
Tri Dao's avatar
Tri Dao committed
174
175
        assert (out - out_ref).abs().max().item() < 2 * (out_hf - out_ref).abs().max().item()

Tri Dao's avatar
Tri Dao committed
176
177
178
179
180
181
182
        print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
        print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
        print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
        print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
        assert (logits - logits_ref).abs().max().item() < 2 * (
            logits_hf - logits_ref
        ).abs().max().item()
Tri Dao's avatar
Tri Dao committed
183
184


Tri Dao's avatar
Tri Dao committed
185
@pytest.mark.parametrize("model_name", ["tiiuae/falcon-7b"])
Tri Dao's avatar
Tri Dao committed
186
187
188
189
190
191
def test_falcon_generation(model_name):
    """Check that our implementation (with all optimizations enabled) matches the
    HF implementation: the output of our forward pass in fp16 should be around the same as the HF
    forward pass in fp16, when compared to the HF forward pass in fp32.
    """
    dtype = torch.float16
Tri Dao's avatar
Tri Dao committed
192
193
194
195
    device = "cuda"
    config = falcon_config_to_gpt2_config(
        AutoConfig.from_pretrained(model_name, trust_remote_code=True)
    )
Tri Dao's avatar
Tri Dao committed
196
197
198
199
200
201
202
203
204
205
206
207
208
    config.use_flash_attn = True
    config.fused_bias_fc = True
    config.fused_mlp = False  # We don't have fused MLP for "gelu" activation
    config.fused_dropout_add_ln = True
    config.residual_in_fp32 = True

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    eos_token_id = tokenizer.eos_token_id

    torch.manual_seed(0)
    batch_size = 1
    seqlen = 100
    max_length = 150
Tri Dao's avatar
Tri Dao committed
209
210
211
    input_ids = torch.randint(
        0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
    )
Tri Dao's avatar
Tri Dao committed
212
213
214
215
216
217
218
219

    model_hf = AutoModelForCausalLM.from_pretrained(
        model_name, torch_dtype=dtype, device_map={"": device}, trust_remote_code=True
    )
    model_hf.eval()
    print("HF fp16")
    torch.cuda.synchronize()
    start = time.time()
Tri Dao's avatar
Tri Dao committed
220
221
222
    out_hf = model_hf.generate(
        input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True
    )
Tri Dao's avatar
Tri Dao committed
223
    torch.cuda.synchronize()
Tri Dao's avatar
Tri Dao committed
224
    print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
Tri Dao's avatar
Tri Dao committed
225
226
227
228
229
230
231
    del model_hf

    model_ref = AutoModelForCausalLM.from_pretrained(
        model_name, device_map={"": device}, trust_remote_code=True
    )
    model_ref.eval()
    with torch.no_grad():
Tri Dao's avatar
Tri Dao committed
232
        logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1]
Tri Dao's avatar
Tri Dao committed
233
234
235
236
237
    del model_ref

    model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
    model.eval()

Tri Dao's avatar
Tri Dao committed
238
    print("Without CUDA graph")
Tri Dao's avatar
Tri Dao committed
239
240
    torch.cuda.synchronize()
    start = time.time()
Tri Dao's avatar
Tri Dao committed
241
242
243
244
245
246
247
    out = model.generate(
        input_ids=input_ids,
        max_length=max_length,
        eos_token_id=eos_token_id,
        fused_ft_kernel=True,
        return_dict_in_generate=True,
        output_scores=True,
Tri Dao's avatar
Tri Dao committed
248
        enable_timing=True,
Tri Dao's avatar
Tri Dao committed
249
250
        teacher_outputs=out_hf.sequences,
    )
Tri Dao's avatar
Tri Dao committed
251
    torch.cuda.synchronize()
Tri Dao's avatar
Tri Dao committed
252
    print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
Tri Dao's avatar
Tri Dao committed
253
254
255
256

    # Capture graph outside the timing loop
    batch_size, seqlen_og = input_ids.shape
    model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
Tri Dao's avatar
Tri Dao committed
257
    print("With CUDA graph")
Tri Dao's avatar
Tri Dao committed
258
259
    torch.cuda.synchronize()
    start = time.time()
Tri Dao's avatar
Tri Dao committed
260
261
262
263
264
265
266
    out_cg = model.generate(
        input_ids=input_ids,
        max_length=max_length,
        fused_ft_kernel=True,
        cg=True,
        return_dict_in_generate=True,
        output_scores=True,
Tri Dao's avatar
Tri Dao committed
267
        enable_timing=True,
Tri Dao's avatar
Tri Dao committed
268
269
        teacher_outputs=out_hf.sequences,
    )
Tri Dao's avatar
Tri Dao committed
270
    torch.cuda.synchronize()
Tri Dao's avatar
Tri Dao committed
271
    print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
Tri Dao's avatar
Tri Dao committed
272
273

    with torch.no_grad():
Tri Dao's avatar
Tri Dao committed
274
        logits_parallel = model(out_hf.sequences).logits[:, (seqlen - 1) : -1]
Tri Dao's avatar
Tri Dao committed
275
276
277
278
279
280
281
282
283
    logits_hf = torch.stack(out_hf.scores, dim=1)
    logits = torch.stack(out.scores, dim=1)
    logits_cg = torch.stack(out_cg.scores, dim=1)

    del model

    hf_error = (logits_hf - logits_ref).abs().max().item()
    assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error

Tri Dao's avatar
Tri Dao committed
284
285
    print(f"HF fp16 logits max diff: {hf_error}")
    print(f"Logits max diff: {(logits - logits_ref).abs().max().item() }")
Tri Dao's avatar
Tri Dao committed
286
    assert (logits - logits_ref).abs().max().item() < 2 * hf_error
Tri Dao's avatar
Tri Dao committed
287
    print(f"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }")
Tri Dao's avatar
Tri Dao committed
288
289
290
291
292
293
    assert torch.equal(logits_cg, logits)


# torchrun --no_python --nproc_per_node=4 pytest -q -s tests/models/test_falcon.py -k "falcon_parallel_generation"
# We want to run this on a machine with 4 x A100 80GB or 8 x A100 40GB so we have enough
# memory to run the model in fp32.
Tri Dao's avatar
Tri Dao committed
294
295
@pytest.mark.parametrize("world_size", [4])
@pytest.mark.parametrize("model_name", ["tiiuae/falcon-40b"])
Tri Dao's avatar
Tri Dao committed
296
297
298
299
300
301
302
303
def test_falcon_parallel_generation(model_name, world_size):
    """Check that our implementation matches the HF implementation:
    the scores in fp16 should be around the same as the HF scores in fp16, when compared to
    the HF scores in fp32.
    """
    from apex.transformer import parallel_state

    dtype = torch.float16
Tri Dao's avatar
Tri Dao committed
304
305
306
    config = falcon_config_to_gpt2_config(
        AutoConfig.from_pretrained(model_name, trust_remote_code=True)
    )
Tri Dao's avatar
Tri Dao committed
307
308
309
310
311
312
313
314
315
316
    config.use_flash_attn = False
    config.fused_bias_fc = True
    config.fused_mlp = False  # We don't have fused MLP for "gelu" activation
    config.fused_dropout_add_ln = False
    config.residual_in_fp32 = True
    config.pad_vocab_size_multiple = 8 * world_size
    config.sequence_parallel = False  # Need to set this to False for generation

    os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0"
    if not torch.distributed.is_initialized():
Tri Dao's avatar
Tri Dao committed
317
318
        torch.distributed.init_process_group(backend="nccl", init_method="env://")
    device = f"cuda:{torch.distributed.get_rank()}"
Tri Dao's avatar
Tri Dao committed
319
320
321
322
323
324
325
326
327
    assert world_size <= torch.distributed.get_world_size()
    parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
    rank = parallel_state.get_tensor_model_parallel_rank()
    process_group = parallel_state.get_tensor_model_parallel_group()

    torch.manual_seed(0)
    batch_size = 1
    seqlen = 100
    max_length = 150
Tri Dao's avatar
Tri Dao committed
328
329
330
    input_ids = torch.randint(
        0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
    )
Tri Dao's avatar
Tri Dao committed
331
332
333
334
335

    # Need this, otherwise when we capture the graph the process for GPU 1 would run on both
    # GPU0 and GPU1 and things would hang
    torch.cuda.set_device(device)

Tri Dao's avatar
Tri Dao committed
336
337
338
    pretrained_state_dict = remap_state_dict_hf_falcon(
        state_dict_from_pretrained(model_name), config
    )
Tri Dao's avatar
Tri Dao committed
339
340
341
342
343

    model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)
    model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))
    model.eval()

Tri Dao's avatar
Tri Dao committed
344
    print("Without CUDA graph")
Tri Dao's avatar
Tri Dao committed
345
    out = model.generate(
Tri Dao's avatar
Tri Dao committed
346
347
348
349
350
        input_ids=input_ids,
        max_length=max_length,
        tensor_parallel=world_size,
        vocab_size=config.vocab_size,
        fused_ft_kernel=True,
Tri Dao's avatar
Tri Dao committed
351
        # teacher_outputs=out_hf.sequences,
Tri Dao's avatar
Tri Dao committed
352
353
        return_dict_in_generate=True,
        output_scores=True,
Tri Dao's avatar
Tri Dao committed
354
        enable_timing=True,
Tri Dao's avatar
Tri Dao committed
355
356
357
358
359
    )

    # Capture graph outside the timing loop
    batch_size, seqlen_og = input_ids.shape
    model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
Tri Dao's avatar
Tri Dao committed
360
    print("With CUDA graph")
Tri Dao's avatar
Tri Dao committed
361
    out_cg = model.generate(
Tri Dao's avatar
Tri Dao committed
362
363
364
365
366
367
        input_ids=input_ids,
        max_length=max_length,
        tensor_parallel=world_size,
        vocab_size=config.vocab_size,
        fused_ft_kernel=True,
        cg=True,
Tri Dao's avatar
Tri Dao committed
368
        # teacher_outputs=out_hf.sequences,
Tri Dao's avatar
Tri Dao committed
369
370
        return_dict_in_generate=True,
        output_scores=True,
Tri Dao's avatar
Tri Dao committed
371
        enable_timing=True,
Tri Dao's avatar
Tri Dao committed
372
373
374
375
376
377
378
379
380
381
382
383
384
385
    )
    del model
    parallel_state.destroy_model_parallel()

    if rank == 0:
        model_hf = AutoModelForCausalLM.from_pretrained(
            model_name, torch_dtype=dtype, device_map="auto", trust_remote_code=True
        )
        model_hf.eval()
        print("HF fp16")
        torch.cuda.synchronize()
        start = time.time()
        with torch.inference_mode():
            out_hf = model_hf.generate(
Tri Dao's avatar
Tri Dao committed
386
387
388
389
                input_ids=input_ids,
                max_length=max_length,
                return_dict_in_generate=True,
                output_scores=True,
Tri Dao's avatar
Tri Dao committed
390
391
            )
        torch.cuda.synchronize()
Tri Dao's avatar
Tri Dao committed
392
        print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
Tri Dao's avatar
Tri Dao committed
393
394
395
396
397
398
399
        del model_hf

        model_ref = AutoModelForCausalLM.from_pretrained(
            model_name, device_map="auto", trust_remote_code=True
        )
        model_ref.eval()
        with torch.inference_mode():
Tri Dao's avatar
Tri Dao committed
400
            logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1]
Tri Dao's avatar
Tri Dao committed
401
402
403
404
405
406
407
        del model_ref
        logits_hf = torch.stack(out_hf.scores, dim=1)

        logits = torch.stack(out.scores, dim=1)
        logits_cg = torch.stack(out_cg.scores, dim=1)

        hf_error = (logits_hf - logits_ref).abs().max().item()
Tri Dao's avatar
Tri Dao committed
408
409
        print(f"HF fp16 logits max diff: {hf_error}")
        print(f"Logits max diff: {(logits - logits_ref).abs().max().item() }")
Tri Dao's avatar
Tri Dao committed
410
        assert (logits - logits_ref).abs().max().item() < 2 * hf_error
Tri Dao's avatar
Tri Dao committed
411
        print(f"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }")
Tri Dao's avatar
Tri Dao committed
412
        assert torch.equal(logits_cg, logits)