model_merger.py 6.91 KB
Newer Older
chenych's avatar
chenych committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os
import re
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List, Tuple

chenych's avatar
chenych committed
21
import numpy as np
chenych's avatar
chenych committed
22
23
import torch
from torch.distributed._tensor import DTensor, Placement, Shard
chenych's avatar
chenych committed
24
25
26
27
28
29
30
31
from transformers import (
    AutoConfig,
    AutoModelForCausalLM,
    AutoModelForTokenClassification,
    AutoModelForVision2Seq,
    PretrainedConfig,
    PreTrainedModel,
)
chenych's avatar
chenych committed
32
33
34
35
36
37
38
39
40
41
42
43
44


def merge_by_placement(tensors: List[torch.Tensor], placement: Placement):
    if placement.is_replicate():
        return tensors[0]
    elif placement.is_partial():
        raise NotImplementedError("Partial placement is not supported yet")
    elif placement.is_shard():
        return torch.cat(tensors, dim=placement.dim).contiguous()
    else:
        raise ValueError(f"Unsupported placement: {placement}")


chenych's avatar
chenych committed
45
46
47
48
49
50
51
52
53
def upload_model_to_huggingface(local_path: str, remote_path: str):
    # Push to hugging face
    from huggingface_hub import HfApi

    api = HfApi()
    api.create_repo(repo_id=remote_path, private=False, exist_ok=True)
    api.upload_folder(repo_id=remote_path, folder_path=local_path, repo_type="model")


chenych's avatar
chenych committed
54
55
56
57
58
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--local_dir", required=True, type=str, help="The path for your saved model")
    parser.add_argument("--hf_upload_path", default=False, type=str, help="The path of the huggingface repo to upload")
    args = parser.parse_args()
chenych's avatar
chenych committed
59
    local_dir: str = args.local_dir
chenych's avatar
chenych committed
60

chenych's avatar
chenych committed
61
    assert not local_dir.endswith("huggingface"), "The local_dir should not end with huggingface."
chenych's avatar
chenych committed
62
63
64
65
66
67
68
69
70
71

    # copy rank zero to find the shape of (dp, fsdp)
    rank = 0
    world_size = 0
    for filename in os.listdir(local_dir):
        match = re.match(r"model_world_size_(\d+)_rank_0\.pt", filename)
        if match:
            world_size = match.group(1)
            break

chenych's avatar
chenych committed
72
73
74
75
    assert world_size, "No model file with the proper format."

    rank0_weight_path = os.path.join(local_dir, f"model_world_size_{world_size}_rank_{rank}.pt")
    state_dict = torch.load(rank0_weight_path, map_location="cpu", weights_only=False)
chenych's avatar
chenych committed
76
77
    pivot_key = sorted(state_dict.keys())[0]
    weight = state_dict[pivot_key]
chenych's avatar
chenych committed
78
79
80
81
82
83
84
85
86
    if isinstance(weight, DTensor):
        # get sharding info
        device_mesh = weight.device_mesh
        mesh = device_mesh.mesh
        mesh_dim_names = device_mesh.mesh_dim_names
    else:
        # for non-DTensor
        mesh = np.array([int(world_size)], dtype=np.int64)
        mesh_dim_names = ("fsdp",)
chenych's avatar
chenych committed
87
88
89

    print(f"Got device mesh {mesh}, mesh_dim_names {mesh_dim_names}")

chenych's avatar
chenych committed
90
    assert mesh_dim_names in (("fsdp",), ("ddp", "fsdp")), f"Unsupported mesh_dim_names {mesh_dim_names}."
chenych's avatar
chenych committed
91
92
93
94
95
96
97
98
99
100

    if "tp" in mesh_dim_names:
        # fsdp * tp
        total_shards = mesh.shape[-1] * mesh.shape[-2]
        mesh_shape = (mesh.shape[-2], mesh.shape[-1])
    else:
        # fsdp
        total_shards = mesh.shape[-1]
        mesh_shape = (mesh.shape[-1],)

chenych's avatar
chenych committed
101
    print(f"Processing {total_shards} model shards in total.")
chenych's avatar
chenych committed
102
103
104
105
    model_state_dict_lst = []
    model_state_dict_lst.append(state_dict)
    model_state_dict_lst.extend([""] * (total_shards - 1))

chenych's avatar
chenych committed
106
    def process_one_shard(rank, model_state_dict_lst):
