test_llama.py 25.9 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
# Copyright (c) 2023, Tri Dao.

3
# To run the huggingface implementation of LLaMa (1), we first need to convert the weights:
Tri Dao's avatar
Tri Dao committed
4
# https://github.com/huggingface/transformers/pull/21955
5
# python -m transformers.models.llama.convert_llama_weights_to_hf --input_dir $CHECKPOINT_DIR/llama --model_size 7B --output_dir $CHECKPOINT_DIR/llama/7B-hf
Tri Dao's avatar
Tri Dao committed
6
7
8
9
10
# and repeat for 13B, 30B, 65B

import os
import time
from pathlib import Path
11

Tri Dao's avatar
Tri Dao committed
12
13
current_dir = Path(__file__).parent.absolute()

14
import shutil
Tri Dao's avatar
Tri Dao committed
15

Tri Dao's avatar
Tri Dao committed
16
17
import pytest
import torch
Tri Dao's avatar
Tri Dao committed
18
19
from einops import rearrange
from flash_attn.models.gpt import GPTLMHeadModel, combine_state_dicts_tp, shard_state_dict_tp
20
from flash_attn.models.llama import (
Tri Dao's avatar
Tri Dao committed
21
22
    config_from_checkpoint,
    inv_remap_state_dict_hf_llama,
23
24
    llama_config_to_gpt2_config,
    remap_state_dict_hf_llama,
Tri Dao's avatar
Tri Dao committed
25
26
    remap_state_dict_meta_llama,
    state_dicts_from_checkpoint,
27
)
Tri Dao's avatar
Tri Dao committed
28
from flash_attn.utils.distributed import all_gather_raw
Tri Dao's avatar
Tri Dao committed
29
from flash_attn.utils.generation import update_graph_cache
Tri Dao's avatar
Tri Dao committed
30
31
32
from flash_attn.utils.pretrained import state_dict_from_pretrained
from transformers import LlamaConfig, LlamaTokenizer
from transformers.models.llama.modeling_llama import LlamaForCausalLM
33
from transformers import AutoConfig
Tri Dao's avatar
Tri Dao committed
34
35


36
37
38
39
40
41
def _pretrained_state_dict_from_checkpoint(checkpoint_path, model_name, config, checkpoint_format):
    if checkpoint_format == "meta":
        ckpt_state_dicts = state_dicts_from_checkpoint(checkpoint_path, model_name)
        pretrained_state_dicts = [remap_state_dict_meta_llama(s, config) for s in ckpt_state_dicts]
        pretrained_state_dict = combine_state_dicts_tp(pretrained_state_dicts, config)
    else:
42
43
44
        pretrained_state_dict = state_dict_from_pretrained(
            Path(checkpoint_path) / f"{model_name}-hf"
        )
45
46
47
48
        pretrained_state_dict = remap_state_dict_hf_llama(pretrained_state_dict, config)
    return pretrained_state_dict


