cal_mfu.py 5.65 KB
Newer Older
chenych's avatar
chenych committed
1
# Copyright 2025 the LlamaFactory team.
luopl's avatar
luopl committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os

import fire
import torch
import torch.distributed as dist
from transformers import AutoConfig

from llamafactory.train.tuner import run_exp


BASE = 2  # gemm (add + mul)


def compute_model_flops(
    model_name_or_path: str,
    total_batch_size: int,
    seq_length: int,
    include_backward: bool = True,
    include_recompute: bool = False,
    include_flashattn: bool = False,
) -> int:
chenych's avatar
chenych committed
37
    r"""Calculate the FLOPs of model per forward/backward pass."""
luopl's avatar
luopl committed
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
    config = AutoConfig.from_pretrained(model_name_or_path)
    hidden_size = getattr(config, "hidden_size", None)
    vocab_size = getattr(config, "vocab_size", None)
    intermediate_size = getattr(config, "intermediate_size", None)
    num_attention_heads = getattr(config, "num_attention_heads", None)
    num_key_value_heads = getattr(config, "num_key_value_heads", None)
    num_hidden_layers = getattr(config, "num_hidden_layers", None)
    tie_word_embeddings = getattr(config, "tie_word_embeddings", False)

    # mlp module
    mlp_flops_per_token = 3 * BASE * hidden_size * intermediate_size  # up, gate, down
    mlp_flops = total_batch_size * seq_length * num_hidden_layers * mlp_flops_per_token

    # attn projector module
    q_flops_per_token = BASE * hidden_size * hidden_size
    o_flops_per_token = BASE * hidden_size * hidden_size
    k_flops_per_token = BASE * hidden_size * hidden_size * num_key_value_heads // num_attention_heads
    v_flops_per_token = BASE * hidden_size * hidden_size * num_key_value_heads // num_attention_heads
    attn_proj_flops_per_token = q_flops_per_token + o_flops_per_token + k_flops_per_token + v_flops_per_token
    attn_proj_flops = total_batch_size * seq_length * num_hidden_layers * attn_proj_flops_per_token

    # attn sdpa module
    sdpa_flops_per_layer = 2 * BASE * hidden_size * seq_length * seq_length  # (q * k^T) * v
    sdpa_flops = total_batch_size * num_hidden_layers * sdpa_flops_per_layer

    # embedding module
    embedding_flops_per_token = hidden_size * vocab_size
    embedding_flops = total_batch_size * seq_length * embedding_flops_per_token
    if tie_word_embeddings is False:
        embedding_flops *= 2

    non_embedding_flops = mlp_flops + attn_proj_flops + sdpa_flops
    non_embedding_coeff, embedding_coeff = 1, 1
    if include_backward:
        non_embedding_coeff += 2
        embedding_coeff += 2

    if include_recompute:
        non_embedding_coeff += 1

    total_flops = non_embedding_coeff * non_embedding_flops + embedding_coeff * embedding_flops

    if include_flashattn:
        total_flops += sdpa_flops

    return total_flops


def compute_device_flops(world_size: int) -> float:
chenych's avatar
chenych committed
87
    r"""Calculate the FLOPs of the device capability per second."""
luopl's avatar
luopl committed
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
    device_name = torch.cuda.get_device_name()
    if "H100" in device_name or "H800" in device_name:
        return 989 * 1e12 * world_size
    elif "A100" in device_name or "A800" in device_name:
        return 312 * 1e12 * world_size
    elif "V100" in device_name:
        return 125 * 1e12 * world_size
    elif "4090" in device_name:
        return 98 * 1e12 * world_size
    else:
        raise NotImplementedError(f"Device not supported: {device_name}.")


def calculate_mfu(
    model_name_or_path: str,
    batch_size: int = 1,
    seq_length: int = 1024,
    num_steps: int = 100,
    finetuning_type: str = "lora",
    flash_attn: str = "auto",
    deepspeed_stage: int = 0,
    disable_gc: bool = False,
    liger_kernel: bool = False,
    unsloth_gc: bool = False,
) -> float:
chenych's avatar
chenych committed
113
114
    r"""Calculate MFU for given model and hyper-params.

luopl's avatar
luopl committed
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
    Usage: python cal_mfu.py --model_name_or_path path_to_model --batch_size 1 --seq_length 1024
    """
    args = {
        "model_name_or_path": model_name_or_path,
        "flash_attn": flash_attn,
        "disable_gradient_checkpointing": disable_gc,
        "enable_liger_kernel": liger_kernel,
        "use_unsloth_gc": unsloth_gc,
        "stage": "pt",
        "do_train": True,
        "finetuning_type": finetuning_type,
        "dataset": "c4_demo",
        "cutoff_len": seq_length,
        "output_dir": os.path.join("saves", "test_mfu"),
        "logging_strategy": "no",
        "save_strategy": "no",
        "save_only_model": True,
        "overwrite_output_dir": True,
        "per_device_train_batch_size": batch_size,
        "max_steps": num_steps,
        "bf16": True,
    }
    if deepspeed_stage in [2, 3]:
        args["deepspeed"] = f"examples/deepspeed/ds_z{deepspeed_stage}_config.json"

    run_exp(args)
    if dist.is_initialized():
chenych's avatar
chenych committed
142
        dist.barrier()
luopl's avatar
luopl committed
143
144
145
146
        world_size = dist.get_world_size()
    else:
        world_size = 1

chenych's avatar
chenych committed
147
148
149
150
151
152
153
154
155
156
157
    if int(os.getenv("LOCAL_RANK", "0")) == 0:
        with open(os.path.join("saves", "test_mfu", "all_results.json"), encoding="utf-8") as f:
            result = json.load(f)

        total_batch_size = batch_size * world_size
        mfu_value = (
            result["train_steps_per_second"]
            * compute_model_flops(model_name_or_path, total_batch_size, seq_length)
            / compute_device_flops(world_size)
        )
        print(f"MFU: {mfu_value * 100:.2f}%")
luopl's avatar
luopl committed
158
159
160
161


if __name__ == "__main__":
    fire.Fire(calculate_mfu)