test_llama.py 27.3 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
3
4
# Copyright (c) 2023, Tri Dao.

# To run the huggingface implementation, we first need to convert the weights:
# https://github.com/huggingface/transformers/pull/21955
5
# python -m transformers.models.llama.convert_llama_weights_to_hf --input_dir $CHECKPOINT_DIR/llama --model_size 7B --output_dir $CHECKPOINT_DIR/llama/7B-hf
Tri Dao's avatar
Tri Dao committed
6
7
8
9
10
# and repeat for 13B, 30B, 65B

import os
import time
from pathlib import Path
11

Tri Dao's avatar
Tri Dao committed
12
13
current_dir = Path(__file__).parent.absolute()

14
import shutil
Tri Dao's avatar
Tri Dao committed
15

Tri Dao's avatar
Tri Dao committed
16
17
import pytest
import torch
Tri Dao's avatar
Tri Dao committed
18
19
from einops import rearrange
from flash_attn.models.gpt import GPTLMHeadModel, combine_state_dicts_tp, shard_state_dict_tp
20
from flash_attn.models.llama import (
Tri Dao's avatar
Tri Dao committed
21
22
    config_from_checkpoint,
    inv_remap_state_dict_hf_llama,
23
24
    llama_config_to_gpt2_config,
    remap_state_dict_hf_llama,
Tri Dao's avatar
Tri Dao committed
25
26
    remap_state_dict_meta_llama,
    state_dicts_from_checkpoint,
27
)
Tri Dao's avatar
Tri Dao committed
28
from flash_attn.utils.distributed import all_gather_raw
Tri Dao's avatar
Tri Dao committed
29
from flash_attn.utils.generation import update_graph_cache
Tri Dao's avatar
Tri Dao committed
30
31
32
from flash_attn.utils.pretrained import state_dict_from_pretrained
from transformers import LlamaConfig, LlamaTokenizer
from transformers.models.llama.modeling_llama import LlamaForCausalLM
Tri Dao's avatar
Tri Dao committed
33
34


35
36
37
38
39
40
def _pretrained_state_dict_from_checkpoint(checkpoint_path, model_name, config, checkpoint_format):
    if checkpoint_format == "meta":
        ckpt_state_dicts = state_dicts_from_checkpoint(checkpoint_path, model_name)
        pretrained_state_dicts = [remap_state_dict_meta_llama(s, config) for s in ckpt_state_dicts]
        pretrained_state_dict = combine_state_dicts_tp(pretrained_state_dicts, config)
    else:
41
42
43
        pretrained_state_dict = state_dict_from_pretrained(
            Path(checkpoint_path) / f"{model_name}-hf"
        )
44
45
46
47
        pretrained_state_dict = remap_state_dict_hf_llama(pretrained_state_dict, config)
    return pretrained_state_dict