49
@pytest.mark.parametrize("model_name", ["7B"])
Tri Dao's avatar
Tri Dao committed
50
def test_llama_state_dict(model_name):
51
52
53
    checkpoint_path = (
        Path(os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")) / "llama"
    )
Tri Dao's avatar
Tri Dao committed
54
55
56
    config = llama_config_to_gpt2_config(config_from_checkpoint(checkpoint_path, model_name))
    ckpt_state_dicts = state_dicts_from_checkpoint(checkpoint_path, model_name)
    pretrained_state_dict = remap_state_dict_meta_llama(ckpt_state_dicts[0], config)
57
    model = GPTLMHeadModel(config, device="meta")  # Without device='meta' init is very slow
Tri Dao's avatar
Tri Dao committed
58
    state_dict = model.state_dict()
59
60
    assert state_dict.keys() == pretrained_state_dict.keys()
    for k in state_dict.keys():
Tri Dao's avatar
Tri Dao committed
61
62
63
        assert state_dict[k].shape == pretrained_state_dict[k].shape


64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# TinyLlama-1.1B is to test MQA
@pytest.mark.parametrize(
    "model_name", ["meta-llama/Llama-2-7b-hf", "PY007/TinyLlama-1.1B-step-50K-105b"]
)
def test_inv_remap_state_dict_hf_llama(model_name):
    config = llama_config_to_gpt2_config(
        AutoConfig.from_pretrained(model_name, trust_remote_code=True)
    )
    state_dict = state_dict_from_pretrained(model_name)
    # inv_remap_state_dict_hf_llama should be the inverse of remap_state_dict_hf_llama
    state_dict = {key: val for key, val in state_dict.items() if "rotary_emb.inv_freq" not in key}
    pretrained_state_dict = remap_state_dict_hf_llama(state_dict, config)
    state_dict_recover = inv_remap_state_dict_hf_llama(pretrained_state_dict, config)
    assert set(state_dict_recover.keys()) == set(state_dict.keys())
    for key in state_dict_recover.keys():
        torch.testing.assert_close(state_dict_recover[key], state_dict[key])


# TinyLlama-1.1B is to test MQA
@pytest.mark.parametrize(
    "model_name",
    [
        "7B",  # Llama 1
        "13B",  # Llama 1
        "meta-llama/Llama-2-13b-hf",
        "codellama/CodeLlama-7b-hf",
        "codellama/CodeLlama-13b-hf",
        "codellama/CodeLlama-34b-hf",
        "PY007/TinyLlama-1.1B-step-50K-105b",
    ],
)
def test_llama_optimized(model_name):
Tri Dao's avatar
Tri Dao committed
96
97
98
99
    """Check that our implementation of LLaMa (with all optimizations enabled) matches the
    HF implementation: the output of our forward pass in fp16 should be around the same as the HF
    forward pass in fp16, when compared to the HF forward pass in fp32.
    """
100
101
102
    checkpoint_path = (
        Path(os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")) / "llama"
    )
Tri Dao's avatar
Tri Dao committed
103
104

    dtype = torch.float16
105
    device = "cuda"
106
107
108
109
110
111
112
    if "/" in model_name:  # Download from HF
        config = llama_config_to_gpt2_config(
            AutoConfig.from_pretrained(model_name, trust_remote_code=True)
        )
    else:
        config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format="meta")
        config = llama_config_to_gpt2_config(config)
Tri Dao's avatar
Tri Dao committed
113
114
115
116
117
118
    config.use_flash_attn = True
    config.fused_bias_fc = True
    config.fused_mlp = False  # We don't have fused GatedMLP yet
    config.fused_dropout_add_ln = True
    config.residual_in_fp32 = True

119
120
121
122
123
124
125
126
    if "/" in model_name:  # Download from HF
        pretrained_state_dict = remap_state_dict_hf_llama(
            state_dict_from_pretrained(model_name), config
        )
    else:
        pretrained_state_dict = _pretrained_state_dict_from_checkpoint(
            checkpoint_path, model_name, config, checkpoint_format="meta"
        )
Tri Dao's avatar
Tri Dao committed
127
    model = GPTLMHeadModel(config, device=device, dtype=dtype)
Tri Dao's avatar
Tri Dao committed
128
    model.load_state_dict(pretrained_state_dict)
Tri Dao's avatar
Tri Dao committed
129
130
131
132
133
134
    model.eval()

    torch.manual_seed(0)
    batch_size = 2
    max_seqlen = 256
    seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)
135
136
137
    input_ids = torch.randint(
        0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
    )
Tri Dao's avatar
Tri Dao committed
138
139
140
141
142
143
144
    with torch.no_grad():
        out = model.transformer(input_ids)
        logits = model(input_ids).logits
    del model

    # Without device_map, the model is loaded on the CPU, which is very slow
    # Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB
145
    model_ref = LlamaForCausalLM.from_pretrained(
146
147
        model_name if "/" in model_name else Path(checkpoint_path) / f"{model_name}-hf",
        device_map="auto",
148
    )
Tri Dao's avatar
Tri Dao committed
149
150
151
152
153
154
    model_ref.eval()
    with torch.no_grad():
        out_ref = model_ref.model(input_ids).last_hidden_state.to(device=device)
        logits_ref = model_ref(input_ids).logits.to(device=device)
    del model_ref

155
    model_hf = LlamaForCausalLM.from_pretrained(
156
157
158
        model_name if "/" in model_name else Path(checkpoint_path) / f"{model_name}-hf",
        torch_dtype=dtype,
        device_map={"": device},
159
    )
Tri Dao's avatar
Tri Dao committed
160
    model_hf.eval()
Tri Dao's avatar
Tri Dao committed
161
162
163
    with torch.no_grad():
        out_hf = model_hf.model(input_ids).last_hidden_state
        logits_hf = model_hf(input_ids).logits
Tri Dao's avatar
Tri Dao committed
164
165
    del model_hf

166
167
168
169
    print(f"Output max diff: {(out - out_ref).abs().max().item()}")
    print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
    print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
    print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
Tri Dao's avatar
Tri Dao committed
170
171
    assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()

172
173
174
175
176
177
178
    print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
    print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
    print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
    print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
    assert (logits - logits_ref).abs().max().item() < 3 * (
        logits_hf - logits_ref
    ).abs().max().item()
Tri Dao's avatar
Tri Dao committed
179
180
181


# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_llama.py -k "parallel"
182
@pytest.mark.parametrize("world_size", [2])
183
184
185
186
@pytest.mark.parametrize(
    "model_name", ["13B", "meta-llama/Llama-2-13b-hf", "codellama/CodeLlama-34b-hf"]
)
def test_llama_parallel(model_name, world_size):
Tri Dao's avatar
Tri Dao committed
187
188
189
190
191
192
    """Check that our implementation of LLaMa (with all optimizations enabled) matches the
    HF implementation: the output of our forward pass in fp16 should be around the same as the HF
    forward pass in fp16, when compared to the HF forward pass in fp32.
    """
    from apex.transformer import parallel_state

193
194
195
    checkpoint_path = (
        Path(os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")) / "llama"
    )
Tri Dao's avatar
Tri Dao committed
196
197

    dtype = torch.float16
198
199
200
201
202
203
204
    if "/" in model_name:  # Download from HF
        config = llama_config_to_gpt2_config(
            AutoConfig.from_pretrained(model_name, trust_remote_code=True)
        )
    else:
        config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format="meta")
        config = llama_config_to_gpt2_config(config)
Tri Dao's avatar
Tri Dao committed
205
206
207
208
209
210
211
    config.use_flash_attn = True
    config.fused_bias_fc = True
    config.fused_mlp = False  # We don't have fused GatedMLP yet
    config.fused_dropout_add_ln = True
    config.residual_in_fp32 = True

    if not torch.distributed.is_initialized():
212
213
        torch.distributed.init_process_group(backend="nccl", init_method="env://")
    device = f"cuda:{torch.distributed.get_rank()}"
Tri Dao's avatar
Tri Dao committed
214
215
216
217
218
    assert world_size <= torch.distributed.get_world_size()
    parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
    rank = parallel_state.get_tensor_model_parallel_rank()
    process_group = parallel_state.get_tensor_model_parallel_group()

219
220
221
222
223
224
225
226
    if "/" in model_name:  # Download from HF
        pretrained_state_dict = remap_state_dict_hf_llama(
            state_dict_from_pretrained(model_name), config
        )
    else:
        pretrained_state_dict = _pretrained_state_dict_from_checkpoint(
            checkpoint_path, model_name, config, checkpoint_format="meta"
        )
Tri Dao's avatar
Tri Dao committed
227
    model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)
Tri Dao's avatar
Tri Dao committed
228
    model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))
Tri Dao's avatar
Tri Dao committed
229
230
231
232
233
234
    model.eval()

    torch.manual_seed(0)
    batch_size = 2
    max_seqlen = 256
    seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)
235
236
237
    input_ids = torch.randint(
        0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
    )
Tri Dao's avatar
Tri Dao committed
238
239
    with torch.no_grad():
        out = model.transformer(input_ids)
Tri Dao's avatar
Tri Dao committed
240
241
        out, _ = all_gather_raw(out, process_group=process_group)
        out = rearrange(out, "(b s) d -> b s d", b=batch_size)
Tri Dao's avatar
Tri Dao committed
242
        logits = model(input_ids).logits
Tri Dao's avatar
Tri Dao committed
243
244
        logits = rearrange(logits, "(b s) d -> b s d", b=batch_size)
        logits, _ = all_gather_raw(logits, process_group)
245
        logits = rearrange(logits, "(n b) ... d -> b ... (n d)", b=batch_size)
Tri Dao's avatar
Tri Dao committed
246
247
    del model

