shard_model.py 3.69 KB
Newer Older
Olivier Dehaene's avatar
Init  
Olivier Dehaene committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from pathlib import Path

import torch
from torch import nn
from transformers import AutoModelForCausalLM


def match_suffix(text, suffix):
    return text[-len(suffix) :] == suffix


def shard_model(model_name: str, path: Path, tp_world_size: int, dtype: torch.dtype):
    """BLOOM specific sharding mechanism"""
    save_paths = [
        path / f"{model_name}_tp-rank-{tp_rank}-of-{tp_world_size}.pty"
        for tp_rank in range(tp_world_size)
    ]
    if all(save_path.exists() for save_path in save_paths):
        print("Loading already cached values")
        return save_paths

    model: nn.Module = AutoModelForCausalLM.from_pretrained(
        model_name, torch_dtype=dtype, local_files_only=True
    )

    shards_state_dicts = [{} for _ in range(tp_world_size)]
    state_dict = model.state_dict()
    keys = list(state_dict.keys())
    for state_name in keys:
        print(state_name)
        state = state_dict[state_name]
        if any(
            match_suffix(state_name, candidate)
            for candidate in [
                "self_attention.query_key_value.weight",
                "self_attention.query_key_value.bias",
                "mlp.dense_h_to_4h.weight",
                "mlp.dense_h_to_4h.bias",
                "transformer.word_embeddings.weight",
                "lm_head.weight",
            ]
        ):
            output_size = state.shape[0]
            assert output_size % tp_world_size == 0
            block_size = output_size // tp_world_size
            sharded_weights = torch.split(state, block_size, dim=0)
            assert len(sharded_weights) == tp_world_size
            for tp_rank, shard in enumerate(sharded_weights):
                assert shard.shape[0] == block_size
                shards_state_dicts[tp_rank][state_name] = shard.detach().clone()
        elif any(
            match_suffix(state_name, candidate)
            for candidate in [
                "self_attention.dense.weight",
                "mlp.dense_4h_to_h.weight",
                "lm_head.weight",
            ]
        ):
            input_size = state.shape[1]
            assert input_size % tp_world_size == 0
            block_size = input_size // tp_world_size
            sharded_weights = torch.split(state, block_size, dim=1)
            assert len(sharded_weights) == tp_world_size
            for tp_rank, shard in enumerate(sharded_weights):
                assert shard.shape[1] == block_size
                shards_state_dicts[tp_rank][state_name] = shard.detach().clone()
        elif any(
            match_suffix(state_name, candidate)
            for candidate in [
                "self_attention.dense.bias",
                "mlp.dense_4h_to_h.bias",
            ]
        ):
            shards_state_dicts[0][state_name] = state.detach().clone()
            for tp_rank in range(1, tp_world_size):
                shards_state_dicts[tp_rank][state_name] = torch.zeros_like(state)
        else:
            # We duplicate parameters across tp ranks
            for tp_rank in range(tp_world_size):
                shards_state_dicts[tp_rank][state_name] = state.detach().clone()

        del state_dict[state_name]  # delete key from state_dict
        del state  # delete tensor

    # we save state_dict
    for tp_rank, (save_path, shard_state_dict) in enumerate(
        zip(save_paths, shards_state_dicts)
    ):
        save_path.parent.mkdir(parents=True, exist_ok=True)
        torch.save(shard_state_dict, save_path)
        save_paths.append(save_path)

    return save_paths


if __name__ == "__main__":
    model_name = "bigscience/bloom"
    save_path = Path("/data/shards")
    tp_world_size = 8
    dtype = torch.bfloat16

    shard_model(model_name, save_path, tp_world_size=tp_world_size, dtype=dtype)