grok1_policy.py 3.93 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from typing import Dict, Union

import torch.nn as nn

from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription


class Grok1Policy(Policy):
    def config_sanity_check(self):
        pass

    def preprocess(self) -> nn.Module:
        if self.shard_config.enable_tensor_parallelism:
            vocab_size = self.model.config.vocab_size
            world_size = self.shard_config.tensor_parallel_size
            assert vocab_size % world_size == 0, f"vocab_size {vocab_size} must be divisible by world_size {world_size}"
        return self.model

    def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
        policy = {}
        if self.shard_config.enable_tensor_parallelism:
            decoder_attribute_replacement = {
                "attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
                "attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
                "attn.num_key_value_heads": self.model.config.num_key_value_heads
                // self.shard_config.tensor_parallel_size,
            }
            decoder_submodule_replacement = [
                SubModuleReplacementDescription(
                    suffix="attn.q_proj",
                    target_module=Linear1D_Col,
                ),
                SubModuleReplacementDescription(
                    suffix="attn.k_proj",
                    target_module=Linear1D_Col,
                ),
                SubModuleReplacementDescription(
                    suffix="attn.v_proj",
                    target_module=Linear1D_Col,
                ),
                SubModuleReplacementDescription(
                    suffix="attn.o_proj",
                    target_module=Linear1D_Row,
                ),
            ]
            for i in range(self.model.config.num_experts):
                decoder_submodule_replacement.extend(
                    [
                        SubModuleReplacementDescription(
                            suffix=f"moe_block.experts[{i}].linear",
                            target_module=Linear1D_Col,
                        ),
                        SubModuleReplacementDescription(
                            suffix=f"moe_block.experts[{i}].linear_v",
                            target_module=Linear1D_Col,
                        ),
                        SubModuleReplacementDescription(
                            suffix=f"moe_block.experts[{i}].linear_1",
                            target_module=Linear1D_Row,
                        ),
                    ]
                )

            policy["DecoderLayer"] = ModulePolicyDescription(
                attribute_replacement=decoder_attribute_replacement,
                sub_module_replacement=decoder_submodule_replacement,
            )
            self.append_or_create_submodule_replacement(
                description=SubModuleReplacementDescription(
                    suffix="embed_tokens",
                    target_module=VocabParallelEmbedding1D,
                ),
                policy=policy,
                target_key="Grok1Model",
            )
        return policy

    def postprocess(self):
        return self.model


class Grok1ModelPolicy(Grok1Policy):
    pass


class Grok1ForCausalLMPolicy(Grok1Policy):
    def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
        policy = super().module_policy()
        self.append_or_create_submodule_replacement(
            description=SubModuleReplacementDescription(
                suffix="lm_head",
                target_module=Linear1D_Col,
                kwargs={"gather_output": not self.shard_config.parallel_output},
            ),
            policy=policy,
            target_key="Grok1ModelForCausalLM",
        )
        return policy