base_module.py 660 Bytes
Newer Older
luopl's avatar
init  
luopl committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from torch import nn

from ..utils import DistriConfig


class BaseModule(nn.Module):
    def __init__(
        self,
        module: nn.Module,
        distri_config: DistriConfig,
    ):
        super(BaseModule, self).__init__()
        self.module = module
        self.distri_config = distri_config
        self.comm_manager = None

        self.counter = 0

        self.buffer_list = None
        self.idx = None

    def forward(self, *args, **kwargs):
        raise NotImplementedError

    def set_counter(self, counter: int = 0):
        self.counter = counter

    def set_comm_manager(self, comm_manager):
        self.comm_manager = comm_manager