flash_neox.py 5.68 KB
Newer Older
1
2
3
4
5
6
import torch
import torch.distributed

from accelerate import init_empty_weights
from opentelemetry import trace
from safetensors import safe_open
7
from transformers import AutoTokenizer, AutoConfig
8
from typing import Optional, List
9

10
11
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
12
13
14
15
16
17
18
19
20
21
22
23
24
    FlashGPTNeoXForCausalLM,
    TensorParallelEmbedding,
    TensorParallelRowLinear,
    TensorParallelColumnLinear,
)
from text_generation_server.utils import (
    initialize_torch_distributed,
    weight_files,
)

tracer = trace.get_tracer(__name__)


25
class FlashNeoX(FlashCausalLM):
26
27
28
29
30
31
    def __init__(
        self,
        model_id: str,
        revision: Optional[str] = None,
        quantize: Optional[str] = None,
    ):
32
        super(FlashNeoX, self).__init__(
33
            FlashGPTNeoXForCausalLM, model_id, revision, quantize
34
35
36
37
38
        )


class FlashNeoXSharded(FlashNeoX):
    def __init__(
39
40
41
42
        self,
        model_id: str,
        revision: Optional[str] = None,
        quantize: Optional[str] = None,
43
    ):
44
        self.process_group, rank, world_size = initialize_torch_distributed()
45
        if torch.cuda.is_available():
46
            device = torch.device(f"cuda:{rank}")
47
            dtype = torch.float16
48
49
50
51
        else:
            raise NotImplementedError("FlashNeoX is only available on GPU")

        tokenizer = AutoTokenizer.from_pretrained(
52
            model_id, revision=revision, padding_side="left", truncation_side="left"
53
54
55
        )

        config = AutoConfig.from_pretrained(
56
57
            model_id,
            revision=revision,
58
59
60
61
62
63
        )

        torch.distributed.barrier(group=self.process_group)
        filenames = weight_files(model_id, revision=revision, extension=".safetensors")

        with init_empty_weights():
64
            model = FlashGPTNeoXForCausalLM(config, self.process_group)
65
66
67
68
69

        torch.distributed.barrier(group=self.process_group)
        self.load_weights(
            model,
            filenames,
70
            quantize=quantize,
71
            device=device,
72
            dtype=dtype,
73
74
            rank=rank,
            world_size=world_size,
75
        )
76
        self.model = model.eval().to(device)
77
        torch.distributed.barrier(group=self.process_group)
78
        super(FlashCausalLM, self).__init__(
79
            tokenizer=tokenizer,
80
81
            requires_padding=False,
            dtype=dtype,
82
            device=device,
83
84
            rank=rank,
            world_size=world_size,
85
86
87
88
89
90
        )

    @staticmethod
    def load_weights(
        model,
        filenames: List[str],
91
        quantize: bool,
92
        device: torch.device,
93
        dtype: torch.dtype,
94
95
96
97
98
        rank: int,
        world_size: int,
    ):
        parameters = dict(model.named_parameters())
        for file in filenames:
99
            with safe_open(
100
                file, framework="pt", device=str(device) if quantize is None else "cpu"
101
            ) as f:
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
                for name in f.keys():
                    module_name, param_name = name.rsplit(".", 1)
                    module = model.get_submodule(module_name)

                    current_parameter_tensor = parameters.get(name, None)

                    slice_ = f.get_slice(name)

                    if isinstance(module, TensorParallelColumnLinear):
                        size = slice_.get_shape()[0]
                        block_size = size // world_size
                        start = rank * block_size
                        stop = (rank + 1) * block_size
                        tensor = slice_[start:stop]
                    elif isinstance(module, TensorParallelRowLinear):
                        if param_name == "weight":
                            size = slice_.get_shape()[1]
                            block_size = size // world_size
                            start = rank * block_size
                            stop = (rank + 1) * block_size
                            tensor = slice_[:, start:stop]
                        else:
                            tensor = slice_[:]
                            # XXX: Hack for Rowlinear to add the bias only once.
                            if rank != 0:
                                tensor = torch.zeros_like(tensor)
                    elif isinstance(module, TensorParallelEmbedding):
                        size = slice_.get_shape()[0]
                        block_size = size // world_size
                        start = rank * block_size
                        stop = (rank + 1) * block_size
                        tensor = slice_[start:stop]
                    elif name == "embed_out.weight" and model.gpt_neox.tp_embeddings:
                        size = slice_.get_shape()[0]
                        block_size = size // world_size
                        start = rank * block_size
                        stop = (rank + 1) * block_size
                        tensor = slice_[start:stop]
                    else:
                        try:
                            tensor = slice_[:]
                        except:
                            tensor = f.get_tensor(name)

                    if (
                        current_parameter_tensor is not None
                        and current_parameter_tensor.shape != tensor.shape
                    ):
                        raise ValueError(
                            f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
                        )

154
                    tensor = tensor.contiguous().to(dtype)
155
156
157
158
159

                    if current_parameter_tensor is not None:
                        module._parameters[param_name] = tensor
                    else:
                        module._buffers[param_name] = tensor
160

161
        model.post_load_weights(quantize)