parallel.py 723 Bytes
Newer Older
hepj's avatar
hepj committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch

from fastvideo.utils.communications import all_gather
from fastvideo.utils.parallel_states import nccl_info


def parallel_forward(fn_):

    def wrapTheFunction(_, hidden_states, *args, **kwargs):
        if kwargs['parallel']:
            hidden_states = torch.chunk(hidden_states, nccl_info.sp_size, dim=-2)[nccl_info.rank_within_group]
            kwargs['attn_mask'] = torch.chunk(kwargs['attn_mask'], nccl_info.sp_size,
                                              dim=-2)[nccl_info.rank_within_group]
        output = fn_(_, hidden_states, *args, **kwargs)

        if kwargs['parallel']:
            output = all_gather(output.contiguous(), dim=-2)

        return output

    return wrapTheFunction