manager.py 1.51 KB
Newer Older
gushiqiao's avatar
gushiqiao committed
1
2
3
import torch


4
class WeightAsyncStreamManager(object):
gushiqiao's avatar
gushiqiao committed
5
    def __init__(self):
6
        self.active_weights = [None for _ in range(2)]
gushiqiao's avatar
gushiqiao committed
7
8
9
10
11
12
13
        self.active_weights = [None for _ in range(2)]
        self.compute_stream = torch.cuda.Stream(priority=-1)
        self.load_stream = torch.cuda.Stream(priority=0)

    def prefetch_weights(self, block_idx, blocks_weights):
        with torch.cuda.stream(self.load_stream):
            if self.active_weights[1] is not None:
14
                self.active_weights[1].to_cpu_async()
gushiqiao's avatar
gushiqiao committed
15
            new_weights = blocks_weights[block_idx]
16
            new_weights.to_cuda_async()
gushiqiao's avatar
gushiqiao committed
17
18
19
20
21
22
23
24
25
26
            self.active_weights[1] = new_weights

    def swap_weights(self):
        self.compute_stream.synchronize()
        self.load_stream.synchronize()

        self.active_weights[0], self.active_weights[1] = (
            self.active_weights[1],
            self.active_weights[0],
        )
27
28
29
30
31
32
33
34
35
36
37
38
39
40

    def prefetch_phase(self, block_idx, phase_idx, blocks):
        with torch.cuda.stream(self.load_stream):
            if self.active_weights[1] is not None:
                _, old_phase = self.active_weights[1]
                old_phase.to_cpu_async()
            new_phase = blocks[block_idx].compute_phases[phase_idx]
            new_phase.to_cuda_async()
            self.active_weights[1] = (phase_idx, new_phase)

    def swap_phases(self):
        self.compute_stream.synchronize()
        self.load_stream.synchronize()
        self.active_weights[0], self.active_weights[1] = self.active_weights[1], self.active_weights[0]