Unverified Commit 41a484fa authored by Hongtao Zhang's avatar Hongtao Zhang Committed by GitHub
Browse files

Bugfix: Avoid Unintended nvbandwidth Function Calls in All Benchmarks (#685)



Root Cause:

1. '_get_all_test_cases()' was called in '_parser' while '_parser' was
defined in the base class.
2.  in '_get_all_test_cases()', cmd path was not included.

Fix:

1. Remove '_get_all_test_cases()' from '_parser'.
2. Construct path for cmd.

---------
Co-authored-by: default avatarhongtaozhang <hongtaozhang@microsoft.com>
parent 45d06647
...@@ -52,8 +52,8 @@ def add_parser_arguments(self): ...@@ -52,8 +52,8 @@ def add_parser_arguments(self):
required=False, required=False,
help=( help=(
'Specify the test case(s) to execute by name only. ' 'Specify the test case(s) to execute by name only. '
'To view the available test case names, run the command "nvbandwidth -l" on the host. '
'If no specific test case is specified, all test cases will be executed by default.' 'If no specific test case is specified, all test cases will be executed by default.'
'Supported test cases are: ' + ', '.join(self._get_all_test_cases())
), ),
) )
...@@ -263,14 +263,15 @@ def _process_raw_result(self, cmd_idx, raw_output): ...@@ -263,14 +263,15 @@ def _process_raw_result(self, cmd_idx, raw_output):
self._result.add_result('abort', 1) self._result.add_result('abort', 1)
return False return False
@staticmethod def _get_all_test_cases(self):
def _get_all_test_cases(): command = os.path.join(self._args.bin_dir, self._bin_name) + ' --list'
command = 'nvbandwidth -l'
test_case_pattern = re.compile(r'(\d+),\s+([\w_]+):') test_case_pattern = re.compile(r'(\d+),\s+([\w_]+):')
try: try:
# Execute the command and capture output # Execute the command and capture output
result = subprocess.run(command, shell=True, text=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) result = subprocess.run(
command, shell=True, text=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=False
)
# Check the return code # Check the return code
if result.returncode != 0: if result.returncode != 0:
......
...@@ -138,6 +138,9 @@ def test_get_all_test_cases(self): ...@@ -138,6 +138,9 @@ def test_get_all_test_cases(self):
benchmark = benchmark_class(benchmark_name, parameters='') benchmark = benchmark_class(benchmark_name, parameters='')
# Call preprocess to initialize _args
assert benchmark._preprocess()
# Mock subprocess.run for successful execution with valid output # Mock subprocess.run for successful execution with valid output
with unittest.mock.patch('subprocess.run') as mock_run: with unittest.mock.patch('subprocess.run') as mock_run:
mock_run.return_value.returncode = 0 mock_run.return_value.returncode = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment