benchmark_uploader.py 6.47 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Library to upload benchmark generated by BenchmarkLogger to remote repo.

This library require google cloud bigquery lib as dependency, which can be
installed with:
  > pip install --upgrade google-cloud-bigquery
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import json

from google.cloud import bigquery
30
from google.cloud import exceptions
31

32
import tensorflow as tf
33
34
35


class BigQueryUploader(object):
36
  """Upload the benchmark and metric info from JSON input to BigQuery. """
37

38
  def __init__(self, gcp_project=None, credentials=None):
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
    """Initialized BigQueryUploader with proper setting.

    Args:
      gcp_project: string, the name of the GCP project that the log will be
        uploaded to. The default project name will be detected from local
        environment if no value is provided.
      credentials: google.auth.credentials. The credential to access the
        BigQuery service. The default service account credential will be
        detected from local environment if no value is provided. Please use
        google.oauth2.service_account.Credentials to load credential from local
        file for the case that the test is run out side of GCP.
    """
    self._bq_client = bigquery.Client(
        project=gcp_project, credentials=credentials)

54
55
  def upload_benchmark_run_json(
      self, dataset_name, table_name, run_id, run_json):
56
57
58
59
60
61
62
63
64
    """Upload benchmark run information to Bigquery.

    Args:
      dataset_name: string, the name of bigquery dataset where the data will be
        uploaded.
      table_name: string, the name of bigquery table under the dataset where
        the data will be uploaded.
      run_id: string, a unique ID that will be attached to the data, usually
        this is a UUID4 format.
65
      run_json: dict, the JSON data that contains the benchmark run info.
66
    """
67
68
69
70
71
    run_json["model_id"] = run_id
    self._upload_json(dataset_name, table_name, [run_json])

  def upload_benchmark_metric_json(
      self, dataset_name, table_name, run_id, metric_json_list):
72
73
74
75
76
77
78
79
80
81
    """Upload metric information to Bigquery.

    Args:
      dataset_name: string, the name of bigquery dataset where the data will be
        uploaded.
      table_name: string, the name of bigquery table under the dataset where
        the metric data will be uploaded. This is different from the
        benchmark_run table.
      run_id: string, a unique ID that will be attached to the data, usually
        this is a UUID4 format. This should be the same as the benchmark run_id.
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
      metric_json_list: list, a list of JSON object that record the metric info.
    """
    for m in metric_json_list:
      m["run_id"] = run_id
    self._upload_json(dataset_name, table_name, metric_json_list)

  def upload_benchmark_run_file(
      self, dataset_name, table_name, run_id, run_json_file):
    """Upload benchmark run information to Bigquery from input json file.

    Args:
      dataset_name: string, the name of bigquery dataset where the data will be
        uploaded.
      table_name: string, the name of bigquery table under the dataset where
        the data will be uploaded.
      run_id: string, a unique ID that will be attached to the data, usually
        this is a UUID4 format.
      run_json_file: string, the file path that contains the run JSON data.
    """
101
    with tf.io.gfile.GFile(run_json_file) as f:
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
      benchmark_json = json.load(f)
      self.upload_benchmark_run_json(
          dataset_name, table_name, run_id, benchmark_json)

  def upload_metric_file(
      self, dataset_name, table_name, run_id, metric_json_file):
    """Upload metric information to Bigquery from input json file.

    Args:
      dataset_name: string, the name of bigquery dataset where the data will be
        uploaded.
      table_name: string, the name of bigquery table under the dataset where
        the metric data will be uploaded. This is different from the
        benchmark_run table.
      run_id: string, a unique ID that will be attached to the data, usually
        this is a UUID4 format. This should be the same as the benchmark run_id.
      metric_json_file: string, the file path that contains the metric JSON
        data.
120
    """
121
    with tf.io.gfile.GFile(metric_json_file) as f:
122
      metrics = []
123
124
125
126
127
128
129
130
131
132
133
134
135
      for line in f:
        metrics.append(json.loads(line.strip()))
      self.upload_benchmark_metric_json(
          dataset_name, table_name, run_id, metrics)

  def _upload_json(self, dataset_name, table_name, json_list):
    # Find the unique table reference based on dataset and table name, so that
    # the data can be inserted to it.
    table_ref = self._bq_client.dataset(dataset_name).table(table_name)
    errors = self._bq_client.insert_rows_json(table_ref, json_list)
    if errors:
      tf.logging.error(
          "Failed to upload benchmark info to bigquery: {}".format(errors))
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157

  def insert_run_status(self, dataset_name, table_name, run_id, run_status):
    """Insert the run status in to Bigquery run status table."""
    query = ("INSERT {ds}.{tb} "
             "(run_id, status) "
             "VALUES('{rid}', '{status}')").format(
                 ds=dataset_name, tb=table_name, rid=run_id, status=run_status)
    try:
      self._bq_client.query(query=query).result()
    except exceptions.GoogleCloudError as e:
      tf.logging.error("Failed to insert run status: %s", e)

  def update_run_status(self, dataset_name, table_name, run_id, run_status):
    """Update the run status in in Bigquery run status table."""
    query = ("UPDATE {ds}.{tb} "
             "SET status = '{status}' "
             "WHERE run_id = '{rid}'").format(
                 ds=dataset_name, tb=table_name, status=run_status, rid=run_id)
    try:
      self._bq_client.query(query=query).result()
    except exceptions.GoogleCloudError as e:
      tf.logging.error("Failed to update run status: %s", e)