batch_size.py 5.2 KB
Newer Older
zhangqha's avatar
zhangqha committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import logging
from typing import Callable, Tuple

import numpy as np

from deepmd.utils.errors import OutOfMemoryError

class AutoBatchSize:
    """This class allows DeePMD-kit to automatically decide the maximum
    batch size that will not cause an OOM error.

    Notes
    -----
    We assume all OOM error will raise :class:`OutOfMemoryError`.

    Parameters
    ----------
    initial_batch_size : int, default: 1024
        initial batch size (number of total atoms)
    factor : float, default: 2.
        increased factor

    Attributes
    ----------
    current_batch_size : int
        current batch size (number of total atoms)
    maximum_working_batch_size : int
        maximum working batch size 
    minimal_not_working_batch_size : int
        minimal not working batch size
    """
    def __init__(self, initial_batch_size: int = 1024, factor: float = 2.) -> None:
        # See also PyTorchLightning/pytorch-lightning#1638
        # TODO: discuss a proper initial batch size
        self.current_batch_size = initial_batch_size
        self.maximum_working_batch_size = 0
        self.minimal_not_working_batch_size = 2**31
        self.factor = factor

    def execute(self, callable: Callable, start_index: int, natoms: int) -> Tuple[int, tuple]:
        """Excuate a method with given batch size.
        
        Parameters
        ----------
        callable : Callable
            The method should accept the batch size and start_index as parameters,
            and returns executed batch size and data.
        start_index : int
            start index
        natoms : int
            natoms
        
        Returns
        -------
        int
            executed batch size * number of atoms
        tuple
            result from callable, None if failing to execute

        Raises
        ------
        OutOfMemoryError
            OOM when batch size is 1
        """
        try:
            n_batch, result = callable(max(self.current_batch_size // natoms, 1), start_index)
        except OutOfMemoryError as e:
            # TODO: it's very slow to catch OOM error; I don't know what TF is doing here
            # but luckily we only need to catch once
            self.minimal_not_working_batch_size = min(self.minimal_not_working_batch_size, self.current_batch_size)
            if self.maximum_working_batch_size >= self.minimal_not_working_batch_size:
                self.maximum_working_batch_size = int(self.minimal_not_working_batch_size / self.factor)
            if self.minimal_not_working_batch_size <= natoms:
                raise OutOfMemoryError("The callable still throws an out-of-memory (OOM) error even when batch size is 1!") from e
            # adjust the next batch size
            self._adjust_batch_size(1./self.factor)
            return 0, None
        else:
            n_tot = n_batch * natoms
            self.maximum_working_batch_size = max(self.maximum_working_batch_size, n_tot)
            # adjust the next batch size
            if n_tot + natoms > self.current_batch_size and self.current_batch_size * self.factor < self.minimal_not_working_batch_size:
                self._adjust_batch_size(self.factor)
            return n_batch, result

    def _adjust_batch_size(self, factor: float):
        old_batch_size = self.current_batch_size
        self.current_batch_size = int(self.current_batch_size * factor)
        logging.info("Adjust batch size from %d to %d" % (old_batch_size, self.current_batch_size))

    def execute_all(self, callable: Callable, total_size: int, natoms: int, *args, **kwargs) -> Tuple[np.ndarray]:
        """Excuate a method with all given data. 
        
        Parameters
        ----------
        callable : Callable
            The method should accept *args and **kwargs as input and return the similiar array.
        total_size : int
            Total size
        natoms : int
            The number of atoms
        **kwargs
            If 2D np.ndarray, assume the first axis is batch; otherwise do nothing.
        """
        def execute_with_batch_size(batch_size: int, start_index: int) -> Tuple[int, Tuple[np.ndarray]]:
            end_index = start_index + batch_size
            end_index = min(end_index, total_size)
            return (end_index - start_index), callable(
                *[(vv[start_index:end_index] if isinstance(vv, np.ndarray) and vv.ndim > 1 else vv) for vv in args],
                **{kk: (vv[start_index:end_index] if isinstance(vv, np.ndarray) and vv.ndim > 1 else vv) for kk, vv in kwargs.items()},
            )

        index = 0
        results = []
        while index < total_size:
            n_batch, result = self.execute(execute_with_batch_size, index, natoms)
            if not isinstance(result, tuple):
                result = (result,)
            index += n_batch
            if n_batch:
                for rr in result:
                    rr.reshape((n_batch, -1))
                results.append(result)
        
        r = tuple([np.concatenate(r, axis=0) for r in zip(*results)])
        if len(r) == 1:
            # avoid returning tuple if callable doesn't return tuple
            r = r[0]
        return r