48
@pytest.mark.parametrize("model_name", ["7B"])
Tri Dao's avatar
Tri Dao committed
49
def test_llama_state_dict(model_name):
50
51
52
    checkpoint_path = (
        Path(os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")) / "llama"
    )
Tri Dao's avatar
Tri Dao committed
53
54
55
    config = llama_config_to_gpt2_config(config_from_checkpoint(checkpoint_path, model_name))
    ckpt_state_dicts = state_dicts_from_checkpoint(checkpoint_path, model_name)
    pretrained_state_dict = remap_state_dict_meta_llama(ckpt_state_dicts[0], config)
56
    model = GPTLMHeadModel(config, device="meta")  # Without device='meta' init is very slow
Tri Dao's avatar
Tri Dao committed
57
    state_dict = model.state_dict()
58
59
    assert state_dict.keys() == pretrained_state_dict.keys()
    for k in state_dict.keys():
Tri Dao's avatar
Tri Dao committed
60
61
62
        assert state_dict[k].shape == pretrained_state_dict[k].shape


63
64
@pytest.mark.parametrize("model_name", ["7B", "13B"])
@pytest.mark.parametrize("checkpoint_format", ["meta", "hf"])
65
def test_llama_optimized(model_name, checkpoint_format):
Tri Dao's avatar
Tri Dao committed
66
67
68
69
    """Check that our implementation of LLaMa (with all optimizations enabled) matches the
    HF implementation: the output of our forward pass in fp16 should be around the same as the HF
    forward pass in fp16, when compared to the HF forward pass in fp32.
    """
70
71
72
    checkpoint_path = (
        Path(os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")) / "llama"
    )
Tri Dao's avatar
Tri Dao committed
73
74

    dtype = torch.float16
75
    device = "cuda"
76
77
    config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format)
    config = llama_config_to_gpt2_config(config)
Tri Dao's avatar
Tri Dao committed
78
79
80
81
82
83
    config.use_flash_attn = True
    config.fused_bias_fc = True
    config.fused_mlp = False  # We don't have fused GatedMLP yet
    config.fused_dropout_add_ln = True
    config.residual_in_fp32 = True

84
85
86
    pretrained_state_dict = _pretrained_state_dict_from_checkpoint(
        checkpoint_path, model_name, config, checkpoint_format
    )
Tri Dao's avatar
Tri Dao committed
87
    model = GPTLMHeadModel(config, device=device, dtype=dtype)
Tri Dao's avatar
Tri Dao committed
88
    model.load_state_dict(pretrained_state_dict)
Tri Dao's avatar
Tri Dao committed
89
90
91
92
93
94
    model.eval()

    torch.manual_seed(0)
    batch_size = 2
    max_seqlen = 256
    seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)
95
96
97
    input_ids = torch.randint(
        0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
    )
Tri Dao's avatar
Tri Dao committed
98
99
100
101
102
103
104
    with torch.no_grad():
        out = model.transformer(input_ids)
        logits = model(input_ids).logits
    del model

    # Without device_map, the model is loaded on the CPU, which is very slow
    # Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB
105
106
107
    model_ref = LlamaForCausalLM.from_pretrained(
        Path(checkpoint_path) / f"{model_name}-hf", device_map="auto"
    )
Tri Dao's avatar
Tri Dao committed
108
109
110
111
112
113
    model_ref.eval()
    with torch.no_grad():
        out_ref = model_ref.model(input_ids).last_hidden_state.to(device=device)
        logits_ref = model_ref(input_ids).logits.to(device=device)
    del model_ref

114
115
116
    model_hf = LlamaForCausalLM.from_pretrained(
        Path(checkpoint_path) / f"{model_name}-hf", torch_dtype=dtype, device_map={"": device}
    )
Tri Dao's avatar
Tri Dao committed
117
    model_hf.eval()
Tri Dao's avatar
Tri Dao committed
118
119
120
    with torch.no_grad():
        out_hf = model_hf.model(input_ids).last_hidden_state
        logits_hf = model_hf(input_ids).logits
Tri Dao's avatar
Tri Dao committed
121
122
    del model_hf

123
124
125
126
    print(f"Output max diff: {(out - out_ref).abs().max().item()}")
    print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
    print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
    print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
Tri Dao's avatar
Tri Dao committed
127
128
    assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()

129
130
131
132
133
134
135
    print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
    print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
    print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
    print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
    assert (logits - logits_ref).abs().max().item() < 3 * (
        logits_hf - logits_ref
    ).abs().max().item()
Tri Dao's avatar
Tri Dao committed
136
137


Kevin Hu's avatar
Kevin Hu committed
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203


