meta_tensor.py 2.43 KB
Newer Older
aiss's avatar
aiss committed
1
2
3
4
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
aiss's avatar
aiss committed
5
6
7
8
9

from abc import ABC, abstractmethod


class MetaTensorContainer(ABC):
aiss's avatar
aiss committed
10

aiss's avatar
aiss committed
11
12
13
14
15
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.is_meta = False
        self.ckpt_load_enabled = True

aiss's avatar
aiss committed
16
17
    def initialize_tensors(self, enable_training=False):
        super().initialize_tensors(enable_training=enable_training)
aiss's avatar
aiss committed
18
19
        self.is_meta = self.qkvw.is_meta

aiss's avatar
aiss committed
20
    def apply_tensor_parallelism(self, mp_replace=None, mp_group=None, tp_size=None):
aiss's avatar
aiss committed
21
22
23
24
25
26
        if self.is_meta:
            if self.qkvb is None:
                self.module.attention.attn_qkvb = None
            if self.dense_b is None:
                self.module.attention.attn_ob = None
        else:
aiss's avatar
aiss committed
27
            super().apply_tensor_parallelism(mp_replace, mp_group, tp_size)
aiss's avatar
aiss committed
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59

    def copy_data_to_new_module(self):
        if self.is_meta:
            if self.attn_nw is None:
                self.module.mlp.attn_nw = self.attn_nw
                self.module.mlp.attn_nb = self.attn_nb
        else:
            super().copy_data_to_new_module()

    def transpose(self):
        if not self.is_meta:
            super().transpose()

    @abstractmethod
    def load_params(self, module, sd, weight_quantizer, mp_replace, prefix):
        """
        Load all the transformer parameter from the checkpoint file (sd).
        In addition to the parameter names, we require two
        more parameters to help read the the data correctly
        from the checkpoint and split the qkv heads in the
        right order:
            1. `use_load_prefix` (Default: False): this specifies
                whether we need to use the name of first abstraction
                layer of the model for searching the parameter's name
                in a checkpoint file. For more information of how this
                is used please see
                https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/module_inject/load_checkpoint.py
            2. `split_qkv` (Default: True): we use this flag when splitting
                the qkv parameter into heads. If it is False, it means the heads
                of q, k, and v are stored together and needs to split in the
                DeepSpeed-Inference API.
        """
aiss's avatar
aiss committed
60
        raise NotImplementedError("A load_params() function must be defined in the model container \
aiss's avatar
aiss committed
61
                                  when inheriting the MetaTensorContainer feature")