train_higgs_test.py 5.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
15
"""Tests for boosted_tree."""
16
17
18
19
20
21
22
23
24
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import tempfile

import numpy as np
import pandas as pd
25
import tensorflow as tf
26

27
# pylint: disable=g-bad-import-order
28
from official.boosted_trees import train_higgs
29
from official.utils.testing import integration
30

31
TEST_CSV = os.path.join(os.path.dirname(__file__), "train_higgs_test.csv")
32

33
tf.logging.set_verbosity(tf.logging.ERROR)
34
35
36
37
38
39
40
41
42
43
44
45
46
47


class BaseTest(tf.test.TestCase):
  """Tests for Wide Deep model."""

  @classmethod
  def setUpClass(cls):  # pylint: disable=invalid-name
    super(BaseTest, cls).setUpClass()
    train_higgs.define_train_higgs_flags()

  def setUp(self):
    # Create temporary CSV file
    self.data_dir = self.get_temp_dir()
    data = pd.read_csv(
48
        TEST_CSV, dtype=np.float32, names=["c%02d" % i for i in range(29)]
49
50
51
52
53
54
55
56
57
58
    ).as_matrix()
    self.input_npz = os.path.join(self.data_dir, train_higgs.NPZ_FILE)
    # numpy.savez doesn't take gfile.Gfile, so need to write down and copy.
    tmpfile = tempfile.NamedTemporaryFile()
    np.savez_compressed(tmpfile, data=data)
    tf.gfile.Copy(tmpfile.name, self.input_npz)

  def test_read_higgs_data(self):
    """Tests read_higgs_data() function."""
    # Error when a wrong data_dir is given.
59
    with self.assertRaisesRegexp(RuntimeError, "Error loading data.*"):
60
      train_data, eval_data = train_higgs.read_higgs_data(
61
          self.data_dir + "non-existing-path",
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
          train_start=0, train_count=15, eval_start=15, eval_count=5)

    # Loading fine with the correct data_dir.
    train_data, eval_data = train_higgs.read_higgs_data(
        self.data_dir,
        train_start=0, train_count=15, eval_start=15, eval_count=5)
    self.assertEqual((15, 29), train_data.shape)
    self.assertEqual((5, 29), eval_data.shape)

  def test_make_inputs_from_np_arrays(self):
    """Tests make_inputs_from_np_arrays() function."""
    train_data, _ = train_higgs.read_higgs_data(
        self.data_dir,
        train_start=0, train_count=15, eval_start=15, eval_count=5)
    input_fn, feature_columns = train_higgs.make_inputs_from_np_arrays(
        features_np=train_data[:, 1:], label_np=train_data[:, 0:1])

    # Check feature columns.
    self.assertEqual(28, len(feature_columns))
    bucketized_column_type = type(
        tf.feature_column.bucketized_column(
83
            tf.feature_column.numeric_column("feature_01"),
84
85
86
87
88
            boundaries=[0, 1, 2]))  # dummy boundaries.
    for feature_column in feature_columns:
      self.assertIsInstance(feature_column, bucketized_column_type)
      # At least 2 boundaries.
      self.assertGreaterEqual(len(feature_column.boundaries), 2)
89
    feature_names = ["feature_%02d" % (i+1) for i in range(28)]
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
    # Tests that the source column names of the bucketized columns match.
    self.assertAllEqual(feature_names,
                        [col.source_column.name for col in feature_columns])

    # Check features.
    features, labels = input_fn().make_one_shot_iterator().get_next()
    with tf.Session() as sess:
      features, labels = sess.run((features, labels))
    self.assertIsInstance(features, dict)
    self.assertAllEqual(feature_names, sorted(features.keys()))
    self.assertAllEqual([[15, 1]] * 28,
                        [features[name].shape for name in feature_names])
    # Validate actual values of some features.
    self.assertAllClose(
        [0.869293, 0.907542, 0.798834, 1.344384, 1.105009, 1.595839,
         0.409391, 0.933895, 1.405143, 1.176565, 0.945974, 0.739356,
         1.384097, 1.383548, 1.343652],
        np.squeeze(features[feature_names[0]], 1))
    self.assertAllClose(
        [-0.653674, -0.213641, 1.540659, -0.676015, 1.020974, 0.643109,
         -1.038338, -2.653732, 0.567342, 0.534315, 0.720819, -0.481741,
         1.409523, -0.307865, 1.474605],
        np.squeeze(features[feature_names[10]], 1))

  def test_end_to_end(self):
    """Tests end-to-end running."""
116
    model_dir = os.path.join(self.get_temp_dir(), "model")
117
118
    integration.run_synthetic(
        main=train_higgs.main, tmp_root=self.get_temp_dir(), extra_flags=[
119
120
121
122
123
124
125
            "--data_dir", self.data_dir,
            "--model_dir", model_dir,
            "--n_trees", "5",
            "--train_start", "0",
            "--train_count", "12",
            "--eval_start", "12",
            "--eval_count", "8",
126
127
        ],
        synth=False, max_train=None)
128
    self.assertTrue(tf.gfile.Exists(os.path.join(model_dir, "checkpoint")))
129
130
131

  def test_end_to_end_with_export(self):
    """Tests end-to-end running."""
132
133
    model_dir = os.path.join(self.get_temp_dir(), "model")
    export_dir = os.path.join(self.get_temp_dir(), "export")
134
135
    integration.run_synthetic(
        main=train_higgs.main, tmp_root=self.get_temp_dir(), extra_flags=[
136
137
138
139
140
141
142
143
            "--data_dir", self.data_dir,
            "--model_dir", model_dir,
            "--export_dir", export_dir,
            "--n_trees", "5",
            "--train_start", "0",
            "--train_count", "12",
            "--eval_start", "12",
            "--eval_count", "8",
144
145
        ],
        synth=False, max_train=None)
146
    self.assertTrue(tf.gfile.Exists(os.path.join(model_dir, "checkpoint")))
147
148
149
    self.assertTrue(tf.gfile.Exists(os.path.join(export_dir)))


150
if __name__ == "__main__":
151
  tf.test.main()