@pytest.mark.parametrize("model_name", ["PY007/TinyLlama-1.1B-step-50K-105b"])
def test_mqa_optimized(model_name):
    """Check that our implementation of Llama with MQA/GQA (with all optimizations enabled) matches the
    HF implementation: the output of our forward pass in fp16 should be around the same as the HF
    forward pass in fp16, when compared to the HF forward pass in fp32.
    """
    dtype = torch.float16
    device = "cuda"
    config = llama_config_to_gpt2_config(LlamaConfig.from_pretrained(model_name))
    config.use_flash_attn = True  # FlashAttention-2 supports headdim 256
    config.fused_bias_fc = True
    config.fused_mlp = False
    config.fused_dropout_add_ln = True
    config.residual_in_fp32 = True

    # Without device_map, the model is loaded on the CPU, which is very slow
    model_ref = LlamaForCausalLM.from_pretrained(model_name, device_map={"": device})
    model_ref.eval()

    model = GPTLMHeadModel(config, device=device, dtype=dtype)
    model.load_state_dict(remap_state_dict_hf_llama(model_ref.state_dict(), config))
    model.eval()

    torch.manual_seed(0)
    batch_size = 2
    max_seqlen = 256
    input_ids = torch.randint(
        0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
    )
    with torch.no_grad():
        out = model.transformer(input_ids)
        logits = model(input_ids).logits
    del model

    with torch.no_grad():
        out_ref = model_ref.model(input_ids).last_hidden_state
        logits_ref = model_ref(input_ids).logits
    del model_ref

    model_hf = LlamaForCausalLM.from_pretrained(
        model_name, torch_dtype=dtype, device_map={"": device}
    )
    model_hf.eval()
    out_hf = model_hf.model(input_ids).last_hidden_state
    logits_hf = model_hf(input_ids).logits
    del model_hf

    print(f"Output max diff: {(out - out_ref).abs().max().item()}")
    print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
    print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
    print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
    assert (out - out_ref).abs().max().item() < 3 * (
        out_hf - out_ref
    ).abs().max().item()

    print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
    print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
    print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
    print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
    assert (logits - logits_ref).abs().max().item() < 3 * (
        logits_hf - logits_ref
    ).abs().max().item()


Tri Dao's avatar
Tri Dao committed
204
# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_llama.py -k "parallel"
205
206
207
@pytest.mark.parametrize("world_size", [2])
@pytest.mark.parametrize("model_name", ["13B"])
@pytest.mark.parametrize("checkpoint_format", ["meta", "hf"])
208
def test_llama_parallel(model_name, world_size, checkpoint_format):
Tri Dao's avatar
Tri Dao committed
209
210
211
212
213
214
    """Check that our implementation of LLaMa (with all optimizations enabled) matches the
    HF implementation: the output of our forward pass in fp16 should be around the same as the HF
    forward pass in fp16, when compared to the HF forward pass in fp32.
    """
    from apex.transformer import parallel_state

215
216
217
    checkpoint_path = (
        Path(os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")) / "llama"
    )
Tri Dao's avatar
Tri Dao committed
218
219

    dtype = torch.float16
220
221
    config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format)
    config = llama_config_to_gpt2_config(config)
Tri Dao's avatar
Tri Dao committed
222
223
224
225
226
227
228
    config.use_flash_attn = True
    config.fused_bias_fc = True
    config.fused_mlp = False  # We don't have fused GatedMLP yet
    config.fused_dropout_add_ln = True
    config.residual_in_fp32 = True

    if not torch.distributed.is_initialized():
229
230
        torch.distributed.init_process_group(backend="nccl", init_method="env://")
    device = f"cuda:{torch.distributed.get_rank()}"
Tri Dao's avatar
Tri Dao committed
231
232
233
234
235
    assert world_size <= torch.distributed.get_world_size()
    parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
    rank = parallel_state.get_tensor_model_parallel_rank()
    process_group = parallel_state.get_tensor_model_parallel_group()

236
237
238
    pretrained_state_dict = _pretrained_state_dict_from_checkpoint(
        checkpoint_path, model_name, config, checkpoint_format
    )
Tri Dao's avatar
Tri Dao committed
239
    model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)
Tri Dao's avatar
Tri Dao committed
240
    model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))