Tri Dao's avatar
Tri Dao committed
248
249
250
    if rank == 0:
        # Without device_map, the model is loaded on the CPU, which is very slow
        model_ref = LlamaForCausalLM.from_pretrained(
251
252
            model_name if "/" in model_name else Path(checkpoint_path) / f"{model_name}-hf",
            device_map="auto",
Tri Dao's avatar
Tri Dao committed
253
254
255
256
257
258
259
260
        )
        model_ref.eval()
        with torch.no_grad():
            out_ref = model_ref.model(input_ids).last_hidden_state.to(device=device)
            logits_ref = model_ref(input_ids).logits.to(device=device)
        del model_ref

        model_hf = LlamaForCausalLM.from_pretrained(
261
262
263
            model_name if "/" in model_name else Path(checkpoint_path) / f"{model_name}-hf",
            torch_dtype=dtype,
            device_map="auto",
Tri Dao's avatar
Tri Dao committed
264
265
266
267
268
269
270
        )
        model_hf.eval()
        with torch.no_grad():
            out_hf = model_hf.model(input_ids).last_hidden_state.to(device=device)
            logits_hf = model_hf(input_ids).logits.to(device=device)
        del model_hf

271
272
273
274
        print(f"Output max diff: {(out - out_ref).abs().max().item()}")
        print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
        print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
        print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
Tri Dao's avatar
Tri Dao committed
275
276
        assert (out - out_ref).abs().max().item() < 2 * (out_hf - out_ref).abs().max().item()

277
278
279
280
281
282
283
        print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
        print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
        print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
        print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
        assert (logits - logits_ref).abs().max().item() < 2 * (
            logits_hf - logits_ref
        ).abs().max().item()
Tri Dao's avatar
Tri Dao committed
284
285
286


# @pytest.mark.parametrize('model_name', ["7B", "13B"])
287
288
@pytest.mark.parametrize("model_name", ["7B"])
@pytest.mark.parametrize("checkpoint_format", ["meta", "hf"])
289
def test_llama_generation(model_name, checkpoint_format):
290
291
292
    checkpoint_path = (
        Path(os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")) / "llama"
    )
Tri Dao's avatar
Tri Dao committed
293
294

    dtype = torch.float16
295
    device = "cuda"
296
297
    config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format)
    config = llama_config_to_gpt2_config(config)
Tri Dao's avatar
Tri Dao committed
298
299
300
301
302
303
    config.use_flash_attn = True
    config.fused_bias_fc = True
    config.fused_mlp = False  # We don't have fused GatedMLP yet
    config.fused_dropout_add_ln = True
    config.residual_in_fp32 = True

304
    tokenizer = LlamaTokenizer.from_pretrained(Path(checkpoint_path) / f"{model_name}-hf")
Tri Dao's avatar
Tri Dao committed
305
306
307
308
309
310
    eos_token_id = tokenizer.eos_token_id

    torch.manual_seed(0)
    batch_size = 1
    seqlen = 100
    max_length = 150
311
312
313
    input_ids = torch.randint(
        0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
    )
Tri Dao's avatar
Tri Dao committed
314

315
316
317
    model_hf = LlamaForCausalLM.from_pretrained(
        Path(checkpoint_path) / f"{model_name}-hf", torch_dtype=dtype, device_map={"": device}
    )
Tri Dao's avatar
Tri Dao committed
318
319
320
321
    model_hf.eval()
    print("HF fp16")
    torch.cuda.synchronize()
    start = time.time()
322
323
324
    out_hf = model_hf.generate(
        input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True
    )
Tri Dao's avatar
Tri Dao committed
325
    torch.cuda.synchronize()
326
    print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
Tri Dao's avatar
Tri Dao committed
327
328
    del model_hf

Tri Dao's avatar
Tri Dao committed
329
    # Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB
330
331
332
    model_ref = LlamaForCausalLM.from_pretrained(
        Path(checkpoint_path) / f"{model_name}-hf", device_map="auto"
    )
Tri Dao's avatar
Tri Dao committed
333
334
    model_ref.eval()
    with torch.no_grad():
335
        logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1].to(device=device)
Tri Dao's avatar
Tri Dao committed
336
337
    del model_ref

338
339
340
    pretrained_state_dict = _pretrained_state_dict_from_checkpoint(
        checkpoint_path, model_name, config, checkpoint_format
    )
Tri Dao's avatar
Tri Dao committed
341
    model = GPTLMHeadModel(config, device=device, dtype=dtype)
