feedforward.py 2.74 KB
Newer Older
jerrrrry's avatar
jerrrrry committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
# https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py

from diffusers.models.attention import FeedForward, GELU, GEGLU
from torch import nn
from xfuser.core.distributed.parallel_state import (
    get_tensor_model_parallel_world_size,
    get_tensor_model_parallel_rank,
    get_tp_group,
)
import torch
from xfuser.model_executor.layers.base_layer import xFuserLayerBaseWrapper
from xfuser.model_executor.layers.register import xFuserLayerWrappersRegister


@xFuserLayerWrappersRegister.register(FeedForward)
class xFuserFeedForwardWrapper(xFuserLayerBaseWrapper):
    def __init__(self, feedforward: FeedForward):
        super(xFuserFeedForwardWrapper, self).__init__(module=feedforward)

        tp_degree = get_tensor_model_parallel_world_size()
        tp_rank = get_tensor_model_parallel_rank()

        if isinstance(self.module.net[0], GELU):
            self.module.net[0].proj.weight.data = self.module.net[
                0
            ].proj.weight.data.chunk(tp_degree, dim=0)[tp_rank]
            if self.module.net[0].proj.bias is not None:
                self.module.net[0].proj.bias.data = self.module.net[
                    0
                ].proj.bias.data.chunk(tp_degree, dim=0)[tp_rank]
        elif isinstance(self.module.net[0], GEGLU):
            weight_buff = self.module.net[0].proj.weight.data.chunk(2, dim=0)
            a = weight_buff[0].chunk(tp_degree, dim=0)[tp_rank]
            b = weight_buff[1].chunk(tp_degree, dim=0)[tp_rank]
            c = torch.cat([a, b], dim=0)

            self.module.net[0].proj.weight.data = c

            bias_buff = self.module.net[0].proj.bias.data.chunk(2, dim=0)
            a = bias_buff[0].chunk(tp_degree, dim=0)[tp_rank]
            b = bias_buff[1].chunk(tp_degree, dim=0)[tp_rank]
            c = torch.cat([a, b], dim=0)
            self.module.net[0].proj.bias.data = c

        else:
            raise TypeError(
                f"activation_fn {type(isinstance(self.module.net[0]))} not supported"
            )

        self.module.net[2].weight.data = self.module.net[2].weight.chunk(
            tp_degree, dim=1
        )[tp_rank]

        self.has_output_bias = False
        if self.module.net[2].bias is not None:
            self.register_parameter(
                "output_bias", nn.Parameter(self.module.net[2].bias.data.clone())
            )
            self.module.net[2].bias = None
            self.has_output_bias = True

        torch.cuda.empty_cache()

    def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
        hidden_states = self.module(hidden_states, *args, **kwargs)
        get_tp_group().all_reduce(hidden_states)
        if self.has_output_bias:
            hidden_states += self.output_bias
        return hidden_states