Tri Dao's avatar
Tri Dao committed
241
242
243
244
245
246
    model.eval()

    torch.manual_seed(0)
    batch_size = 2
    max_seqlen = 256
    seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)
247
248
249
    input_ids = torch.randint(
        0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
    )
Tri Dao's avatar
Tri Dao committed
250
251
    with torch.no_grad():
        out = model.transformer(input_ids)
Tri Dao's avatar
Tri Dao committed
252
253
        out, _ = all_gather_raw(out, process_group=process_group)
        out = rearrange(out, "(b s) d -> b s d", b=batch_size)
Tri Dao's avatar
Tri Dao committed
254
        logits = model(input_ids).logits
Tri Dao's avatar
Tri Dao committed
255
256
        logits = rearrange(logits, "(b s) d -> b s d", b=batch_size)
        logits, _ = all_gather_raw(logits, process_group)
257
        logits = rearrange(logits, "(n b) ... d -> b ... (n d)", b=batch_size)
Tri Dao's avatar
Tri Dao committed
258
259
    del model

Tri Dao's avatar
Tri Dao committed
260
261
262
    if rank == 0:
        # Without device_map, the model is loaded on the CPU, which is very slow
        model_ref = LlamaForCausalLM.from_pretrained(
263
            Path(checkpoint_path) / f"{model_name}-hf", device_map="auto"
Tri Dao's avatar
Tri Dao committed
264
265
266
267
268
269
270
271
        )
        model_ref.eval()
        with torch.no_grad():
            out_ref = model_ref.model(input_ids).last_hidden_state.to(device=device)
            logits_ref = model_ref(input_ids).logits.to(device=device)
        del model_ref

        model_hf = LlamaForCausalLM.from_pretrained(
272
            Path(checkpoint_path) / f"{model_name}-hf", torch_dtype=dtype, device_map="auto"
Tri Dao's avatar
Tri Dao committed
273
274
275
276
277
278
279
        )
        model_hf.eval()
        with torch.no_grad():
            out_hf = model_hf.model(input_ids).last_hidden_state.to(device=device)
            logits_hf = model_hf(input_ids).logits.to(device=device)
        del model_hf

280
281
282
283
        print(f"Output max diff: {(out - out_ref).abs().max().item()}")
        print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
        print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
        print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
Tri Dao's avatar
Tri Dao committed
284
285
        assert (out - out_ref).abs().max().item() < 2 * (out_hf - out_ref).abs().max().item()

286
287
288
289
290
291
292
        print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
        print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
        print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
        print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
        assert (logits - logits_ref).abs().max().item() < 2 * (
            logits_hf - logits_ref
        ).abs().max().item()
Tri Dao's avatar
Tri Dao committed
293
294
295


# @pytest.mark.parametrize('model_name', ["7B", "13B"])
296
297
@pytest.mark.parametrize("model_name", ["7B"])
@pytest.mark.parametrize("checkpoint_format", ["meta", "hf"])
298
def test_llama_generation(model_name, checkpoint_format):
299
300
301
    checkpoint_path = (
        Path(os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")) / "llama"
    )
Tri Dao's avatar
Tri Dao committed
302
303

    dtype = torch.float16
304
    device = "cuda"
305
306
    config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format)
    config = llama_config_to_gpt2_config(config)
Tri Dao's avatar
Tri Dao committed
307
308
309
310
311
312
    config.use_flash_attn = True
    config.fused_bias_fc = True
    config.fused_mlp = False  # We don't have fused GatedMLP yet
    config.fused_dropout_add_ln = True
    config.residual_in_fp32 = True

313
    tokenizer = LlamaTokenizer.from_pretrained(Path(checkpoint_path) / f"{model_name}-hf")
Tri Dao's avatar
Tri Dao committed
314
315
316
317
318
319
    eos_token_id = tokenizer.eos_token_id

    torch.manual_seed(0)
    batch_size = 1
    seqlen = 100
    max_length = 150
