parallel_op.py 2.67 KB
Newer Older
zhangqha's avatar
zhangqha committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from typing import Callable, Generator, Tuple, Dict, Any

from deepmd.env import tf
from deepmd.utils.sess import run_sess


class ParallelOp:
    """Run an op with data parallelism.
    
    Parameters
    ----------
    builder : Callable[..., Tuple[Dict[str, tf.Tensor], Tuple[tf.Tensor]]]
        returns two objects: a dict which stores placeholders by key, and a tuple with the final op(s)
    nthreads : int, optional
        the number of threads
    config : tf.ConfigProto, optional
        tf.ConfigProto
    
    Examples
    --------
    >>> from deepmd.env import tf
    >>> from deepmd.utils.parallel_op import ParallelOp
    >>> def builder():
    ...     x = tf.placeholder(tf.int32, [1])
    ...     return {"x": x}, (x + 1)
    ...
    >>> p = ParallelOp(builder, nthreads=4)
    >>> def feed():
    ...     for ii in range(10):
    ...         yield {"x": [ii]}
    ...
    >>> print(*p.generate(tf.Session(), feed()))
    [1] [2] [3] [4] [5] [6] [7] [8] [9] [10]
    """
    def __init__(self, builder: Callable[..., Tuple[Dict[str, tf.Tensor], Tuple[tf.Tensor]]], nthreads: int = None, config: tf.ConfigProto = None) -> None:
        if nthreads is not None:
            self.nthreads = nthreads
        elif config is not None:
            self.nthreads = max(config.inter_op_parallelism_threads, 1)
        else:
            self.nthreads = 1
        
        self.placeholders = []
        self.ops = []
        for ii in range(self.nthreads):
            with tf.name_scope("task_%d" % ii) as scope:
                placeholder, op = builder()
                self.placeholders.append(placeholder)
                self.ops.append(op)

    def generate(self, sess: tf.Session, feed: Generator[Dict[str, Any], None, None]) -> Generator[Tuple, None, None]:
        """Returns a generator.

        Parameters
        ----------
        feed : Generator[dict, None, None]
            generator which yields feed_dict
        
        Yields
        ------
        Generator[Tuple, None, None]
            generator which yields session returns
        """
        nn = self.nthreads
        while True:
            feed_dict = {}
            for ii in range(self.nthreads):
                try:
                    fd = next(feed)
                except StopIteration:
                    if ii == 0:
                        return
                    nn = ii
                    break
                for kk, vv in fd.items():
                    feed_dict[self.placeholders[ii][kk]] = vv  
            ops = self.ops[:nn]       
            for yy in run_sess(sess, ops, feed_dict=feed_dict):
                yield yy