Tri Dao's avatar
Tri Dao committed
342
    model.load_state_dict(pretrained_state_dict)
Tri Dao's avatar
Tri Dao committed
343
344
    model.eval()

345
    print("Without CUDA graph")
Tri Dao's avatar
Tri Dao committed
346
347
    torch.cuda.synchronize()
    start = time.time()
348
349
350
351
352
353
    out = model.generate(
        input_ids=input_ids,
        max_length=max_length,
        eos_token_id=eos_token_id,
        return_dict_in_generate=True,
        output_scores=True,
Tri Dao's avatar
Tri Dao committed
354
        enable_timing=True,
355
356
        teacher_outputs=out_hf.sequences,
    )
Tri Dao's avatar
Tri Dao committed
357
    torch.cuda.synchronize()
358
    print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
Tri Dao's avatar
Tri Dao committed
359
360
361

    # Capture graph outside the timing loop
    batch_size, seqlen_og = input_ids.shape
362
    model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
363
    print("With CUDA graph")
Tri Dao's avatar
Tri Dao committed
364
365
    torch.cuda.synchronize()
    start = time.time()
366
367
368
369
370
371
    out_cg = model.generate(
        input_ids=input_ids,
        max_length=max_length,
        cg=True,
        return_dict_in_generate=True,
        output_scores=True,
Tri Dao's avatar
Tri Dao committed
372
        enable_timing=True,
373
374
        teacher_outputs=out_hf.sequences,
    )
Tri Dao's avatar
Tri Dao committed
375
    torch.cuda.synchronize()
376
    print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
Tri Dao's avatar
Tri Dao committed
377
378

    with torch.no_grad():
379
        logits_parallel = model(out_hf.sequences).logits[:, (seqlen - 1) : -1]
Tri Dao's avatar
Tri Dao committed
380
381
382
383
384
385
386
387
    logits_hf = torch.stack(out_hf.scores, dim=1)
    logits = torch.stack(out.scores, dim=1)
    logits_cg = torch.stack(out_cg.scores, dim=1)

    del model

    hf_error = (logits_hf - logits_ref).abs().max().item()

388
389
390
    print(f"HF fp16 logits max diff: {hf_error}")
    print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
    print(f"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item()}")
391
392
393

    assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error
    assert (logits - logits_ref).abs().max().item() < 2 * hf_error
Tri Dao's avatar
Tri Dao committed
394
    assert torch.equal(logits_cg, logits)
Tri Dao's avatar
Tri Dao committed
395
396
397


# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_llama.py -k "llama_parallel_generation"
398
@pytest.mark.parametrize("world_size", [2])
399
400
401
402
@pytest.mark.parametrize(
    "model_name", ["13B", "meta-llama/Llama-2-13b-hf", "codellama/CodeLlama-34b-hf"]
)
def test_llama_parallel_generation(model_name, world_size):
Tri Dao's avatar
Tri Dao committed
403
404
405
406
407
408
    """Check that our implementation matches the HF implementation:
    the scores in fp16 should be around the same as the HF scores in fp16, when compared to
    the HF scores in fp32.
    """
    from apex.transformer import parallel_state

409
410
411
    checkpoint_path = (
        Path(os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")) / "llama"
    )
Tri Dao's avatar
Tri Dao committed
412
413

    dtype = torch.float16
414
415
416
417
418
419
420
421
    if "/" in model_name:  # Download from HF
        config = llama_config_to_gpt2_config(
            AutoConfig.from_pretrained(model_name, trust_remote_code=True)
        )
    else:
        config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format="meta")
        config = llama_config_to_gpt2_config(config)
    config.use_flash_attn = True
Tri Dao's avatar
Tri Dao committed
422
423
    config.fused_bias_fc = True
    config.fused_mlp = False  # We don't have fused GatedMLP yet
424
    config.fused_dropout_add_ln = True
Tri Dao's avatar
Tri Dao committed
425
426
427
428
429
430
    config.residual_in_fp32 = True
    config.pad_vocab_size_multiple = 8 * world_size
    config.sequence_parallel = False  # Need to set this to False for generation

    os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0"
    if not torch.distributed.is_initialized():
431
432
        torch.distributed.init_process_group(backend="nccl", init_method="env://")
    device = f"cuda:{torch.distributed.get_rank()}"