320
321
322
    input_ids = torch.randint(
        0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
    )
Tri Dao's avatar
Tri Dao committed
323

324
325
326
    model_hf = LlamaForCausalLM.from_pretrained(
        Path(checkpoint_path) / f"{model_name}-hf", torch_dtype=dtype, device_map={"": device}
    )
Tri Dao's avatar
Tri Dao committed
327
328
329
330
    model_hf.eval()
    print("HF fp16")
    torch.cuda.synchronize()
    start = time.time()
331
332
333
    out_hf = model_hf.generate(
        input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True
    )
Tri Dao's avatar
Tri Dao committed
334
    torch.cuda.synchronize()
335
    print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
Tri Dao's avatar
Tri Dao committed
336
337
    del model_hf

Tri Dao's avatar
Tri Dao committed
338
    # Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB
339
340
341
    model_ref = LlamaForCausalLM.from_pretrained(
        Path(checkpoint_path) / f"{model_name}-hf", device_map="auto"
    )
Tri Dao's avatar
Tri Dao committed
342
343
    model_ref.eval()
    with torch.no_grad():
344
        logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1].to(device=device)
Tri Dao's avatar
Tri Dao committed
345
346
    del model_ref

347
348
349
    pretrained_state_dict = _pretrained_state_dict_from_checkpoint(
        checkpoint_path, model_name, config, checkpoint_format
    )
Tri Dao's avatar
Tri Dao committed
350
    model = GPTLMHeadModel(config, device=device, dtype=dtype)
Tri Dao's avatar
Tri Dao committed
351
    model.load_state_dict(pretrained_state_dict)
Tri Dao's avatar
Tri Dao committed
352
353
    model.eval()

354
    print("Without CUDA graph")
Tri Dao's avatar
Tri Dao committed
355
356
    torch.cuda.synchronize()
    start = time.time()
357
358
359
360
361
362
    out = model.generate(
        input_ids=input_ids,
        max_length=max_length,
        eos_token_id=eos_token_id,
        return_dict_in_generate=True,
        output_scores=True,
Tri Dao's avatar
Tri Dao committed
363
        enable_timing=True,
364
365
        teacher_outputs=out_hf.sequences,
    )
Tri Dao's avatar
Tri Dao committed
366
    torch.cuda.synchronize()
367
    print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
Tri Dao's avatar
Tri Dao committed
368
369
370

    # Capture graph outside the timing loop
    batch_size, seqlen_og = input_ids.shape
371
    model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
372
    print("With CUDA graph")
Tri Dao's avatar
Tri Dao committed
373
374
    torch.cuda.synchronize()
    start = time.time()
375
376
377
378
379
380
    out_cg = model.generate(
        input_ids=input_ids,
        max_length=max_length,
        cg=True,
        return_dict_in_generate=True,
        output_scores=True,
Tri Dao's avatar
Tri Dao committed
381
        enable_timing=True,
382
383
        teacher_outputs=out_hf.sequences,
    )
Tri Dao's avatar
Tri Dao committed
384
    torch.cuda.synchronize()
385
    print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
Tri Dao's avatar
Tri Dao committed
386
387

    with torch.no_grad():
388
        logits_parallel = model(out_hf.sequences).logits[:, (seqlen - 1) : -1]
Tri Dao's avatar
Tri Dao committed
389
390
391
392
393
394
395
396
    logits_hf = torch.stack(out_hf.scores, dim=1)
    logits = torch.stack(out.scores, dim=1)
    logits_cg = torch.stack(out_cg.scores, dim=1)

    del model

    hf_error = (logits_hf - logits_ref).abs().max().item()

397
398
399
    print(f"HF fp16 logits max diff: {hf_error}")
    print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
    print(f"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item()}")
400
401
402

    assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error
    assert (logits - logits_ref).abs().max().item() < 2 * hf_error
