manager.py 2.15 KB
Newer Older
gushiqiao's avatar
gushiqiao committed
1
2
3
import torch


4
class WeightAsyncStreamManager(object):
5
6
7
    def __init__(self, blocks_num, offload_ratio=1, phases_num=1):
        self.active_weights = [None for _ in range(3)]
        self.active_weights = [None for _ in range(3)]
gushiqiao's avatar
gushiqiao committed
8
        self.compute_stream = torch.cuda.Stream(priority=-1)
9
10
11
12
13
        self.cpu_load_stream = torch.cuda.Stream(priority=0)
        self.cuda_load_stream = torch.cuda.Stream(priority=0)
        self.offload_block_num = offload_ratio * blocks_num
        self.phases_num = phases_num
        self.offload_phases_num = blocks_num * phases_num * offload_ratio
gushiqiao's avatar
gushiqiao committed
14
15

    def prefetch_weights(self, block_idx, blocks_weights):
16
17
18
19
20
21
22
        with torch.cuda.stream(self.cuda_load_stream):
            self.active_weights[2] = blocks_weights[block_idx]
            self.active_weights[2].to_cuda_async()
        with torch.cuda.stream(self.cpu_load_stream):
            if block_idx < self.offload_block_num:
                if self.active_weights[1] is not None:
                    self.active_weights[1].to_cpu_async()
gushiqiao's avatar
gushiqiao committed
23
24
25

    def swap_weights(self):
        self.compute_stream.synchronize()
26
27
        self.cpu_load_stream.synchronize()
        self.cuda_load_stream.synchronize()
gushiqiao's avatar
gushiqiao committed
28
29

        self.active_weights[0], self.active_weights[1] = (
30
            self.active_weights[2],
gushiqiao's avatar
gushiqiao committed
31
32
            self.active_weights[0],
        )
33
34

    def prefetch_phase(self, block_idx, phase_idx, blocks):
35
        with torch.cuda.stream(self.cuda_load_stream):
36
37
            new_phase = blocks[block_idx].compute_phases[phase_idx]
            new_phase.to_cuda_async()
38
39
40
41
42
43
            self.active_weights[2] = (phase_idx, blocks[block_idx].compute_phases[phase_idx])
        with torch.cuda.stream(self.cpu_load_stream):
            if block_idx * self.phases_num + phase_idx < self.offload_phases_num:
                if self.active_weights[1] is not None:
                    _, old_phase = self.active_weights[1]
                    old_phase.to_cpu_async()
44
45
46

    def swap_phases(self):
        self.compute_stream.synchronize()
47
48
49
        self.cpu_load_stream.synchronize()
        self.cuda_load_stream.synchronize()
        self.active_weights[0], self.active_weights[1] = self.active_weights[2], self.active_weights[0]