mixtral_layer.py 4.94 KB
Newer Older
Xuanlei Zhao's avatar
Xuanlei Zhao committed
1
import torch
Hongxin Liu's avatar
Hongxin Liu committed
2
3
4
import torch.distributed as dist
import torch.nn.functional as F
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
Xuanlei Zhao's avatar
Xuanlei Zhao committed
5
6

from colossalai.lazy import LazyInitContext
Hongxin Liu's avatar
Hongxin Liu committed
7
8
9
10
from colossalai.moe import MOE_MANAGER
from colossalai.moe._operation import MoeInGradScaler, MoeOutGradScaler, all_to_all_uneven
from colossalai.shardformer.shard.utils import set_tensors_to_none
from colossalai.tensor.moe_tensor.api import set_moe_tensor_info
Xuanlei Zhao's avatar
Xuanlei Zhao committed
11
12


Hongxin Liu's avatar
Hongxin Liu committed
13
14
15
16
class EPMixtralSparseMoeBlock(MixtralSparseMoeBlock):
    def __init__(self, config):
        super().__init__(config)
        self.setup_ep()
Xuanlei Zhao's avatar
Xuanlei Zhao committed
17

Hongxin Liu's avatar
Hongxin Liu committed
18
19
20
21
22
23
24
25
26
27
28
29
30
    def setup_ep(self):
        _, moe_info = MOE_MANAGER.get_info(self.num_experts)
        ep_group = moe_info.ep_group
        self.ep_size = dist.get_world_size(ep_group) if ep_group is not None else 1
        self.ep_rank = dist.get_rank(ep_group) if ep_group is not None else 0
        assert self.num_experts % self.ep_size == 0
        self.ep_group = ep_group
        self.num_experts_per_ep = self.num_experts // self.ep_size
        self.expert_start_idx = self.ep_rank * self.num_experts_per_ep
        held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep]
        set_tensors_to_none(self.experts, exclude=set(held_experts))
        for p in self.experts.parameters():
            set_moe_tensor_info(p, moe_info)
Xuanlei Zhao's avatar
Xuanlei Zhao committed
31
32

    @staticmethod
Hongxin Liu's avatar
Hongxin Liu committed
33
34
35
36
37
    def from_native_module(module: MixtralSparseMoeBlock, *args, **kwargs) -> "EPMixtralSparseMoeBlock":
        LazyInitContext.materialize(module)
        module.__class__ = EPMixtralSparseMoeBlock
        module.setup_ep()
        return module
Xuanlei Zhao's avatar
Xuanlei Zhao committed
38

Hongxin Liu's avatar
Hongxin Liu committed
39
40
41
42
43
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        batch_size, sequence_length, hidden_dim = hidden_states.shape
        hidden_states = hidden_states.view(-1, hidden_dim)
        # router_logits: (batch * sequence_length, n_experts)
        router_logits = self.gate(hidden_states)
Xuanlei Zhao's avatar
Xuanlei Zhao committed
44

Hongxin Liu's avatar
Hongxin Liu committed
45
46
47
48
49
        routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
        routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
        routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
        # we cast back to the input dtype
        routing_weights = routing_weights.to(hidden_states.dtype)
Xuanlei Zhao's avatar
Xuanlei Zhao committed
50

Hongxin Liu's avatar
Hongxin Liu committed
51
52
53
54
55
56
        selected_experts = selected_experts.t().reshape(-1)
        selected_experts_idx = selected_experts.argsort()
        dispatch_states = hidden_states.repeat(self.top_k, 1)[selected_experts_idx]
        input_split_sizes = selected_experts.bincount(minlength=self.num_experts)
        output_split_sizes = torch.zeros_like(input_split_sizes)
        dist.all_to_all_single(output_split_sizes, input_split_sizes, group=self.ep_group)
Xuanlei Zhao's avatar
Xuanlei Zhao committed
57

Hongxin Liu's avatar
Hongxin Liu committed
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
        input_split_list = input_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
        output_split_list = output_split_sizes.view(self.ep_size, self.num_experts_per_ep).sum(dim=-1).tolist()
        output_states, _ = all_to_all_uneven(dispatch_states, input_split_list, output_split_list, self.ep_group)
        # compute expert output
        output_states = MoeInGradScaler.apply(output_states, self.ep_size)
        if output_states.size(0) > 0:
            if self.num_experts_per_ep == 1:
                # no need to split
                expert = self.experts[self.expert_start_idx]
                output_states = expert.act_fn(expert.w1(output_states)) * expert.w3(output_states)
                output_states = expert.w2(output_states)
            else:
                output_states_splits = output_states.split(output_split_sizes.tolist())
                output_states_list = []
                for i, split_states in enumerate(output_states_splits):
                    if split_states.size(0) == 0:
                        continue
                    expert = self.experts[self.expert_start_idx + i % self.num_experts_per_ep]
                    split_states = expert.act_fn(expert.w1(split_states)) * expert.w3(split_states)
                    split_states = expert.w2(split_states)
                    output_states_list.append(split_states)
                output_states = torch.cat(output_states_list)
        output_states = MoeOutGradScaler.apply(output_states, self.ep_size)
        dispatch_states, _ = all_to_all_uneven(output_states, output_split_list, input_split_list, self.ep_group)
        recover_experts_idx = torch.empty_like(selected_experts_idx)
        recover_experts_idx[selected_experts_idx] = torch.arange(
            selected_experts_idx.size(0), device=selected_experts_idx.device
Xuanlei Zhao's avatar
Xuanlei Zhao committed
85
        )
Hongxin Liu's avatar
Hongxin Liu committed
86
87
88
89
90
91
92
        dispatch_states = dispatch_states[recover_experts_idx]
        k_hidden_states = dispatch_states.chunk(self.top_k)
        output_states = k_hidden_states[0] * routing_weights[:, 0, None]
        for i in range(1, self.top_k):
            output_states += k_hidden_states[i] * routing_weights[:, i, None]
        output_states = output_states.reshape(batch_size, sequence_length, hidden_dim)
        return output_states, router_logits