# # SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import argparse import os import tensorrt as trt from loguru import logger from .common_runtime import * try: # Sometimes python does not understand FileNotFoundError FileNotFoundError except NameError: FileNotFoundError = IOError def GiB(val): return val * 1 << 30 def add_help(description): parser = argparse.ArgumentParser(description=description, formatter_class=argparse.ArgumentDefaultsHelpFormatter) args, _ = parser.parse_known_args() def find_sample_data(description="Runs a TensorRT Python sample", subfolder="", find_files=[], err_msg=""): """ Parses sample arguments. Args: description (str): Description of the sample. subfolder (str): The subfolder containing data relevant to this sample find_files (str): A list of filenames to find. Each filename will be replaced with an absolute path. Returns: str: Path of data directory. """ # Standard command-line arguments for all samples. kDEFAULT_DATA_ROOT = os.path.join(os.sep, "usr", "src", "tensorrt", "data") parser = argparse.ArgumentParser(description=description, formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument( "-d", "--datadir", help="Location of the TensorRT sample data directory, and any additional data directories.", action="append", default=[kDEFAULT_DATA_ROOT], ) args, _ = parser.parse_known_args() def get_data_path(data_dir): # If the subfolder exists, append it to the path, otherwise use the provided path as-is. data_path = os.path.join(data_dir, subfolder) if not os.path.exists(data_path): if data_dir != kDEFAULT_DATA_ROOT: logger.info("WARNING: " + data_path + " does not exist. Trying " + data_dir + " instead.") data_path = data_dir # Make sure data directory exists. if not (os.path.exists(data_path)) and data_dir != kDEFAULT_DATA_ROOT: logger.info("WARNING: {:} does not exist. Please provide the correct data path with the -d option.".format(data_path)) return data_path data_paths = [get_data_path(data_dir) for data_dir in args.datadir] return data_paths, locate_files(data_paths, find_files, err_msg) def locate_files(data_paths, filenames, err_msg=""): """ Locates the specified files in the specified data directories. If a file exists in multiple data directories, the first directory is used. Args: data_paths (List[str]): The data directories. filename (List[str]): The names of the files to find. Returns: List[str]: The absolute paths of the files. Raises: FileNotFoundError if a file could not be located. """ found_files = [None] * len(filenames) for data_path in data_paths: # Find all requested files. for index, (found, filename) in enumerate(zip(found_files, filenames)): if not found: file_path = os.path.abspath(os.path.join(data_path, filename)) if os.path.exists(file_path): found_files[index] = file_path # Check that all files were found for f, filename in zip(found_files, filenames): if not f or not os.path.exists(f): raise FileNotFoundError("Could not find {:}. Searched in data paths: {:}\n{:}".format(filename, data_paths, err_msg)) return found_files # Sets up the builder to use the timing cache file, and creates it if it does not already exist def setup_timing_cache(config: trt.IBuilderConfig, timing_cache_path: os.PathLike): buffer = b"" if os.path.exists(timing_cache_path): with open(timing_cache_path, mode="rb") as timing_cache_file: buffer = timing_cache_file.read() timing_cache: trt.ITimingCache = config.create_timing_cache(buffer) config.set_timing_cache(timing_cache, True) # Saves the config's timing cache to file def save_timing_cache(config: trt.IBuilderConfig, timing_cache_path: os.PathLike): timing_cache: trt.ITimingCache = config.get_timing_cache() with open(timing_cache_path, "wb") as timing_cache_file: timing_cache_file.write(memoryview(timing_cache.serialize()))