test_llama.py 23.4 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
3
4
# Copyright (c) 2023, Tri Dao.

# To run the huggingface implementation, we first need to convert the weights:
# https://github.com/huggingface/transformers/pull/21955
5
# python -m transformers.models.llama.convert_llama_weights_to_hf --input_dir $CHECKPOINT_DIR/llama --model_size 7B --output_dir $CHECKPOINT_DIR/llama/7B-hf
Tri Dao's avatar
Tri Dao committed
6
7
8
9
10
# and repeat for 13B, 30B, 65B

import os
import time
from pathlib import Path
11

Tri Dao's avatar
Tri Dao committed
12
13
14
15
16
current_dir = Path(__file__).parent.absolute()

import torch
import pytest

Tri Dao's avatar
Tri Dao committed
17
18
from einops import rearrange

19
from transformers import LlamaTokenizer, LlamaConfig
Tri Dao's avatar
Tri Dao committed
20
21
from transformers.models.llama.modeling_llama import LlamaForCausalLM

Tri Dao's avatar
Tri Dao committed
22
from flash_attn.models.gpt import GPTLMHeadModel, combine_state_dicts_tp, shard_state_dict_tp
23
from flash_attn.models.llama import remap_state_dict_meta_llama, llama_config_to_gpt2_config, remap_state_dict_hf_llama
Tri Dao's avatar
Tri Dao committed
24
from flash_attn.models.llama import config_from_checkpoint, state_dicts_from_checkpoint
Tri Dao's avatar
Tri Dao committed
25
from flash_attn.utils.distributed import all_gather_raw
Tri Dao's avatar
Tri Dao committed
26
27
28
29
from flash_attn.utils.pretrained import state_dict_from_pretrained
from flash_attn.utils.generation import update_graph_cache


30
31
32
33
34
35
36
37
38
39
40
def _pretrained_state_dict_from_checkpoint(checkpoint_path, model_name, config, checkpoint_format):
    if checkpoint_format == "meta":
        ckpt_state_dicts = state_dicts_from_checkpoint(checkpoint_path, model_name)
        pretrained_state_dicts = [remap_state_dict_meta_llama(s, config) for s in ckpt_state_dicts]
        pretrained_state_dict = combine_state_dicts_tp(pretrained_state_dicts, config)
    else:
        pretrained_state_dict = state_dict_from_pretrained(Path(checkpoint_path) / f'{model_name}-hf')
        pretrained_state_dict = remap_state_dict_hf_llama(pretrained_state_dict, config)
    return pretrained_state_dict


Tri Dao's avatar
Tri Dao committed
41
42
43
44
45
46
47
48
49
@pytest.mark.parametrize('model_name', ["7B"])
def test_llama_state_dict(model_name):
    checkpoint_path = Path(os.environ.get('CHECKPOINT_DIR',
                                          current_dir.parent.parent / 'checkpoints')) / 'llama'
    config = llama_config_to_gpt2_config(config_from_checkpoint(checkpoint_path, model_name))
    ckpt_state_dicts = state_dicts_from_checkpoint(checkpoint_path, model_name)
    pretrained_state_dict = remap_state_dict_meta_llama(ckpt_state_dicts[0], config)
    model = GPTLMHeadModel(config, device='meta')  # Without device='meta' init is very slow
    state_dict = model.state_dict()
50
51
    assert state_dict.keys() == pretrained_state_dict.keys()
    for k in state_dict.keys():
Tri Dao's avatar
Tri Dao committed
52
53
54
55
        assert state_dict[k].shape == pretrained_state_dict[k].shape


@pytest.mark.parametrize('model_name', ["7B", "13B"])
56
57
@pytest.mark.parametrize('checkpoint_format', ["meta", "hf"])
def test_llama_optimized(model_name, checkpoint_format):
Tri Dao's avatar
Tri Dao committed
58
59
60
61
62
63
64
65
66
    """Check that our implementation of LLaMa (with all optimizations enabled) matches the
    HF implementation: the output of our forward pass in fp16 should be around the same as the HF
    forward pass in fp16, when compared to the HF forward pass in fp32.
    """
    checkpoint_path = Path(os.environ.get('CHECKPOINT_DIR',
                                          current_dir.parent.parent / 'checkpoints')) / 'llama'

    dtype = torch.float16
    device = 'cuda'