Tri Dao's avatar
Tri Dao committed
403
    assert torch.equal(logits_cg, logits)
Tri Dao's avatar
Tri Dao committed
404
405
406


# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_llama.py -k "llama_parallel_generation"
407
408
409
@pytest.mark.parametrize("world_size", [2])
@pytest.mark.parametrize("model_name", ["13B"])
@pytest.mark.parametrize("checkpoint_format", ["meta", "hf"])
410
def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
Tri Dao's avatar
Tri Dao committed
411
412
413
414
415
416
    """Check that our implementation matches the HF implementation:
    the scores in fp16 should be around the same as the HF scores in fp16, when compared to
    the HF scores in fp32.
    """
    from apex.transformer import parallel_state

417
418
419
    checkpoint_path = (
        Path(os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")) / "llama"
    )
Tri Dao's avatar
Tri Dao committed
420
421

    dtype = torch.float16
422
423
    config = config_from_checkpoint(checkpoint_path, model_name, checkpoint_format)
    config = llama_config_to_gpt2_config(config)
Tri Dao's avatar
Tri Dao committed
424
425
426
427
428
429
430
431
432
433
    config.use_flash_attn = False
    config.fused_bias_fc = True
    config.fused_mlp = False  # We don't have fused GatedMLP yet
    config.fused_dropout_add_ln = False
    config.residual_in_fp32 = True
    config.pad_vocab_size_multiple = 8 * world_size
    config.sequence_parallel = False  # Need to set this to False for generation

    os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0"
    if not torch.distributed.is_initialized():
434
435
        torch.distributed.init_process_group(backend="nccl", init_method="env://")
    device = f"cuda:{torch.distributed.get_rank()}"
Tri Dao's avatar
Tri Dao committed
436
437
438
439
440
441
442
443
444
    assert world_size <= torch.distributed.get_world_size()
    parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
    rank = parallel_state.get_tensor_model_parallel_rank()
    process_group = parallel_state.get_tensor_model_parallel_group()

    torch.manual_seed(0)
    batch_size = 1
    seqlen = 100
    max_length = 150
445
446
447
    input_ids = torch.randint(
        0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
    )
Tri Dao's avatar
Tri Dao committed
448
449
450
451
452

    # Need this, otherwise when we capture the graph the process for GPU 1 would run on both
    # GPU0 and GPU1 and things would hang
    torch.cuda.set_device(device)

453
454
455
    pretrained_state_dict = _pretrained_state_dict_from_checkpoint(
        checkpoint_path, model_name, config, checkpoint_format
    )
Tri Dao's avatar
Tri Dao committed
456
457
458
459
    model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)
    model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))
    model.eval()

460
    print("Without CUDA graph")
Tri Dao's avatar
Tri Dao committed
461
    out = model.generate(
462
463
464
465
        input_ids=input_ids,
        max_length=max_length,
        tensor_parallel=world_size,
        vocab_size=config.vocab_size,
Tri Dao's avatar
Tri Dao committed
466
        # teacher_outputs=out_hf.sequences,
467
468
        return_dict_in_generate=True,
        output_scores=True,
Tri Dao's avatar
Tri Dao committed
469
        enable_timing=True,
Tri Dao's avatar
Tri Dao committed
470
471
472
473
    )

    # Capture graph outside the timing loop
    batch_size, seqlen_og = input_ids.shape
474
    model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
475
    print("With CUDA graph")
