parallel.py 1.53 KB
Newer Older
luopl's avatar
luopl committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch.distributed as dist
import xfuser
import torch


def initialize_parall_group(ring_degree, ulysses_degree):
    dist.init_process_group("nccl")
    xfuser.core.distributed.init_distributed_environment(
        rank=dist.get_rank(), 
        world_size=dist.get_world_size()
    )
    
    xfuser.core.distributed.initialize_model_parallel(
        sequence_parallel_degree=ulysses_degree,
        ring_degree=ring_degree,
        ulysses_degree=ulysses_degree,
    )
    torch.cuda.set_device(dist.get_rank())

def get_parallel_group():
    return xfuser.core.distributed.get_world_group()

def get_sequence_parallel_world_size():
    return xfuser.core.distributed.parallel_state.get_sequence_parallel_world_size()

def get_sequence_parallel_rank():
    return xfuser.core.distributed.parallel_state.get_sequence_parallel_rank()

def get_sp_group():
    return xfuser.core.distributed.parallel_state.get_sp_group()



def parallel_forward(fn_):
    def wrapTheFunction(_, hidden_states, *args, **kwargs):
        if kwargs['parallel']:            
            hidden_states = torch.chunk(hidden_states, get_sequence_parallel_world_size(), dim=-2)[get_sequence_parallel_rank()]
            kwargs['attn_mask'] = torch.chunk(kwargs['attn_mask'], get_sequence_parallel_world_size(), dim=-2)[get_sequence_parallel_rank()]
        output = fn_(_, hidden_states, *args, **kwargs)
        
        if kwargs['parallel']:
            output = get_sp_group().all_gather(output.contiguous(), dim=-2)
        
        return output
     
    return wrapTheFunction