67
68
    config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format)
    config = llama_config_to_gpt2_config(config)
Tri Dao's avatar
Tri Dao committed
69
70
71
72
73
74
    config.use_flash_attn = True
    config.fused_bias_fc = True
    config.fused_mlp = False  # We don't have fused GatedMLP yet
    config.fused_dropout_add_ln = True
    config.residual_in_fp32 = True

75
76
77
    pretrained_state_dict = _pretrained_state_dict_from_checkpoint(
        checkpoint_path, model_name, config, checkpoint_format
    )
Tri Dao's avatar
Tri Dao committed
78
    model = GPTLMHeadModel(config, device=device, dtype=dtype)
Tri Dao's avatar
Tri Dao committed
79
    model.load_state_dict(pretrained_state_dict)
Tri Dao's avatar
Tri Dao committed
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
    model.eval()

    torch.manual_seed(0)
    batch_size = 2
    max_seqlen = 256
    seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)
    input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
                              device=device)
    with torch.no_grad():
        out = model.transformer(input_ids)
        logits = model(input_ids).logits
    del model

    # Without device_map, the model is loaded on the CPU, which is very slow
    # Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB
    model_ref = LlamaForCausalLM.from_pretrained(Path(checkpoint_path) / f'{model_name}-hf',
                                                 device_map='auto')
    model_ref.eval()
    with torch.no_grad():
        out_ref = model_ref.model(input_ids).last_hidden_state.to(device=device)
        logits_ref = model_ref(input_ids).logits.to(device=device)
    del model_ref

    model_hf = LlamaForCausalLM.from_pretrained(Path(checkpoint_path) / f'{model_name}-hf',
                                                torch_dtype=dtype, device_map={"": device})
    model_hf.eval()
Tri Dao's avatar
Tri Dao committed
106
107
108
    with torch.no_grad():
        out_hf = model_hf.model(input_ids).last_hidden_state
        logits_hf = model_hf(input_ids).logits
Tri Dao's avatar
Tri Dao committed
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
    del model_hf

    print(f'Output max diff: {(out - out_ref).abs().max().item()}')
    print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
    print(f'HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}')
    print(f'HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}')
    assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()

    print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}')
    print(f'Logits mean diff: {(logits - logits_ref).abs().mean().item()}')
    print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}')
    print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}')
    assert (logits - logits_ref).abs().max().item() < 3 * (logits_hf - logits_ref).abs().max().item()


# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_llama.py -k "parallel"
@pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize('model_name', ["13B"])
127
128
@pytest.mark.parametrize('checkpoint_format', ["meta", "hf"])
def test_llama_parallel(model_name, world_size, checkpoint_format):
Tri Dao's avatar
Tri Dao committed
129
130
131
132
133
134
135
136
137
138
    """Check that our implementation of LLaMa (with all optimizations enabled) matches the
    HF implementation: the output of our forward pass in fp16 should be around the same as the HF
    forward pass in fp16, when compared to the HF forward pass in fp32.
    """
    from apex.transformer import parallel_state

    checkpoint_path = Path(os.environ.get('CHECKPOINT_DIR',
                                          current_dir.parent.parent / 'checkpoints')) / 'llama'

    dtype = torch.float16
139
140
    config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format)
    config = llama_config_to_gpt2_config(config)
Tri Dao's avatar
Tri Dao committed
141
142
143
144
145
146
147
148
149
150
151
152
153
154
    config.use_flash_attn = True
    config.fused_bias_fc = True
    config.fused_mlp = False  # We don't have fused GatedMLP yet
    config.fused_dropout_add_ln = True
    config.residual_in_fp32 = True

    if not torch.distributed.is_initialized():
        torch.distributed.init_process_group(backend='nccl', init_method='env://')
    device = f'cuda:{torch.distributed.get_rank()}'
    assert world_size <= torch.distributed.get_world_size()
    parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
    rank = parallel_state.get_tensor_model_parallel_rank()
    process_group = parallel_state.get_tensor_model_parallel_group()