Tri Dao's avatar
Tri Dao committed
476
    out_cg = model.generate(
477
478
479
480
481
        input_ids=input_ids,
        max_length=max_length,
        tensor_parallel=world_size,
        vocab_size=config.vocab_size,
        cg=True,
Tri Dao's avatar
Tri Dao committed
482
        # teacher_outputs=out_hf.sequences,
483
484
        return_dict_in_generate=True,
        output_scores=True,
Tri Dao's avatar
Tri Dao committed
485
        enable_timing=True,
Tri Dao's avatar
Tri Dao committed
486
487
488
489
490
491
492
    )
    del model
    parallel_state.destroy_model_parallel()

    if rank == 0:
        # Without device_map, the model is loaded on the CPU, which is very slow
        model_hf = LlamaForCausalLM.from_pretrained(
493
            Path(checkpoint_path) / f"{model_name}-hf", torch_dtype=dtype, device_map="auto"
Tri Dao's avatar
Tri Dao committed
494
495
496
497
498
499
500
        )
        model_hf.eval()
        print("HF fp16")
        torch.cuda.synchronize()
        start = time.time()
        with torch.inference_mode():
            out_hf = model_hf.generate(
501
502
503
504
                input_ids=input_ids,
                max_length=max_length,
                return_dict_in_generate=True,
                output_scores=True,
Tri Dao's avatar
Tri Dao committed
505
506
            )
        torch.cuda.synchronize()
507
        print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
Tri Dao's avatar
Tri Dao committed
508
509
510
        del model_hf

        model_ref = LlamaForCausalLM.from_pretrained(
511
            Path(checkpoint_path) / f"{model_name}-hf", device_map="auto"
Tri Dao's avatar
Tri Dao committed
512
513
514
        )
        model_ref.eval()
        with torch.inference_mode():
515
            logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1]
Tri Dao's avatar
Tri Dao committed
516
517
518
519
520
521
522
        del model_ref
        logits_hf = torch.stack(out_hf.scores, dim=1)

        logits = torch.stack(out.scores, dim=1)
        logits_cg = torch.stack(out_cg.scores, dim=1)

        hf_error = (logits_hf - logits_ref).abs().max().item()
523
524
        print(f"HF fp16 logits max diff: {hf_error}")
        print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
Tri Dao's avatar
Tri Dao committed
525
        assert (logits - logits_ref).abs().max().item() < 2 * hf_error
526
        print(f"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item()}")
Tri Dao's avatar
Tri Dao committed
527
        assert torch.equal(logits_cg, logits)
528
529
530


@torch.no_grad()
531
@pytest.mark.parametrize("world_size", [2])
532
533
534
def test_llama_parallel_uneven_num_heads(world_size):
    from apex.transformer import parallel_state

535
536
537
    checkpoint_path = (
        Path(os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")) / "llama"
    )
538
    num_attention_heads = world_size + 1
539
    model_name = f"teeny-{num_attention_heads}-heads"
540
541

    if not torch.distributed.is_initialized():
542
543
        torch.distributed.init_process_group(backend="nccl", init_method="env://")
    device = f"cuda:{torch.distributed.get_rank()}"
544
545
546
547
548
549
550
    assert world_size <= torch.distributed.get_world_size()
    parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
    rank = parallel_state.get_tensor_model_parallel_rank()
    process_group = parallel_state.get_tensor_model_parallel_group()

    dtype = torch.float16
    llama_config = LlamaConfig(
551
552
        hidden_size=256
        * num_attention_heads,  # ParallelGatedMlp hidden_features must be divisible by 256
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
        intermediate_size=256 * num_attention_heads * 4,
        num_hidden_layers=4,
        num_attention_heads=num_attention_heads,
        initializer_range=0.5,  # Set crazy init range so we don't have near zero weights implying a vacuous test.
    )
    config = llama_config_to_gpt2_config(llama_config)
    config.use_flash_attn = True
    config.fused_bias_fc = True
    config.fused_mlp = False  # We don't have fused GatedMLP yet
    config.fused_dropout_add_ln = True
    config.residual_in_fp32 = True

    torch.manual_seed(0)
    batch_size = 2
    max_seqlen = 256
    seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)