Tri Dao's avatar
Tri Dao committed
433
434
435
436
437
438
439
440
441
    assert world_size <= torch.distributed.get_world_size()
    parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
    rank = parallel_state.get_tensor_model_parallel_rank()
    process_group = parallel_state.get_tensor_model_parallel_group()

    torch.manual_seed(0)
    batch_size = 1
    seqlen = 100
    max_length = 150
442
443
444
    input_ids = torch.randint(
        0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
    )
Tri Dao's avatar
Tri Dao committed
445
446
447
448
449

    # Need this, otherwise when we capture the graph the process for GPU 1 would run on both
    # GPU0 and GPU1 and things would hang
    torch.cuda.set_device(device)

450
451
452
453
454
455
456
457
    if "/" in model_name:  # Download from HF
        pretrained_state_dict = remap_state_dict_hf_llama(
            state_dict_from_pretrained(model_name), config
        )
    else:
        pretrained_state_dict = _pretrained_state_dict_from_checkpoint(
            checkpoint_path, model_name, config, checkpoint_format="meta"
        )
Tri Dao's avatar
Tri Dao committed
458
459
460
461
    model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)
    model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))
    model.eval()

462
    print("Without CUDA graph")
Tri Dao's avatar
Tri Dao committed
463
    out = model.generate(
464
465
466
467
        input_ids=input_ids,
        max_length=max_length,
        tensor_parallel=world_size,
        vocab_size=config.vocab_size,
Tri Dao's avatar
Tri Dao committed
468
        # teacher_outputs=out_hf.sequences,
469
470
        return_dict_in_generate=True,
        output_scores=True,
Tri Dao's avatar
Tri Dao committed
471
        enable_timing=True,
Tri Dao's avatar
Tri Dao committed
472
473
474
475
    )

    # Capture graph outside the timing loop
    batch_size, seqlen_og = input_ids.shape
476
    model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
477
    print("With CUDA graph")
Tri Dao's avatar
Tri Dao committed
478
    out_cg = model.generate(
479
480
481
482
483
        input_ids=input_ids,
        max_length=max_length,
        tensor_parallel=world_size,
        vocab_size=config.vocab_size,
        cg=True,
Tri Dao's avatar
Tri Dao committed
484
        # teacher_outputs=out_hf.sequences,
485
486
        return_dict_in_generate=True,
        output_scores=True,
Tri Dao's avatar
Tri Dao committed
487
        enable_timing=True,
Tri Dao's avatar
Tri Dao committed
488
489
490
491
492
493
494
    )
    del model
    parallel_state.destroy_model_parallel()

    if rank == 0:
        # Without device_map, the model is loaded on the CPU, which is very slow
        model_hf = LlamaForCausalLM.from_pretrained(
495
496
497
            model_name if "/" in model_name else Path(checkpoint_path) / f"{model_name}-hf",
            torch_dtype=dtype,
            device_map="auto",
Tri Dao's avatar
Tri Dao committed
498
499
500
501
502
503
504
        )
        model_hf.eval()
        print("HF fp16")
        torch.cuda.synchronize()
        start = time.time()
        with torch.inference_mode():
            out_hf = model_hf.generate(
505
506
507
508
                input_ids=input_ids,
                max_length=max_length,
                return_dict_in_generate=True,
                output_scores=True,
Tri Dao's avatar
Tri Dao committed
509
510
            )
        torch.cuda.synchronize()
511
        print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
Tri Dao's avatar
Tri Dao committed
512
513
514
        del model_hf

        model_ref = LlamaForCausalLM.from_pretrained(
515
516
            model_name if "/" in model_name else Path(checkpoint_path) / f"{model_name}-hf",
            device_map="auto",
Tri Dao's avatar
Tri Dao committed
517
518
519
        )
        model_ref.eval()
        with torch.inference_mode():
520
            logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1]
Tri Dao's avatar
Tri Dao committed
521
522
523
524
525
526
527
        del model_ref
        logits_hf = torch.stack(out_hf.scores, dim=1)

        logits = torch.stack(out.scores, dim=1)
        logits_cg = torch.stack(out_cg.scores, dim=1)

        hf_error = (logits_hf - logits_ref).abs().max().item()
528
529
        print(f"HF fp16 logits max diff: {hf_error}")
        print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
Tri Dao's avatar
Tri Dao committed
530
        assert (logits - logits_ref).abs().max().item() < 2 * hf_error