155
156
157
    pretrained_state_dict = _pretrained_state_dict_from_checkpoint(
        checkpoint_path, model_name, config, checkpoint_format
    )
Tri Dao's avatar
Tri Dao committed
158
    model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)
Tri Dao's avatar
Tri Dao committed
159
    model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))
Tri Dao's avatar
Tri Dao committed
160
161
162
163
164
165
166
167
168
169
    model.eval()

    torch.manual_seed(0)
    batch_size = 2
    max_seqlen = 256
    seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)
    input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
                              device=device)
    with torch.no_grad():
        out = model.transformer(input_ids)
Tri Dao's avatar
Tri Dao committed
170
171
        out, _ = all_gather_raw(out, process_group=process_group)
        out = rearrange(out, "(b s) d -> b s d", b=batch_size)
Tri Dao's avatar
Tri Dao committed
172
        logits = model(input_ids).logits
Tri Dao's avatar
Tri Dao committed
173
174
175
        logits = rearrange(logits, "(b s) d -> b s d", b=batch_size)
        logits, _ = all_gather_raw(logits, process_group)
        logits = rearrange(logits, '(n b) ... d -> b ... (n d)', b=batch_size)
Tri Dao's avatar
Tri Dao committed
176
177
    del model

Tri Dao's avatar
Tri Dao committed
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
    if rank == 0:
        # Without device_map, the model is loaded on the CPU, which is very slow
        model_ref = LlamaForCausalLM.from_pretrained(
            Path(checkpoint_path) / f'{model_name}-hf', device_map="auto"
        )
        model_ref.eval()
        with torch.no_grad():
            out_ref = model_ref.model(input_ids).last_hidden_state.to(device=device)
            logits_ref = model_ref(input_ids).logits.to(device=device)
        del model_ref

        model_hf = LlamaForCausalLM.from_pretrained(
            Path(checkpoint_path) / f'{model_name}-hf', torch_dtype=dtype, device_map="auto"
        )
        model_hf.eval()
        with torch.no_grad():
            out_hf = model_hf.model(input_ids).last_hidden_state.to(device=device)
            logits_hf = model_hf(input_ids).logits.to(device=device)
        del model_hf

        print(f'Output max diff: {(out - out_ref).abs().max().item()}')
        print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
        print(f'HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}')
        print(f'HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}')
        assert (out - out_ref).abs().max().item() < 2 * (out_hf - out_ref).abs().max().item()

        print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}')
        print(f'Logits mean diff: {(logits - logits_ref).abs().mean().item()}')
        print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}')
        print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}')
        assert (logits - logits_ref).abs().max().item() < 2 * (logits_hf - logits_ref).abs().max().item()


# @pytest.mark.parametrize('model_name', ["7B", "13B"])
@pytest.mark.parametrize('model_name', ["7B"])
213
214
@pytest.mark.parametrize('checkpoint_format', ["meta", "hf"])
def test_llama_generation(model_name, checkpoint_format):
Tri Dao's avatar
Tri Dao committed
215
216
217
218
219
    checkpoint_path = Path(os.environ.get('CHECKPOINT_DIR',
                                          current_dir.parent.parent / 'checkpoints')) / 'llama'

    dtype = torch.float16
    device = 'cuda'
220
221
    config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format)
    config = llama_config_to_gpt2_config(config)
Tri Dao's avatar
Tri Dao committed
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
    config.use_flash_attn = True
    config.fused_bias_fc = True
    config.fused_mlp = False  # We don't have fused GatedMLP yet
    config.fused_dropout_add_ln = True
    config.residual_in_fp32 = True

    tokenizer = LlamaTokenizer.from_pretrained(Path(checkpoint_path) / f'{model_name}-hf')
    eos_token_id = tokenizer.eos_token_id

    torch.manual_seed(0)
    batch_size = 1
    seqlen = 100
    max_length = 150
    input_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), dtype=torch.long,
                              device=device)

    model_hf = LlamaForCausalLM.from_pretrained(Path(checkpoint_path) / f'{model_name}-hf',
                                                torch_dtype=dtype, device_map={"": device})
    model_hf.eval()
    print("HF fp16")
    torch.cuda.synchronize()
    start = time.time()
    out_hf = model_hf.generate(input_ids=input_ids, max_length=max_length,
                               return_dict_in_generate=True, output_scores=True)
    torch.cuda.synchronize()
    print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
    del model_hf

Tri Dao's avatar
Tri Dao committed
250
    # Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB
Tri Dao's avatar
Tri Dao committed
251
    model_ref = LlamaForCausalLM.from_pretrained(Path(checkpoint_path) / f'{model_name}-hf',
Tri Dao's avatar
Tri Dao committed
252
                                                 device_map='auto')
Tri Dao's avatar
Tri Dao committed
253
254
    model_ref.eval()
    with torch.no_grad():
Tri Dao's avatar
Tri Dao committed
255
        logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1):-1].to(device=device)
Tri Dao's avatar
Tri Dao committed
256
257
    del model_ref

258
259
260
    pretrained_state_dict = _pretrained_state_dict_from_checkpoint(
        checkpoint_path, model_name, config, checkpoint_format
    )
Tri Dao's avatar
Tri Dao committed
261
    model = GPTLMHeadModel(config, device=device, dtype=dtype)
Tri Dao's avatar
Tri Dao committed
262
    model.load_state_dict(pretrained_state_dict)
Tri Dao's avatar
Tri Dao committed
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
    model.eval()

    print('Without CUDA graph')
    torch.cuda.synchronize()
    start = time.time()
    out = model.generate(input_ids=input_ids, max_length=max_length,
                         eos_token_id=eos_token_id, fused_ft_kernel=True,
                         return_dict_in_generate=True, output_scores=True, timing=True,
                         teacher_outputs=out_hf.sequences)
    torch.cuda.synchronize()
    print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')

    # Capture graph outside the timing loop
    batch_size, seqlen_og = input_ids.shape
    model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
    print('With CUDA graph')
    torch.cuda.synchronize()
    start = time.time()
    out_cg = model.generate(input_ids=input_ids, max_length=max_length,
                            fused_ft_kernel=True, cg=True,
                            return_dict_in_generate=True, output_scores=True, timing=True,
                            teacher_outputs=out_hf.sequences)
    torch.cuda.synchronize()
    print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')

    with torch.no_grad():
        logits_parallel = model(out_hf.sequences).logits[:, (seqlen - 1):-1]
    logits_hf = torch.stack(out_hf.scores, dim=1)
    logits = torch.stack(out.scores, dim=1)
    logits_cg = torch.stack(out_cg.scores, dim=1)

    del model

    hf_error = (logits_hf - logits_ref).abs().max().item()

    print(f'HF fp16 logits max diff: {hf_error}')
299
300
    print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}')
    print(f'Logits CG max diff: {(logits_cg - logits_ref).abs().max().item()}')
301
302
303

    assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error
    assert (logits - logits_ref).abs().max().item() < 2 * hf_error
Tri Dao's avatar
Tri Dao committed
304
    assert torch.equal(logits_cg, logits)
Tri Dao's avatar
Tri Dao committed
305
306
307
308
309


# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_llama.py -k "llama_parallel_generation"
@pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize('model_name', ["13B"])
310
311
@pytest.mark.parametrize('checkpoint_format', ["meta", "hf"])
def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
Tri Dao's avatar
Tri Dao committed
312
313
314
315
316
317
318
319
320
321
    """Check that our implementation matches the HF implementation:
    the scores in fp16 should be around the same as the HF scores in fp16, when compared to
    the HF scores in fp32.
    """
    from apex.transformer import parallel_state

    checkpoint_path = Path(os.environ.get('CHECKPOINT_DIR',
                                          current_dir.parent.parent / 'checkpoints')) / 'llama'

    dtype = torch.float16
322
323
    config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format)
    config = llama_config_to_gpt2_config(config)
Tri Dao's avatar
Tri Dao committed
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
    config.use_flash_attn = False
    config.fused_bias_fc = True
    config.fused_mlp = False  # We don't have fused GatedMLP yet
    config.fused_dropout_add_ln = False
    config.residual_in_fp32 = True
    config.pad_vocab_size_multiple = 8 * world_size
    config.sequence_parallel = False  # Need to set this to False for generation

    os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0"
    if not torch.distributed.is_initialized():
        torch.distributed.init_process_group(backend='nccl', init_method='env://')
    device = f'cuda:{torch.distributed.get_rank()}'
    assert world_size <= torch.distributed.get_world_size()
    parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
    rank = parallel_state.get_tensor_model_parallel_rank()
    process_group = parallel_state.get_tensor_model_parallel_group()

    torch.manual_seed(0)
    batch_size = 1
    seqlen = 100
    max_length = 150
    input_ids = torch.randint(0, config.vocab_size, (batch_size, seqlen), dtype=torch.long,
                              device=device)

    # Need this, otherwise when we capture the graph the process for GPU 1 would run on both
    # GPU0 and GPU1 and things would hang
    torch.cuda.set_device(device)

352
353
354
    pretrained_state_dict = _pretrained_state_dict_from_checkpoint(
        checkpoint_path, model_name, config, checkpoint_format
    )
Tri Dao's avatar
Tri Dao committed
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
    model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)
    model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))
    model.eval()

    print('Without CUDA graph')
    out = model.generate(
        input_ids=input_ids, max_length=max_length, tensor_parallel=world_size,
        vocab_size=config.vocab_size, fused_ft_kernel=True,
        # teacher_outputs=out_hf.sequences,
        return_dict_in_generate=True, output_scores=True, timing=True
    )

    # Capture graph outside the timing loop
    batch_size, seqlen_og = input_ids.shape
    model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
    print('With CUDA graph')
    out_cg = model.generate(
        input_ids=input_ids, max_length=max_length, tensor_parallel=world_size,
        vocab_size=config.vocab_size, fused_ft_kernel=True, cg=True,
        # teacher_outputs=out_hf.sequences,
        return_dict_in_generate=True, output_scores=True, timing=True
    )
    del model
    parallel_state.destroy_model_parallel()

    if rank == 0:
        # Without device_map, the model is loaded on the CPU, which is very slow
        model_hf = LlamaForCausalLM.from_pretrained(
            Path(checkpoint_path) / f'{model_name}-hf', torch_dtype=dtype, device_map="auto"
        )
        model_hf.eval()
        print("HF fp16")
        torch.cuda.synchronize()
        start = time.time()
        with torch.inference_mode():
            out_hf = model_hf.generate(
                input_ids=input_ids, max_length=max_length, return_dict_in_generate=True,
                output_scores=True
            )
        torch.cuda.synchronize()
        print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
        del model_hf

        model_ref = LlamaForCausalLM.from_pretrained(
            Path(checkpoint_path) / f'{model_name}-hf', device_map="auto"
        )
        model_ref.eval()
        with torch.inference_mode():
            logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1):-1]
        del model_ref
        logits_hf = torch.stack(out_hf.scores, dim=1)

        logits = torch.stack(out.scores, dim=1)
        logits_cg = torch.stack(out_cg.scores, dim=1)

        hf_error = (logits_hf - logits_ref).abs().max().item()
        print(f'HF fp16 logits max diff: {hf_error}')
412
        print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}')
Tri Dao's avatar
Tri Dao committed
413
        assert (logits - logits_ref).abs().max().item() < 2 * hf_error
414
        print(f'Logits CG max diff: {(logits_cg - logits_ref).abs().max().item()}')
Tri Dao's avatar
Tri Dao committed
415
        assert torch.equal(logits_cg, logits)
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509


@torch.no_grad()
@pytest.mark.parametrize('world_size', [2])
def test_llama_parallel_uneven_num_heads(world_size):
    from apex.transformer import parallel_state

    checkpoint_path = Path(os.environ.get('CHECKPOINT_DIR', current_dir.parent.parent / 'checkpoints')) / 'llama'
    num_attention_heads = world_size + 1
    model_name = f'teeny-{num_attention_heads}-heads'

    if not torch.distributed.is_initialized():
        torch.distributed.init_process_group(backend='nccl', init_method='env://')
    device = f'cuda:{torch.distributed.get_rank()}'
    assert world_size <= torch.distributed.get_world_size()
    parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
    rank = parallel_state.get_tensor_model_parallel_rank()
    process_group = parallel_state.get_tensor_model_parallel_group()

    dtype = torch.float16
    llama_config = LlamaConfig(
        hidden_size=256 * num_attention_heads,  # ParallelGatedMlp hidden_features must be divisible by 256
        intermediate_size=256 * num_attention_heads * 4,
        num_hidden_layers=4,
        num_attention_heads=num_attention_heads,
        initializer_range=0.5,  # Set crazy init range so we don't have near zero weights implying a vacuous test.
    )
    config = llama_config_to_gpt2_config(llama_config)
    config.use_flash_attn = True
    config.fused_bias_fc = True
    config.fused_mlp = False  # We don't have fused GatedMLP yet
    config.fused_dropout_add_ln = True
    config.residual_in_fp32 = True

    torch.manual_seed(0)
    batch_size = 2
    max_seqlen = 256
    seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)
    input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
                              device=device)

    # Create a shared test model.
    if rank == 0:
        LlamaForCausalLM(config=llama_config).save_pretrained(checkpoint_path / f"{model_name}-hf")
    torch.distributed.barrier()

    # Run the standard forward pass test.
    pretrained_state_dict = _pretrained_state_dict_from_checkpoint(
        checkpoint_path, model_name, config, checkpoint_format="hf"
    )
    model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)
    model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))
    model.eval()

    # TODO: Avoid duplicate code. Modularize the comparison of two forward pass diffs.
    out = model.transformer(input_ids)
    out, _ = all_gather_raw(out, process_group=process_group)
    out = rearrange(out, "(b s) d -> b s d", b=batch_size)
    logits = model(input_ids).logits
    logits = rearrange(logits, "(b s) d -> b s d", b=batch_size)
    logits, _ = all_gather_raw(logits, process_group)
    logits = rearrange(logits, '(n b) ... d -> b ... (n d)', b=batch_size)

    if rank == 0:
        model_ref = LlamaForCausalLM.from_pretrained(
            Path(checkpoint_path) / f'{model_name}-hf', device_map="auto"
        )
        model_ref.eval()
        out_ref = model_ref.model(input_ids).last_hidden_state.to(device=device)
        logits_ref = model_ref(input_ids).logits.to(device=device)
        del model_ref

        model_hf = LlamaForCausalLM.from_pretrained(
            Path(checkpoint_path) / f'{model_name}-hf', torch_dtype=dtype, device_map="auto"
        )
        model_hf.eval()
        out_hf = model_hf.model(input_ids).last_hidden_state.to(device=device)
        logits_hf = model_hf(input_ids).logits.to(device=device)
        del model_hf

        print(f'Output max diff: {(out - out_ref).abs().max().item()}')
        print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
        print(f'HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}')
        print(f'HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}')
        assert (out - out_ref).abs().max().item() < 2 * (out_hf - out_ref).abs().max().item()

        print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}')
        print(f'Logits mean diff: {(logits - logits_ref).abs().mean().item()}')
        print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}')
        print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}')
        assert (logits - logits_ref).abs().max().item() < 2 * (logits_hf - logits_ref).abs().max().item()

        import shutil
        shutil.rmtree(checkpoint_path / f'{model_name}-hf')