loss.py 1.72 KB
Newer Older
zhangqha's avatar
zhangqha committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from abc import ABCMeta, abstractmethod
from typing import Tuple, Dict
from deepmd.env import tf


class Loss(metaclass=ABCMeta):
    """The abstract class for the loss function."""
    @abstractmethod
    def build(self, 
            learning_rate: tf.Tensor,
            natoms: tf.Tensor,
            model_dict: Dict[str, tf.Tensor],
            label_dict: Dict[str, tf.Tensor],
            suffix: str) -> Tuple[tf.Tensor, Dict[str, tf.Tensor]]:
        """Build the loss function graph.
        
        Parameters
        ----------
        learning_rate : tf.Tensor
            learning rate
        natoms : tf.Tensor
            number of atoms
        model_dict : dict[str, tf.Tensor]
            A dictionary that maps model keys to tensors
        label_dict : dict[str, tf.Tensor]
            A dictionary that maps label keys to tensors
        suffix : str
            suffix

        Returns
        -------
        tf.Tensor
            the total squared loss
        dict[str, tf.Tensor]
            A dictionary that maps loss keys to more loss tensors
        """

    @abstractmethod
    def eval(self,
             sess: tf.Session,
             feed_dict: Dict[tf.placeholder, tf.Tensor],
             natoms: tf.Tensor) -> dict:
        """Eval the loss function.

        Parameters
        ----------
        sess : tf.Session
            TensorFlow session
        feed_dict : dict[tf.placeholder, tf.Tensor]
            A dictionary that maps graph elements to values
        natoms : tf.Tensor
            number of atoms

        Returns
        -------
        dict
            A dictionary that maps keys to values. It
            should contain key `natoms`
        """