531
        print(f"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item()}")
Tri Dao's avatar
Tri Dao committed
532
        assert torch.equal(logits_cg, logits)
533
534
535


@torch.no_grad()
536
@pytest.mark.parametrize("world_size", [2])
537
538
539
def test_llama_parallel_uneven_num_heads(world_size):
    from apex.transformer import parallel_state

540
541
542
    checkpoint_path = (
        Path(os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")) / "llama"
    )
543
    num_attention_heads = world_size + 1
544
    model_name = f"teeny-{num_attention_heads}-heads"
545
546

    if not torch.distributed.is_initialized():
547
548
        torch.distributed.init_process_group(backend="nccl", init_method="env://")
    device = f"cuda:{torch.distributed.get_rank()}"
549
550
551
552
553
554
555
    assert world_size <= torch.distributed.get_world_size()
    parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
    rank = parallel_state.get_tensor_model_parallel_rank()
    process_group = parallel_state.get_tensor_model_parallel_group()

    dtype = torch.float16
    llama_config = LlamaConfig(
556
557
        hidden_size=256
        * num_attention_heads,  # ParallelGatedMlp hidden_features must be divisible by 256
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
        intermediate_size=256 * num_attention_heads * 4,
        num_hidden_layers=4,
        num_attention_heads=num_attention_heads,
        initializer_range=0.5,  # Set crazy init range so we don't have near zero weights implying a vacuous test.
    )
    config = llama_config_to_gpt2_config(llama_config)
    config.use_flash_attn = True
    config.fused_bias_fc = True
    config.fused_mlp = False  # We don't have fused GatedMLP yet
    config.fused_dropout_add_ln = True
    config.residual_in_fp32 = True

    torch.manual_seed(0)
    batch_size = 2
    max_seqlen = 256
    seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)
574
575
576
    input_ids = torch.randint(
        0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
    )
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597

    # Create a shared test model.
    if rank == 0:
        LlamaForCausalLM(config=llama_config).save_pretrained(checkpoint_path / f"{model_name}-hf")
    torch.distributed.barrier()

    # Run the standard forward pass test.
    pretrained_state_dict = _pretrained_state_dict_from_checkpoint(
        checkpoint_path, model_name, config, checkpoint_format="hf"
    )
    model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)
    model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))
    model.eval()

    # TODO: Avoid duplicate code. Modularize the comparison of two forward pass diffs.
    out = model.transformer(input_ids)
    out, _ = all_gather_raw(out, process_group=process_group)
    out = rearrange(out, "(b s) d -> b s d", b=batch_size)
    logits = model(input_ids).logits
    logits = rearrange(logits, "(b s) d -> b s d", b=batch_size)
    logits, _ = all_gather_raw(logits, process_group)
598
    logits = rearrange(logits, "(n b) ... d -> b ... (n d)", b=batch_size)
599
600
601

    if rank == 0:
        model_ref = LlamaForCausalLM.from_pretrained(
602
            Path(checkpoint_path) / f"{model_name}-hf", device_map={"": device}
603
        )
604
        model_ref = model_ref.to(device=device)
605
        model_ref.eval()
606
607
        out_ref = model_ref.model(input_ids).last_hidden_state
        logits_ref = model_ref(input_ids).logits
608
609
610
        del model_ref

        model_hf = LlamaForCausalLM.from_pretrained(
611
            Path(checkpoint_path) / f"{model_name}-hf", torch_dtype=dtype, device_map={"": device}
612
613
614
615
616
617
        )
        model_hf.eval()
        out_hf = model_hf.model(input_ids).last_hidden_state.to(device=device)
        logits_hf = model_hf(input_ids).logits.to(device=device)
        del model_hf

618
619
620
621
        print(f"Output max diff: {(out - out_ref).abs().max().item()}")
        print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
        print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
        print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
622
623
        assert (out - out_ref).abs().max().item() < 2 * (out_hf - out_ref).abs().max().item()

624
625
626
627
628
629
630
631
632
633
        print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
        print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
        print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
        print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
        assert (logits - logits_ref).abs().max().item() < 2 * (
            logits_hf - logits_ref
        ).abs().max().item()

        if os.path.exists(checkpoint_path / f"{model_name}-hf"):
            shutil.rmtree(checkpoint_path / f"{model_name}-hf")