benchmark_uploader_test.py 3.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Tests for benchmark_uploader."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import json
import os
import tempfile
import unittest
from mock import MagicMock
from mock import patch

import tensorflow as tf  # pylint: disable=g-bad-import-order

try:
  from google.cloud import bigquery
  from official.benchmark import benchmark_uploader
except ImportError:
  bigquery = None
  benchmark_uploader = None


@unittest.skipIf(bigquery is None, 'Bigquery dependency is not installed.')
class BigQueryUploaderTest(tf.test.TestCase):

  @patch.object(bigquery, 'Client')
  def setUp(self, mock_bigquery):
    self.mock_client = mock_bigquery.return_value
    self.mock_dataset = MagicMock(name="dataset")
    self.mock_table = MagicMock(name="table")
    self.mock_client.dataset.return_value = self.mock_dataset
    self.mock_dataset.table.return_value = self.mock_table
    self.mock_client.insert_rows_json.return_value = []

    self.benchmark_uploader = benchmark_uploader.BigQueryUploader()
    self.benchmark_uploader._bq_client = self.mock_client

    self.log_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
    with open(os.path.join(self.log_dir, 'metric.log'), 'a') as f:
      json.dump({'name': 'accuracy', 'value': 1.0}, f)
      f.write("\n")
      json.dump({'name': 'loss', 'value': 0.5}, f)
      f.write("\n")
    with open(os.path.join(self.log_dir, 'run.log'), 'w') as f:
      json.dump({'model_name': 'value'}, f)

  def tearDown(self):
    tf.gfile.DeleteRecursively(self.get_temp_dir())

  def test_upload_benchmark_run_json(self):
    self.benchmark_uploader.upload_benchmark_run_json(
        'dataset', 'table', 'run_id', {'model_name': 'value'})

    self.mock_client.insert_rows_json.assert_called_once_with(
        self.mock_table, [{'model_name': 'value', 'model_id': 'run_id'}])

  def test_upload_benchmark_metric_json(self):
    metric_json_list = [
        {'name': 'accuracy', 'value': 1.0},
        {'name': 'loss', 'value': 0.5}
    ]
    expected_params = [
        {'run_id': 'run_id', 'name': 'accuracy', 'value': 1.0},
        {'run_id': 'run_id', 'name': 'loss', 'value': 0.5}
    ]
    self.benchmark_uploader.upload_benchmark_metric_json(
        'dataset', 'table', 'run_id', metric_json_list)
    self.mock_client.insert_rows_json.assert_called_once_with(
        self.mock_table, expected_params)

  def test_upload_benchmark_run_file(self):
    self.benchmark_uploader.upload_benchmark_run_file(
        'dataset', 'table', 'run_id', os.path.join(self.log_dir, 'run.log'))

    self.mock_client.insert_rows_json.assert_called_once_with(
        self.mock_table, [{'model_name': 'value', 'model_id': 'run_id'}])

  def test_upload_metric_file(self):
    self.benchmark_uploader.upload_metric_file(
        'dataset', 'table', 'run_id',
        os.path.join(self.log_dir, 'metric.log'))
    expected_params = [
        {'run_id': 'run_id', 'name': 'accuracy', 'value': 1.0},
        {'run_id': 'run_id', 'name': 'loss', 'value': 0.5}
    ]
    self.mock_client.insert_rows_json.assert_called_once_with(
        self.mock_table, expected_params)


if __name__ == '__main__':
  tf.test.main()