"googlemock/vscode:/vscode.git/clone" did not exist on "1f9c668a0452148f725b242079f70d65d3e93153"
utils.py 816 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from contextlib import contextmanager

import torch


def _noop(*args, **kwargs):
    pass


@contextmanager
def low_resource_init():
    """This context manager disables weight initialization and sets the default float dtype to half.
    """
    old_kaiming_uniform_ = torch.nn.init.kaiming_uniform_
    old_uniform_ = torch.nn.init.uniform_
    old_normal_ = torch.nn.init.normal_
    dtype = torch.get_default_dtype()
    try:
        torch.nn.init.kaiming_uniform_ = _noop
        torch.nn.init.uniform_ = _noop
        torch.nn.init.normal_ = _noop
        torch.set_default_dtype(torch.half)
        yield
    finally:
        torch.nn.init.kaiming_uniform_ = old_kaiming_uniform_
        torch.nn.init.uniform_ = old_uniform_
        torch.nn.init.normal_ = old_normal_
        torch.set_default_dtype(dtype)