chenych's avatar
chenych committed
107
108
109
110
111
112
113
        model_path = os.path.join(local_dir, f"model_world_size_{world_size}_rank_{rank}.pt")
        state_dict = torch.load(model_path, map_location="cpu", weights_only=False)
        model_state_dict_lst[rank] = state_dict
        return state_dict

    with ThreadPoolExecutor(max_workers=min(32, os.cpu_count())) as executor:
        for rank in range(1, total_shards):
chenych's avatar
chenych committed
114
115
116
            executor.submit(process_one_shard, rank, model_state_dict_lst)

    state_dict: Dict[str, List[torch.Tensor]] = {}
chenych's avatar
chenych committed
117
118
119
120
121
122
123
124
    param_placements: Dict[str, List[Placement]] = {}
    keys = set(model_state_dict_lst[0].keys())
    for key in keys:
        state_dict[key] = []
        for model_state_dict in model_state_dict_lst:
            try:
                tensor = model_state_dict.pop(key)
            except Exception:
chenych's avatar
chenych committed
125
126
                print(f"Cannot find key {key} in rank {rank}.")

chenych's avatar
chenych committed
127
128
129
            if isinstance(tensor, DTensor):
                state_dict[key].append(tensor._local_tensor.bfloat16())
                placements = tuple(tensor.placements)
chenych's avatar
Update  
chenych committed
130
131
                # replicated placement at ddp dimension can be discarded
                if mesh_dim_names[0] == "ddp":
chenych's avatar
chenych committed
132
                    placements = placements[1:]
chenych's avatar
Update  
chenych committed
133

chenych's avatar
chenych committed
134
135
136
137
138
                if key not in param_placements:
                    param_placements[key] = placements
                else:
                    assert param_placements[key] == placements
            else:
chenych's avatar
chenych committed
139
                state_dict[key].append(tensor.bfloat16())
chenych's avatar
chenych committed
140
141
142
143
144
145
146

    del model_state_dict_lst

    for key in sorted(state_dict):
        if not isinstance(state_dict[key], list):
            print(f"No need to merge key {key}")
            continue
chenych's avatar
chenych committed
147
148
149
150
151
152
153
154
155
156
157
158

        if key in param_placements:
            # merge shards
            placements: Tuple[Shard] = param_placements[key]
            if len(mesh_shape) == 1:
                # 1-D list, FSDP without TP
                assert len(placements) == 1
                shards = state_dict[key]
                state_dict[key] = merge_by_placement(shards, placements[0])
            else:
                # 2-D list, FSDP + TP
                raise NotImplementedError("FSDP + TP is not supported yet.")
chenych's avatar
chenych committed
159
        else:
chenych's avatar
chenych committed
160
            state_dict[key] = torch.cat(state_dict[key], dim=0)
chenych's avatar
chenych committed
161

chenych's avatar
chenych committed
162
    print("Merge completed.")
chenych's avatar
chenych committed
163
    hf_path = os.path.join(local_dir, "huggingface")
chenych's avatar
chenych committed
164
165
166
167
168
169
170
171
172
    config: PretrainedConfig = AutoConfig.from_pretrained(hf_path)
    architectures: List[str] = getattr(config, "architectures", ["Unknown"])

    if "ForTokenClassification" in architectures[0]:
        AutoClass = AutoModelForTokenClassification
    elif "ForCausalLM" in architectures[0]:
        AutoClass = AutoModelForCausalLM
    elif "ForConditionalGeneration" in architectures[0]:
        AutoClass = AutoModelForVision2Seq
chenych's avatar
chenych committed
173
    else:
chenych's avatar
chenych committed
174
        raise NotImplementedError(f"Unknown architecture {architectures}.")
chenych's avatar
chenych committed
175
176

    with torch.device("meta"):
chenych's avatar
chenych committed
177
        model: PreTrainedModel = AutoClass.from_config(config, torch_dtype=torch.bfloat16)
chenych's avatar
chenych committed
178

chenych's avatar
chenych committed
179
    assert isinstance(model, PreTrainedModel)
chenych's avatar
chenych committed
180
181
    model.to_empty(device="cpu")

chenych's avatar
chenych committed
182
    print(f"Saving model to {hf_path}...")
chenych's avatar
chenych committed
183
    model.save_pretrained(hf_path, state_dict=state_dict)
chenych's avatar
chenych committed
184
    del state_dict, model
chenych's avatar
chenych committed
185

chenych's avatar
chenych committed
186
187
    if args.hf_upload_path:
        upload_model_to_huggingface(hf_path, args.hf_upload_path)