569
570
571
    input_ids = torch.randint(
        0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
    )
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592

    # Create a shared test model.
    if rank == 0:
        LlamaForCausalLM(config=llama_config).save_pretrained(checkpoint_path / f"{model_name}-hf")
    torch.distributed.barrier()

    # Run the standard forward pass test.
    pretrained_state_dict = _pretrained_state_dict_from_checkpoint(
        checkpoint_path, model_name, config, checkpoint_format="hf"
    )
    model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)
    model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))
    model.eval()

    # TODO: Avoid duplicate code. Modularize the comparison of two forward pass diffs.
    out = model.transformer(input_ids)
    out, _ = all_gather_raw(out, process_group=process_group)
    out = rearrange(out, "(b s) d -> b s d", b=batch_size)
    logits = model(input_ids).logits
    logits = rearrange(logits, "(b s) d -> b s d", b=batch_size)
    logits, _ = all_gather_raw(logits, process_group)
593
    logits = rearrange(logits, "(n b) ... d -> b ... (n d)", b=batch_size)
594
595
596

    if rank == 0:
        model_ref = LlamaForCausalLM.from_pretrained(
597
            Path(checkpoint_path) / f"{model_name}-hf", device_map="auto"
598
599
600
601
602
603
604
        )
        model_ref.eval()
        out_ref = model_ref.model(input_ids).last_hidden_state.to(device=device)
        logits_ref = model_ref(input_ids).logits.to(device=device)
        del model_ref

        model_hf = LlamaForCausalLM.from_pretrained(
605
            Path(checkpoint_path) / f"{model_name}-hf", torch_dtype=dtype, device_map="auto"
606
607
608
609
610
611
        )
        model_hf.eval()
        out_hf = model_hf.model(input_ids).last_hidden_state.to(device=device)
        logits_hf = model_hf(input_ids).logits.to(device=device)
        del model_hf

612
613
614
615
        print(f"Output max diff: {(out - out_ref).abs().max().item()}")
        print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
        print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
        print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
616
617
        assert (out - out_ref).abs().max().item() < 2 * (out_hf - out_ref).abs().max().item()

618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
        print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
        print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
        print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
        print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
        assert (logits - logits_ref).abs().max().item() < 2 * (
            logits_hf - logits_ref
        ).abs().max().item()

        if os.path.exists(checkpoint_path / f"{model_name}-hf"):
            shutil.rmtree(checkpoint_path / f"{model_name}-hf")


@torch.no_grad()
def test_inv_remap_state_dict_hf_llama():
    checkpoint_path = (
        Path(os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")) / "llama"
    )
    model_name = f"teeny"

    llama_config = LlamaConfig(
        num_attention_heads=2,
        hidden_size=256 * 2,
        intermediate_size=256 * 2 * 4,
        num_hidden_layers=4,
    )
    config = llama_config_to_gpt2_config(llama_config)
    config.use_flash_attn = True
    config.fused_bias_fc = True
    config.fused_mlp = False  # We don't have fused GatedMLP yet
    config.fused_dropout_add_ln = True
    config.residual_in_fp32 = True

    # Set up.
    LlamaForCausalLM(config=llama_config).save_pretrained(checkpoint_path / f"{model_name}-hf")

    # inv_remap_state_dict_hf_llama should be the inverse of remap_state_dict_hf_llama
    state_dict = state_dict_from_pretrained(checkpoint_path / f"{model_name}-hf")
    state_dict = {key: val for key, val in state_dict.items() if "rotary_emb.inv_freq" not in key}
    pretrained_state_dict = remap_state_dict_hf_llama(state_dict, config)
    state_dict_recover = inv_remap_state_dict_hf_llama(pretrained_state_dict, config)

    assert set(state_dict_recover.keys()) == set(state_dict.keys())

    for key in state_dict_recover.keys():
        torch.testing.assert_close(state_dict_recover[key], state_dict[key])
663

664
665
666
    # Tear down.
    if os.path.exists(checkpoint_path / f"{model_name}-hf"):
        shutil.rmtree(checkpoint_path / f"{model_name}-hf")