Unverified Commit d03d110f authored by guoshzhao's avatar guoshzhao Committed by GitHub
Browse files

Benchmarks: Add Feature - Sync the E2E training results among all workers for each step. (#287)

**Description**
Please write a brief description and link the related issue if have.

**Major Revision**
- Sync (do allreduce max) the E2E training results among all workers.
- Avoid using ':0' in metric name if there has only one rank having output.
parent d877ca23
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
from superbench.benchmarks import Precision, ModelAction, DistributedImpl, DistributedBackend, BenchmarkType, ReturnCode from superbench.benchmarks import Precision, ModelAction, DistributedImpl, DistributedBackend, BenchmarkType, ReturnCode
from superbench.benchmarks.base import Benchmark from superbench.benchmarks.base import Benchmark
from superbench.benchmarks.context import Enum from superbench.benchmarks.context import Enum
from superbench.benchmarks.reducer import ReduceType
class Optimizer(Enum): class Optimizer(Enum):
...@@ -344,6 +343,15 @@ def _benchmark(self): ...@@ -344,6 +343,15 @@ def _benchmark(self):
return True return True
def _is_finished(self, curr_step, curr_time): def _is_finished(self, curr_step, curr_time):
"""Judge whether the benchmarking should be stopped early or not.
Args:
curr_step (int): the current benchmarking step.
curr_time (float): the current time in seconds got from time.time().
Return:
True if the benchmarking should be stopped.
"""
total_steps = self._args.num_warmup + self._args.num_steps total_steps = self._args.num_warmup + self._args.num_steps
if ( if (
...@@ -354,6 +362,17 @@ def _is_finished(self, curr_step, curr_time): ...@@ -354,6 +362,17 @@ def _is_finished(self, curr_step, curr_time):
return False return False
def _sync_result(self, result):
"""Function to reduce the result to rank 0.
Args:
result (list): The result data to sync.
Return:
True if reduce result data successfully.
"""
return True
def __process_model_result(self, model_action, precision, step_times): def __process_model_result(self, model_action, precision, step_times):
"""Function to process raw results and save the summarized results. """Function to process raw results and save the summarized results.
...@@ -376,22 +395,26 @@ def __process_model_result(self, model_action, precision, step_times): ...@@ -376,22 +395,26 @@ def __process_model_result(self, model_action, precision, step_times):
precision_metric = {'float16': 'fp16', 'float32': 'fp32', 'float64': 'fp64', 'bfloat16': 'bf16'} precision_metric = {'float16': 'fp16', 'float32': 'fp32', 'float64': 'fp64', 'bfloat16': 'bf16'}
if precision.value in precision_metric.keys(): if precision.value in precision_metric.keys():
precision = precision_metric[precision.value] precision = precision_metric[precision.value]
metric = '{}_{}_step_time'.format(precision, model_action) metric_s = '{}_{}_step_time'.format(precision, model_action)
reduce_type = ReduceType.MAX if model_action is ModelAction.TRAIN else None metric_t = '{}_{}_throughput'.format(precision, model_action)
self._result.add_raw_data(metric, step_times)
self._result.add_result(metric, statistics.mean(step_times), reduce_type=reduce_type)
if model_action == ModelAction.INFERENCE:
self._process_percentile_result(metric, step_times, reduce_type=reduce_type)
# The unit of step time is millisecond, use it to calculate the throughput with the unit samples/sec. # The unit of step time is millisecond, use it to calculate the throughput with the unit samples/sec.
millisecond_per_second = 1000 millisecond_per_second = 1000
throughput = [millisecond_per_second / step_time * self._args.batch_size for step_time in step_times] throughput = [millisecond_per_second / step_time * self._args.batch_size for step_time in step_times]
metric = '{}_{}_throughput'.format(precision, model_action) self._result.add_raw_data(metric_s, step_times)
reduce_type = ReduceType.MIN if model_action is ModelAction.TRAIN else None self._result.add_raw_data(metric_t, throughput)
self._result.add_raw_data(metric, throughput)
self._result.add_result(metric, statistics.mean(throughput), reduce_type=reduce_type) if model_action == ModelAction.TRAIN:
if model_action == ModelAction.INFERENCE: if not self._sync_result(step_times):
self._process_percentile_result(metric, throughput, reduce_type=reduce_type) return False
if self._local_rank is None or self._local_rank == 0:
self._result.add_result(metric_s, statistics.mean(step_times))
throughput = [millisecond_per_second / step_time * self._args.batch_size for step_time in step_times]
self._result.add_result(metric_t, statistics.mean(throughput))
elif model_action == ModelAction.INFERENCE:
self._result.add_result(metric_s, statistics.mean(step_times))
self._result.add_result(metric_t, statistics.mean(throughput))
self._process_percentile_result(metric_s, step_times)
self._process_percentile_result(metric_t, throughput)
return True return True
......
...@@ -10,8 +10,8 @@ ...@@ -10,8 +10,8 @@
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from superbench.common.utils import logger from superbench.common.utils import logger
from superbench.benchmarks import Framework, ReturnCode from superbench.benchmarks import Framework, ReturnCode, DistributedBackend, DistributedImpl
from superbench.benchmarks.model_benchmarks.model_base import Optimizer, DistributedImpl, ModelBenchmark from superbench.benchmarks.model_benchmarks.model_base import Optimizer, ModelBenchmark
class PytorchBase(ModelBenchmark): class PytorchBase(ModelBenchmark):
...@@ -172,6 +172,36 @@ def _create_optimizer(self): ...@@ -172,6 +172,36 @@ def _create_optimizer(self):
return True return True
def _sync_result(self, result):
"""Function to reduce the result to rank 0.
Args:
result (list): The result data to sync.
Return:
True if reduce result data successfully.
"""
if not super()._sync_result(result):
return False
try:
if self._args.distributed_impl == DistributedImpl.DDP:
if self._args.distributed_backend == DistributedBackend.NCCL:
tensor = torch.as_tensor(result).cuda()
else:
tensor = torch.as_tensor(result)
torch.distributed.reduce(tensor, 0, op=torch.distributed.ReduceOp.MAX)
result = tensor.tolist()
except BaseException as e:
logger.error(
'Sync train result failed - model: {}, distributed implementation: {}, message: {}.'.format(
self._name, self._args.distributed_impl, str(e)
)
)
return False
return True
def _postprocess(self): def _postprocess(self):
"""Postprocess/cleanup operations after the benchmarking. """Postprocess/cleanup operations after the benchmarking.
......
...@@ -267,6 +267,33 @@ def __create_single_node_summary(self, node_path): # pragma: no cover # noqa: ...@@ -267,6 +267,33 @@ def __create_single_node_summary(self, node_path): # pragma: no cover # noqa:
return results_summary return results_summary
def __generate_metric_name(self, benchmark_name, metric, rank_count, run_count, curr_rank, curr_run):
"""Generate the summarized metrics name.
The format of metric name is:
{benchmark_name}/[{run_count}/]{metric_name}[:rank]
[run_count] and [rank] parts are optional.
Args:
benchmark_name (str): The benchmark name.
metric (str): The metric name.
rank_count (int): The total count of rank.
run_count (int): The total count of benchmarking.
curr_rank (int): The current rank index.
curr_run (int): The current run index.
Returns:
dict: Flattened result with metric as key.
"""
metric_name = benchmark_name
if run_count > 1:
metric_name = '{}/{}'.format(metric_name, curr_run)
metric_name = '{}/{}'.format(metric_name, metric)
if rank_count > 1:
metric_name = '{}:{}'.format(metric_name, curr_rank)
return metric_name
def __merge_benchmark_metrics(self, results_summary, reduce_ops): def __merge_benchmark_metrics(self, results_summary, reduce_ops):
"""Merge metrics of all benchmarks in one node. """Merge metrics of all benchmarks in one node.
...@@ -290,20 +317,18 @@ def __merge_benchmark_metrics(self, results_summary, reduce_ops): ...@@ -290,20 +317,18 @@ def __merge_benchmark_metrics(self, results_summary, reduce_ops):
if reduce_ops[metric_name] is not None: if reduce_ops[metric_name] is not None:
reduce_func = Reducer.get_reduce_func(ReduceType(reduce_ops[metric_name])) reduce_func = Reducer.get_reduce_func(ReduceType(reduce_ops[metric_name]))
values = [reduce_func(list(result)) for result in zip(*results_summary[benchmark_name][metric])] values = [reduce_func(list(result)) for result in zip(*results_summary[benchmark_name][metric])]
for run_count in range(len(values)): for run in range(len(values)):
if len(values) > 1: metric_name = self.__generate_metric_name(benchmark_name, metric, 1, len(values), 0, run)
metric_name = '{}/{}/{}'.format(benchmark_name, run_count, metric) metrics_summary[metric_name] = values[run]
else:
metric_name = '{}/{}'.format(benchmark_name, metric)
metrics_summary[metric_name] = values[run_count]
else: else:
for rank in range(len(results_summary[benchmark_name][metric])): rank_count = len(results_summary[benchmark_name][metric])
for run_count in range(len(results_summary[benchmark_name][metric][rank])): for rank, rank_value in enumerate(results_summary[benchmark_name][metric]):
if len(results_summary[benchmark_name][metric][rank]) > 1: run_count = len(rank_value)
metric_name = '{}/{}/{}:{}'.format(benchmark_name, run_count, metric, rank) for run, run_value in enumerate(rank_value):
else: metric_name = self.__generate_metric_name(
metric_name = '{}/{}:{}'.format(benchmark_name, metric, rank) benchmark_name, metric, rank_count, run_count, rank, run
metrics_summary[metric_name] = results_summary[benchmark_name][metric][rank][run_count] )
metrics_summary[metric_name] = run_value
return metrics_summary return metrics_summary
......
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
"""Tests for BenchmarkRegistry module.""" """Tests for BenchmarkRegistry module."""
import json
from superbench.benchmarks import Platform, Framework, Precision, BenchmarkRegistry, BenchmarkType, ReturnCode from superbench.benchmarks import Platform, Framework, Precision, BenchmarkRegistry, BenchmarkType, ReturnCode
from superbench.benchmarks.model_benchmarks import ModelBenchmark from superbench.benchmarks.model_benchmarks import ModelBenchmark
...@@ -226,11 +228,11 @@ def test_train(): ...@@ -226,11 +228,11 @@ def test_train():
'"fp32_train_step_time": [[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]], ' '"fp32_train_step_time": [[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]], '
'"fp32_train_throughput": [[16000.0, 16000.0, 16000.0, 16000.0, 16000.0, 16000.0, 16000.0, 16000.0]]}, ' '"fp32_train_throughput": [[16000.0, 16000.0, 16000.0, 16000.0, 16000.0, 16000.0, 16000.0, 16000.0]]}, '
'"result": {"return_code": [0], "fp32_train_step_time": [2.0], "fp32_train_throughput": [16000.0]}, ' '"result": {"return_code": [0], "fp32_train_step_time": [2.0], "fp32_train_throughput": [16000.0]}, '
'"reduce_op": {"return_code": null, "fp32_train_step_time": "max", "fp32_train_throughput": "min"}}' '"reduce_op": {"return_code": null, "fp32_train_step_time": null, "fp32_train_throughput": null}}'
) )
assert (benchmark._preprocess()) assert (benchmark._preprocess())
assert (benchmark._ModelBenchmark__train(Precision.FLOAT32)) assert (benchmark._ModelBenchmark__train(Precision.FLOAT32))
assert (benchmark.serialized_result == expected_result) assert (json.loads(benchmark.serialized_result) == json.loads(expected_result))
# Step time list is empty (simulate training failure). # Step time list is empty (simulate training failure).
benchmark = create_benchmark('--num_steps 0') benchmark = create_benchmark('--num_steps 0')
...@@ -241,7 +243,7 @@ def test_train(): ...@@ -241,7 +243,7 @@ def test_train():
) )
assert (benchmark._preprocess()) assert (benchmark._preprocess())
assert (benchmark._ModelBenchmark__train(Precision.FLOAT32) is False) assert (benchmark._ModelBenchmark__train(Precision.FLOAT32) is False)
assert (benchmark.serialized_result == expected_result) assert (json.loads(benchmark.serialized_result) == json.loads(expected_result))
def test_inference(): def test_inference():
...@@ -270,7 +272,7 @@ def test_inference(): ...@@ -270,7 +272,7 @@ def test_inference():
) )
assert (benchmark._preprocess()) assert (benchmark._preprocess())
assert (benchmark._ModelBenchmark__inference(Precision.FLOAT16)) assert (benchmark._ModelBenchmark__inference(Precision.FLOAT16))
assert (benchmark.serialized_result == expected_result) assert (json.loads(benchmark.serialized_result) == json.loads(expected_result))
# Step time list is empty (simulate inference failure). # Step time list is empty (simulate inference failure).
benchmark = create_benchmark('--num_steps 0') benchmark = create_benchmark('--num_steps 0')
...@@ -281,7 +283,7 @@ def test_inference(): ...@@ -281,7 +283,7 @@ def test_inference():
) )
assert (benchmark._preprocess()) assert (benchmark._preprocess())
assert (benchmark._ModelBenchmark__inference(Precision.FLOAT16) is False) assert (benchmark._ModelBenchmark__inference(Precision.FLOAT16) is False)
assert (benchmark.serialized_result == expected_result) assert (json.loads(benchmark.serialized_result) == json.loads(expected_result))
def test_benchmark(): def test_benchmark():
...@@ -318,10 +320,10 @@ def test_benchmark(): ...@@ -318,10 +320,10 @@ def test_benchmark():
'"fp16_train_throughput": [[16000.0, 16000.0, 16000.0, 16000.0, 16000.0, 16000.0, 16000.0, 16000.0]]}, ' '"fp16_train_throughput": [[16000.0, 16000.0, 16000.0, 16000.0, 16000.0, 16000.0, 16000.0, 16000.0]]}, '
'"result": {"return_code": [0], "fp32_train_step_time": [2.0], "fp32_train_throughput": [16000.0], ' '"result": {"return_code": [0], "fp32_train_step_time": [2.0], "fp32_train_throughput": [16000.0], '
'"fp16_train_step_time": [2.0], "fp16_train_throughput": [16000.0]}, ' '"fp16_train_step_time": [2.0], "fp16_train_throughput": [16000.0]}, '
'"reduce_op": {"return_code": null, "fp32_train_step_time": "max", "fp32_train_throughput": "min", ' '"reduce_op": {"return_code": null, "fp32_train_step_time": null, "fp32_train_throughput": null, '
'"fp16_train_step_time": "max", "fp16_train_throughput": "min"}}' '"fp16_train_step_time": null, "fp16_train_throughput": null}}'
) )
assert (benchmark.serialized_result == expected_serialized_result) assert (json.loads(benchmark.serialized_result) == json.loads(expected_serialized_result))
# Negative case for _benchmark() - no supported precision found. # Negative case for _benchmark() - no supported precision found.
benchmark = create_benchmark('--precision int16') benchmark = create_benchmark('--precision int16')
......
...@@ -202,30 +202,24 @@ def test_merge_benchmark_metrics(self): ...@@ -202,30 +202,24 @@ def test_merge_benchmark_metrics(self):
'{"kernel-launch": {"overhead_event": [[0.00583], [0.00545], [0.00581], [0.00572], [0.00559], [0.00591], ' '{"kernel-launch": {"overhead_event": [[0.00583], [0.00545], [0.00581], [0.00572], [0.00559], [0.00591], '
'[0.00562], [0.00586]], "overhead_wall": [[0.01018], [0.01039], [0.01067], [0.01079], [0.00978], ' '[0.00562], [0.00586]], "overhead_wall": [[0.01018], [0.01039], [0.01067], [0.01079], [0.00978], '
'[0.01085], [0.01036], [0.01033]]}, "resnet_models/pytorch-resnet50": {"steptime_train_float32": ' '[0.01085], [0.01036], [0.01033]]}, "resnet_models/pytorch-resnet50": {"steptime_train_float32": '
'[[252.03], [250.53], [253.75], [250.61], [252.86], [252.58], [251.15], [252.83]], ' '[[252.03]], "throughput_train_float32": [[764.57]], "steptime_train_float16": [[198.36]], '
'"throughput_train_float32": [[764.57], [767.83], [762.19], [767.31], [763.41], [764.31], [766.43], ' '"throughput_train_float16": [[972.64]]}, "resnet_models/pytorch-resnet101": {"steptime_train_float32": '
'[763.38]], "steptime_train_float16": [[198.36], [196.85], [200.55], [198.07], [199.41], [199.20], ' '[[385.53]], "throughput_train_float32": [[499.39]], "steptime_train_float16": [[307.49]], '
'[199.07], [199.34]], "throughput_train_float16": [[972.64], [977.31], [969.58], [974.33], [972.87], ' '"throughput_train_float16": [[627.21]]}, "pytorch-sharding-matmul": {"allreduce": [[10.56, 10.66], '
'[972.73], [972.46], [972.46]]}, "resnet_models/pytorch-resnet101": {"steptime_train_float32": [[385.53], ' '[10.87, 10.32], [10.56, 10.45], [10.56, 10.60], [10.56, 10.45], [10.56, 10.38], [10.56, 10.33], '
'[384.05], [386.98], [385.12], [385.47], [385.81], [384.90], [386.65]], "throughput_train_float32": ' '[10.56, 10.69]], "allgather": [[10.08, 10.10], [10.08, 10.16], [10.08, 10.06], [10.56, 10.04], '
'[[499.39], [500.69], [498.57], [499.83], [499.51], [499.27], [499.94], [498.65]], ' '[10.08, 10.05], [10.08, 10.09], [10.08, 10.08], [10.08, 10.06]]}}'
'"steptime_train_float16": [[307.49], [307.13], [310.31], [307.64], [308.68], [309.61], [307.71], '
'[309.95]], "throughput_train_float16": [[627.21], [627.34], [624.85], [626.76], [626.26], [625.12], '
'[626.92], [625.02]]}, "pytorch-sharding-matmul": {"allreduce": [[10.56, 10.66], [10.87, 10.32], '
'[10.56, 10.45], [10.56, 10.60], [10.56, 10.45], [10.56, 10.38], [10.56, 10.33], [10.56, 10.69]], '
'"allgather": [[10.08, 10.10], [10.08, 10.16], [10.08, 10.06], [10.56, 10.04], [10.08, 10.05], '
'[10.08, 10.09], [10.08, 10.08], [10.08, 10.06]]}}'
) )
reduce_ops = json.loads( reduce_ops = json.loads(
'{"kernel-launch/overhead_event": null, "kernel-launch/overhead_wall": null, ' '{"kernel-launch/overhead_event": null, "kernel-launch/overhead_wall": null, '
'"resnet_models/pytorch-resnet50/steptime_train_float32": "max", ' '"resnet_models/pytorch-resnet50/steptime_train_float32": null, '
'"resnet_models/pytorch-resnet50/throughput_train_float32": "min", ' '"resnet_models/pytorch-resnet50/throughput_train_float32": null, '
'"resnet_models/pytorch-resnet50/steptime_train_float16": "max", ' '"resnet_models/pytorch-resnet50/steptime_train_float16": null, '
'"resnet_models/pytorch-resnet50/throughput_train_float16": "min", ' '"resnet_models/pytorch-resnet50/throughput_train_float16": null, '
'"resnet_models/pytorch-resnet101/steptime_train_float32": "max", ' '"resnet_models/pytorch-resnet101/steptime_train_float32": null, '
'"resnet_models/pytorch-resnet101/throughput_train_float32": "min", ' '"resnet_models/pytorch-resnet101/throughput_train_float32": null, '
'"resnet_models/pytorch-resnet101/steptime_train_float16": "max", ' '"resnet_models/pytorch-resnet101/steptime_train_float16": null, '
'"resnet_models/pytorch-resnet101/throughput_train_float16": "min", ' '"resnet_models/pytorch-resnet101/throughput_train_float16": null, '
'"pytorch-sharding-matmul/allreduce": "max", "pytorch-sharding-matmul/allgather": "max"}' '"pytorch-sharding-matmul/allreduce": "max", "pytorch-sharding-matmul/allgather": "max"}'
) )
...@@ -238,14 +232,14 @@ def test_merge_benchmark_metrics(self): ...@@ -238,14 +232,14 @@ def test_merge_benchmark_metrics(self):
'"kernel-launch/overhead_wall:2": 0.01067, "kernel-launch/overhead_wall:3": 0.01079, ' '"kernel-launch/overhead_wall:2": 0.01067, "kernel-launch/overhead_wall:3": 0.01079, '
'"kernel-launch/overhead_wall:4": 0.00978, "kernel-launch/overhead_wall:5": 0.01085, ' '"kernel-launch/overhead_wall:4": 0.00978, "kernel-launch/overhead_wall:5": 0.01085, '
'"kernel-launch/overhead_wall:6": 0.01036, "kernel-launch/overhead_wall:7": 0.01033, ' '"kernel-launch/overhead_wall:6": 0.01036, "kernel-launch/overhead_wall:7": 0.01033, '
'"resnet_models/pytorch-resnet50/steptime_train_float32": 253.75, ' '"resnet_models/pytorch-resnet50/steptime_train_float32": 252.03, '
'"resnet_models/pytorch-resnet50/throughput_train_float32": 762.19, ' '"resnet_models/pytorch-resnet50/throughput_train_float32": 764.57, '
'"resnet_models/pytorch-resnet50/steptime_train_float16": 200.55, ' '"resnet_models/pytorch-resnet50/steptime_train_float16": 198.36, '
'"resnet_models/pytorch-resnet50/throughput_train_float16": 969.58, ' '"resnet_models/pytorch-resnet50/throughput_train_float16": 972.64, '
'"resnet_models/pytorch-resnet101/steptime_train_float32": 386.98, ' '"resnet_models/pytorch-resnet101/steptime_train_float32": 385.53, '
'"resnet_models/pytorch-resnet101/throughput_train_float32": 498.57, ' '"resnet_models/pytorch-resnet101/throughput_train_float32": 499.39, '
'"resnet_models/pytorch-resnet101/steptime_train_float16": 310.31, ' '"resnet_models/pytorch-resnet101/steptime_train_float16": 307.49, '
'"resnet_models/pytorch-resnet101/throughput_train_float16": 624.85, ' '"resnet_models/pytorch-resnet101/throughput_train_float16": 627.21, '
'"pytorch-sharding-matmul/0/allreduce": 10.87, "pytorch-sharding-matmul/1/allreduce": 10.69, ' '"pytorch-sharding-matmul/0/allreduce": 10.87, "pytorch-sharding-matmul/1/allreduce": 10.69, '
'"pytorch-sharding-matmul/0/allgather": 10.56, "pytorch-sharding-matmul/1/allgather": 10.16}' '"pytorch-sharding-matmul/0/allgather": 10.56, "pytorch-sharding-matmul/1/allgather": 10.16}'
) )
...@@ -289,3 +283,63 @@ def test_merge_monitor_metrics(self): ...@@ -289,3 +283,63 @@ def test_merge_monitor_metrics(self):
'monitor/gpu_uncorrected_ecc:7': 0 'monitor/gpu_uncorrected_ecc:7': 0
} }
self.assertEqual(self.runner._SuperBenchRunner__merge_monitor_metrics(path), expected) self.assertEqual(self.runner._SuperBenchRunner__merge_monitor_metrics(path), expected)
def test_generate_metric_name(self):
"""Test __generate_metric_name."""
"""(self, benchmark_name, metric, rank_count, run_count, curr_rank, curr_run):"""
test_cases = [
{
'benchmark_name': 'kernel-launch',
'metric': 'overhead_event',
'rank_count': 8,
'run_count': 2,
'curr_rank': 0,
'curr_run': 0,
'expected': 'kernel-launch/0/overhead_event:0',
},
{
'benchmark_name': 'kernel-launch',
'metric': 'overhead_event',
'rank_count': 8,
'run_count': 2,
'curr_rank': 2,
'curr_run': 1,
'expected': 'kernel-launch/1/overhead_event:2',
},
{
'benchmark_name': 'kernel-launch',
'metric': 'overhead_event',
'rank_count': 1,
'run_count': 1,
'curr_rank': 0,
'curr_run': 0,
'expected': 'kernel-launch/overhead_event',
},
{
'benchmark_name': 'resnet_models/pytorch-resnet50',
'metric': 'fp32_train_step_time',
'rank_count': 1,
'run_count': 2,
'curr_rank': 0,
'curr_run': 1,
'expected': 'resnet_models/pytorch-resnet50/1/fp32_train_step_time',
},
{
'benchmark_name': 'resnet_models/pytorch-resnet50',
'metric': 'fp32_train_step_time',
'rank_count': 1,
'run_count': 1,
'curr_rank': 0,
'curr_run': 0,
'expected': 'resnet_models/pytorch-resnet50/fp32_train_step_time',
},
]
for test_case in test_cases:
with self.subTest(msg='Testing with case', test_case=test_case):
self.assertEqual(
self.runner._SuperBenchRunner__generate_metric_name(
test_case['benchmark_name'], test_case['metric'], test_case['rank_count'],
test_case['run_count'], test_case['curr_rank'], test_case['curr_run']
), test_case['expected']
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment