benchmark_uploader_test.py 4.47 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Tests for benchmark_uploader."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import json
import os
import tempfile
import unittest
from mock import MagicMock
from mock import patch

import tensorflow as tf  # pylint: disable=g-bad-import-order

try:
  from google.cloud import bigquery
  from official.benchmark import benchmark_uploader
except ImportError:
  bigquery = None
  benchmark_uploader = None


39
@unittest.skipIf(bigquery is None, "Bigquery dependency is not installed.")
40
41
class BigQueryUploaderTest(tf.test.TestCase):

42
  @patch.object(bigquery, "Client")
43
44
45
46
47
48
49
50
51
52
53
54
  def setUp(self, mock_bigquery):
    self.mock_client = mock_bigquery.return_value
    self.mock_dataset = MagicMock(name="dataset")
    self.mock_table = MagicMock(name="table")
    self.mock_client.dataset.return_value = self.mock_dataset
    self.mock_dataset.table.return_value = self.mock_table
    self.mock_client.insert_rows_json.return_value = []

    self.benchmark_uploader = benchmark_uploader.BigQueryUploader()
    self.benchmark_uploader._bq_client = self.mock_client

    self.log_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
55
56
    with open(os.path.join(self.log_dir, "metric.log"), "a") as f:
      json.dump({"name": "accuracy", "value": 1.0}, f)
57
      f.write("\n")
58
      json.dump({"name": "loss", "value": 0.5}, f)
59
      f.write("\n")
60
61
    with open(os.path.join(self.log_dir, "run.log"), "w") as f:
      json.dump({"model_name": "value"}, f)
62
63

  def tearDown(self):
64
    tf.io.gfile.rmtree(self.get_temp_dir())
65
66
67

  def test_upload_benchmark_run_json(self):
    self.benchmark_uploader.upload_benchmark_run_json(
68
        "dataset", "table", "run_id", {"model_name": "value"})
69
70

    self.mock_client.insert_rows_json.assert_called_once_with(
71
        self.mock_table, [{"model_name": "value", "model_id": "run_id"}])
72
73
74

  def test_upload_benchmark_metric_json(self):
    metric_json_list = [
75
76
        {"name": "accuracy", "value": 1.0},
        {"name": "loss", "value": 0.5}
77
78
    ]
    expected_params = [
79
80
        {"run_id": "run_id", "name": "accuracy", "value": 1.0},
        {"run_id": "run_id", "name": "loss", "value": 0.5}
81
82
    ]
    self.benchmark_uploader.upload_benchmark_metric_json(
83
        "dataset", "table", "run_id", metric_json_list)
84
85
86
87
88
    self.mock_client.insert_rows_json.assert_called_once_with(
        self.mock_table, expected_params)

  def test_upload_benchmark_run_file(self):
    self.benchmark_uploader.upload_benchmark_run_file(
89
        "dataset", "table", "run_id", os.path.join(self.log_dir, "run.log"))
90
91

    self.mock_client.insert_rows_json.assert_called_once_with(
92
        self.mock_table, [{"model_name": "value", "model_id": "run_id"}])
93
94
95

  def test_upload_metric_file(self):
    self.benchmark_uploader.upload_metric_file(
96
97
        "dataset", "table", "run_id",
        os.path.join(self.log_dir, "metric.log"))
98
    expected_params = [
99
100
        {"run_id": "run_id", "name": "accuracy", "value": 1.0},
        {"run_id": "run_id", "name": "loss", "value": 0.5}
101
102
103
104
    ]
    self.mock_client.insert_rows_json.assert_called_once_with(
        self.mock_table, expected_params)

105
106
107
108
109
110
111
  def test_insert_run_status(self):
    self.benchmark_uploader.insert_run_status(
        "dataset", "table", "run_id", "status")
    expected_query = ("INSERT dataset.table "
                      "(run_id, status) "
                      "VALUES('run_id', 'status')")
    self.mock_client.query.assert_called_once_with(query=expected_query)
112

113
114
115
116
117
118
119
120
121
122
  def test_update_run_status(self):
    self.benchmark_uploader.update_run_status(
        "dataset", "table", "run_id", "status")
    expected_query = ("UPDATE dataset.table "
                      "SET status = 'status' "
                      "WHERE run_id = 'run_id'")
    self.mock_client.query.assert_called_once_with(query=expected_query)


if __name__ == "__main__":
123
  tf.test.main()