manager.py 853 Bytes
Newer Older
gushiqiao's avatar
gushiqiao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import torch


class WeightStreamManager(object):
    def __init__(self):
        self.active_weights = [None for _ in range(2)]
        self.compute_stream = torch.cuda.Stream(priority=-1)
        self.load_stream = torch.cuda.Stream(priority=0)

    def prefetch_weights(self, block_idx, blocks_weights):
        with torch.cuda.stream(self.load_stream):
            if self.active_weights[1] is not None:
                self.active_weights[1].to_cpu_sync()
            new_weights = blocks_weights[block_idx]
            new_weights.to_cuda_sync()
            self.active_weights[1] = new_weights

    def swap_weights(self):
        self.compute_stream.synchronize()
        self.load_stream.synchronize()

        self.active_weights[0], self.active_weights[1] = (
            self.active_weights[1],
            self.active_weights[0],
        )