# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Binary to upload benchmark generated by BenchmarkLogger to remote repo. This library require google cloud bigquery lib as dependency, which can be installed with: > pip install --upgrade google-cloud-bigquery """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import os import sys import uuid from absl import app as absl_app from absl import flags from official.benchmark import benchmark_uploader from official.utils.flags import core as flags_core from official.utils.logs import logger def main(_): if not flags.FLAGS.benchmark_log_dir: print("Usage: benchmark_uploader.py --benchmark_log_dir=/some/dir") sys.exit(1) uploader = benchmark_uploader.BigQueryUploader( gcp_project=flags.FLAGS.gcp_project) run_id = str(uuid.uuid4()) run_json_file = os.path.join( flags.FLAGS.benchmark_log_dir, logger.BENCHMARK_RUN_LOG_FILE_NAME) metric_json_file = os.path.join( flags.FLAGS.benchmark_log_dir, logger.METRIC_LOG_FILE_NAME) uploader.upload_benchmark_run_file( flags.FLAGS.bigquery_data_set, flags.FLAGS.bigquery_run_table, run_id, run_json_file) uploader.upload_metric_file( flags.FLAGS.bigquery_data_set, flags.FLAGS.bigquery_metric_table, run_id, metric_json_file) # Assume the run finished successfully before user invoke the upload script. uploader.insert_run_status( flags.FLAGS.bigquery_data_set, flags.FLAGS.bigquery_run_status_table, run_id, logger.RUN_STATUS_SUCCESS) if __name__ == "__main__": flags_core.define_benchmark() flags.adopt_module_key_flags(flags_core) absl_app.run(main=main)