"vscode:/vscode.git/clone" did not exist on "794da715d409aa8d5b092a1692d7802624066ab3"
Unverified Commit 2e5b2342 authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

issue/547 - improved test output (#550)

parents bf3395f5 991f534c
...@@ -63,14 +63,22 @@ class TestCase: ...@@ -63,14 +63,22 @@ class TestCase:
if inp.init_mode != TensorInitializer.RANDOM if inp.init_mode != TensorInitializer.RANDOM
else "" else ""
) )
if hasattr(inp, "is_contiguous") and not inp.is_contiguous: # Show shape and strides for non-contiguous tensors
input_strs.append(f"strided_tensor{inp.shape}{dtype_str}{init_str}") if (
hasattr(inp, "is_contiguous")
and not inp.is_contiguous
and inp.strides
):
strides_str = f", strides={inp.strides}"
input_strs.append(
f"tensor{inp.shape}{strides_str}{dtype_str}{init_str}"
)
else: else:
input_strs.append(f"tensor{inp.shape}{dtype_str}{init_str}") input_strs.append(f"tensor{inp.shape}{dtype_str}{init_str}")
else: else:
input_strs.append(str(inp)) input_strs.append(str(inp))
base_str = f"TestCase(mode={mode_str}, inputs=[{', '.join(input_strs)}]" base_str = f"TestCase(mode={mode_str}, inputs=[{'; '.join(input_strs)}]"
if self.output: if self.output:
dtype_str = f", dtype={self.output.dtype}" if self.output.dtype else "" dtype_str = f", dtype={self.output.dtype}" if self.output.dtype else ""
init_str = ( init_str = (
...@@ -78,7 +86,16 @@ class TestCase: ...@@ -78,7 +86,16 @@ class TestCase:
if self.output.init_mode != TensorInitializer.RANDOM if self.output.init_mode != TensorInitializer.RANDOM
else "" else ""
) )
base_str += f", output=tensor{self.output.shape}{dtype_str}{init_str}" # Show shape and strides for non-contiguous output tensors
if (
hasattr(self.output, "is_contiguous")
and not self.output.is_contiguous
and self.output.strides
):
strides_str = f", strides={self.output.strides}"
base_str += f", output=tensor{self.output.shape}{strides_str}{dtype_str}{init_str}"
else:
base_str += f", output=tensor{self.output.shape}{dtype_str}{init_str}"
if self.kwargs: if self.kwargs:
base_str += f", kwargs={self.kwargs}" base_str += f", kwargs={self.kwargs}"
if self.description: if self.description:
...@@ -131,24 +148,30 @@ class TestRunner: ...@@ -131,24 +148,30 @@ class TestRunner:
if self.config.dtype_combinations: if self.config.dtype_combinations:
for dtype_combo in self.config.dtype_combinations: for dtype_combo in self.config.dtype_combinations:
try: try:
test_func(device, test_case, dtype_combo, self.config) # Print test case info first
combo_str = self._format_dtype_combo(dtype_combo) combo_str = self._format_dtype_combo(dtype_combo)
print(f"✓ {test_case} with {combo_str} passed") print(f"{test_case} with {combo_str}")
test_func(device, test_case, dtype_combo, self.config)
print(f"\033[92m✓\033[0m Passed")
except Exception as e: except Exception as e:
combo_str = self._format_dtype_combo(dtype_combo) combo_str = self._format_dtype_combo(dtype_combo)
error_msg = f"{test_case} with {combo_str} on {InfiniDeviceNames[device]}: {e}" error_msg = f"Error: {e}"
print(f" {error_msg}") print(f"\033[91m✗\033[0m {error_msg}")
self.failed_tests.append(error_msg) self.failed_tests.append(error_msg)
if self.config.debug: if self.config.debug:
raise raise
else: else:
for dtype in tensor_dtypes: for dtype in tensor_dtypes:
try: try:
# Print test case info first
print(f"{test_case} with {dtype}")
test_func(device, test_case, dtype, self.config) test_func(device, test_case, dtype, self.config)
print(f"{test_case} with {dtype} passed") print(f"\033[92m✓\033[0m Passed")
except Exception as e: except Exception as e:
error_msg = f"{test_case} with {dtype} on {InfiniDeviceNames[device]}: {e}" error_msg = f"Error: {e}"
print(f" {error_msg}") print(f"\033[91m✗\033[0m {error_msg}")
self.failed_tests.append(error_msg) self.failed_tests.append(error_msg)
if self.config.debug: if self.config.debug:
raise raise
...@@ -214,7 +237,7 @@ class BaseOperatorTest(ABC): ...@@ -214,7 +237,7 @@ class BaseOperatorTest(ABC):
raise NotImplementedError("torch_operator not implemented") raise NotImplementedError("torch_operator not implemented")
def infinicore_operator(self, *inputs, out=None, **kwargs): def infinicore_operator(self, *inputs, out=None, **kwargs):
"""Unified Infinicore operator function - can be overridden or return None""" """Unified InfiniCore operator function - can be overridden or return None"""
raise NotImplementedError("infinicore_operator not implemented") raise NotImplementedError("infinicore_operator not implemented")
def create_strided_tensor( def create_strided_tensor(
...@@ -321,9 +344,7 @@ class BaseOperatorTest(ABC): ...@@ -321,9 +344,7 @@ class BaseOperatorTest(ABC):
# If neither operator is implemented, skip the test # If neither operator is implemented, skip the test
if not torch_implemented and not infini_implemented: if not torch_implemented and not infini_implemented:
print( print(f"⚠ Both operators not implemented - test skipped")
f"⚠ {self.operator_name} {mode_name}: Both operators not implemented - test skipped"
)
return return
# If only one operator is implemented, run it without comparison # If only one operator is implemented, run it without comparison
...@@ -332,7 +353,7 @@ class BaseOperatorTest(ABC): ...@@ -332,7 +353,7 @@ class BaseOperatorTest(ABC):
"torch_operator" if not torch_implemented else "infinicore_operator" "torch_operator" if not torch_implemented else "infinicore_operator"
) )
print( print(
f"⚠ {self.operator_name} {mode_name}: {missing_op} not implemented - running single operator without comparison" f"⚠ {missing_op} not implemented - running single operator without comparison"
) )
# Run the available operator for benchmarking if requested # Run the available operator for benchmarking if requested
...@@ -342,8 +363,9 @@ class BaseOperatorTest(ABC): ...@@ -342,8 +363,9 @@ class BaseOperatorTest(ABC):
def torch_op(): def torch_op():
return self.torch_operator(*inputs, **kwargs) return self.torch_operator(*inputs, **kwargs)
print(f" {mode_name}:")
profile_operation( profile_operation(
f"PyTorch {self.operator_name} {mode_name}", "PyTorch ",
torch_op, torch_op,
device_str, device_str,
config.num_prerun, config.num_prerun,
...@@ -354,8 +376,9 @@ class BaseOperatorTest(ABC): ...@@ -354,8 +376,9 @@ class BaseOperatorTest(ABC):
def infini_op(): def infini_op():
return self.infinicore_operator(*infini_inputs, **kwargs) return self.infinicore_operator(*infini_inputs, **kwargs)
print(f" {mode_name}:")
profile_operation( profile_operation(
f"Infinicore {self.operator_name} {mode_name}", "InfiniCore",
infini_op, infini_op,
device_str, device_str,
config.num_prerun, config.num_prerun,
...@@ -388,21 +411,22 @@ class BaseOperatorTest(ABC): ...@@ -388,21 +411,22 @@ class BaseOperatorTest(ABC):
) )
compare_fn = create_test_comparator( compare_fn = create_test_comparator(
config, comparison_dtype, mode_name=f"{self.operator_name} {mode_name}" config, comparison_dtype, mode_name=f"{mode_name}"
) )
is_valid = compare_fn(infini_result, torch_result) is_valid = compare_fn(infini_result, torch_result)
assert is_valid, f"{self.operator_name} {mode_name} test failed" assert is_valid, f"{mode_name} result comparison failed"
if config.bench: if config.bench:
print(f" {mode_name}:")
profile_operation( profile_operation(
f"PyTorch {self.operator_name} {mode_name}", "PyTorch ",
torch_op, torch_op,
device_str, device_str,
config.num_prerun, config.num_prerun,
config.num_iterations, config.num_iterations,
) )
profile_operation( profile_operation(
f"Infinicore {self.operator_name} {mode_name}", "InfiniCore",
infini_op, infini_op,
device_str, device_str,
config.num_prerun, config.num_prerun,
...@@ -464,21 +488,22 @@ class BaseOperatorTest(ABC): ...@@ -464,21 +488,22 @@ class BaseOperatorTest(ABC):
test_case, dtype_config, torch_output test_case, dtype_config, torch_output
) )
compare_fn = create_test_comparator( compare_fn = create_test_comparator(
config, comparison_dtype, mode_name=f"{self.operator_name} {mode_name}" config, comparison_dtype, mode_name=f"{mode_name}"
) )
is_valid = compare_fn(infini_output, torch_output) is_valid = compare_fn(infini_output, torch_output)
assert is_valid, f"{self.operator_name} {mode_name} test failed" assert is_valid, f"{mode_name} result comparison failed"
if config.bench: if config.bench:
print(f" {mode_name}:")
profile_operation( profile_operation(
f"PyTorch {self.operator_name} {mode_name}", "PyTorch ",
torch_op_inplace, torch_op_inplace,
device_str, device_str,
config.num_prerun, config.num_prerun,
config.num_iterations, config.num_iterations,
) )
profile_operation( profile_operation(
f"Infinicore {self.operator_name} {mode_name}", "InfiniCore",
infini_op_inplace, infini_op_inplace,
device_str, device_str,
config.num_prerun, config.num_prerun,
......
...@@ -34,7 +34,7 @@ def profile_operation(desc, func, torch_device, num_prerun, num_iterations): ...@@ -34,7 +34,7 @@ def profile_operation(desc, func, torch_device, num_prerun, num_iterations):
# Timed execution # Timed execution
elapsed = timed_op(lambda: func(), num_iterations, torch_device) elapsed = timed_op(lambda: func(), num_iterations, torch_device)
print(f" {desc} time: {elapsed * 1000 :6f} ms") print(f" {desc} time: {elapsed * 1000 :6f} ms")
def is_integer_dtype(dtype): def is_integer_dtype(dtype):
...@@ -157,7 +157,7 @@ def print_discrepancy( ...@@ -157,7 +157,7 @@ def print_discrepancy(
print( print(
f" - Min(delta) : {torch.min(delta):<{col_width[1]}} | Max(delta) : {torch.max(delta):<{col_width[2]}}" f" - Min(delta) : {torch.min(delta):<{col_width[1]}} | Max(delta) : {torch.max(delta):<{col_width[2]}}"
) )
print("-" * total_width + "\n") print("-" * total_width)
return diff_indices return diff_indices
...@@ -273,7 +273,7 @@ def create_test_comparator(config, dtype, tolerance_map=None, mode_name=""): ...@@ -273,7 +273,7 @@ def create_test_comparator(config, dtype, tolerance_map=None, mode_name=""):
def compare_test_results(infini_result, torch_result): def compare_test_results(infini_result, torch_result):
if config.debug and mode_name: if config.debug and mode_name:
print(f"\n\033[94mDEBUG INFO - {mode_name}:\033[0m") print(f"\033[94mDEBUG INFO - {mode_name}:\033[0m")
return compare_results( return compare_results(
infini_result, infini_result,
torch_result, torch_result,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment