benchmark_uploader.py 5.44 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Library to upload benchmark generated by BenchmarkLogger to remote repo.

This library require google cloud bigquery lib as dependency, which can be
installed with:
  > pip install --upgrade google-cloud-bigquery
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import json

from google.cloud import bigquery

31
import tensorflow as tf
32
33
34


class BigQueryUploader(object):
35
  """Upload the benchmark and metric info from JSON input to BigQuery. """
36

37
  def __init__(self, gcp_project=None, credentials=None):
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
    """Initialized BigQueryUploader with proper setting.

    Args:
      gcp_project: string, the name of the GCP project that the log will be
        uploaded to. The default project name will be detected from local
        environment if no value is provided.
      credentials: google.auth.credentials. The credential to access the
        BigQuery service. The default service account credential will be
        detected from local environment if no value is provided. Please use
        google.oauth2.service_account.Credentials to load credential from local
        file for the case that the test is run out side of GCP.
    """
    self._bq_client = bigquery.Client(
        project=gcp_project, credentials=credentials)

53
54
  def upload_benchmark_run_json(
      self, dataset_name, table_name, run_id, run_json):
55
56
57
58
59
60
61
62
63
    """Upload benchmark run information to Bigquery.

    Args:
      dataset_name: string, the name of bigquery dataset where the data will be
        uploaded.
      table_name: string, the name of bigquery table under the dataset where
        the data will be uploaded.
      run_id: string, a unique ID that will be attached to the data, usually
        this is a UUID4 format.
64
      run_json: dict, the JSON data that contains the benchmark run info.
65
    """
66
67
68
69
70
    run_json["model_id"] = run_id
    self._upload_json(dataset_name, table_name, [run_json])

  def upload_benchmark_metric_json(
      self, dataset_name, table_name, run_id, metric_json_list):
71
72
73
74
75
76
77
78
79
80
    """Upload metric information to Bigquery.

    Args:
      dataset_name: string, the name of bigquery dataset where the data will be
        uploaded.
      table_name: string, the name of bigquery table under the dataset where
        the metric data will be uploaded. This is different from the
        benchmark_run table.
      run_id: string, a unique ID that will be attached to the data, usually
        this is a UUID4 format. This should be the same as the benchmark run_id.
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
      metric_json_list: list, a list of JSON object that record the metric info.
    """
    for m in metric_json_list:
      m["run_id"] = run_id
    self._upload_json(dataset_name, table_name, metric_json_list)

  def upload_benchmark_run_file(
      self, dataset_name, table_name, run_id, run_json_file):
    """Upload benchmark run information to Bigquery from input json file.

    Args:
      dataset_name: string, the name of bigquery dataset where the data will be
        uploaded.
      table_name: string, the name of bigquery table under the dataset where
        the data will be uploaded.
      run_id: string, a unique ID that will be attached to the data, usually
        this is a UUID4 format.
      run_json_file: string, the file path that contains the run JSON data.
    """
    with tf.gfile.GFile(run_json_file) as f:
      benchmark_json = json.load(f)
      self.upload_benchmark_run_json(
          dataset_name, table_name, run_id, benchmark_json)

  def upload_metric_file(
      self, dataset_name, table_name, run_id, metric_json_file):
    """Upload metric information to Bigquery from input json file.

    Args:
      dataset_name: string, the name of bigquery dataset where the data will be
        uploaded.
      table_name: string, the name of bigquery table under the dataset where
        the metric data will be uploaded. This is different from the
        benchmark_run table.
      run_id: string, a unique ID that will be attached to the data, usually
        this is a UUID4 format. This should be the same as the benchmark run_id.
      metric_json_file: string, the file path that contains the metric JSON
        data.
119
    """
120
    with tf.gfile.GFile(metric_json_file) as f:
121
      metrics = []
122
123
124
125
126
127
128
129
130
131
132
133
134
      for line in f:
        metrics.append(json.loads(line.strip()))
      self.upload_benchmark_metric_json(
          dataset_name, table_name, run_id, metrics)

  def _upload_json(self, dataset_name, table_name, json_list):
    # Find the unique table reference based on dataset and table name, so that
    # the data can be inserted to it.
    table_ref = self._bq_client.dataset(dataset_name).table(table_name)
    errors = self._bq_client.insert_rows_json(table_ref, json_list)
    if errors:
      tf.logging.error(
          "Failed to upload benchmark info to bigquery: {}".format(errors))