Unverified Commit 2e5b2342 authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

issue/547 - improved test output (#550)

parents bf3395f5 991f534c
......@@ -63,14 +63,22 @@ class TestCase:
if inp.init_mode != TensorInitializer.RANDOM
else ""
)
if hasattr(inp, "is_contiguous") and not inp.is_contiguous:
input_strs.append(f"strided_tensor{inp.shape}{dtype_str}{init_str}")
# Show shape and strides for non-contiguous tensors
if (
hasattr(inp, "is_contiguous")
and not inp.is_contiguous
and inp.strides
):
strides_str = f", strides={inp.strides}"
input_strs.append(
f"tensor{inp.shape}{strides_str}{dtype_str}{init_str}"
)
else:
input_strs.append(f"tensor{inp.shape}{dtype_str}{init_str}")
else:
input_strs.append(str(inp))
base_str = f"TestCase(mode={mode_str}, inputs=[{', '.join(input_strs)}]"
base_str = f"TestCase(mode={mode_str}, inputs=[{'; '.join(input_strs)}]"
if self.output:
dtype_str = f", dtype={self.output.dtype}" if self.output.dtype else ""
init_str = (
......@@ -78,7 +86,16 @@ class TestCase:
if self.output.init_mode != TensorInitializer.RANDOM
else ""
)
base_str += f", output=tensor{self.output.shape}{dtype_str}{init_str}"
# Show shape and strides for non-contiguous output tensors
if (
hasattr(self.output, "is_contiguous")
and not self.output.is_contiguous
and self.output.strides
):
strides_str = f", strides={self.output.strides}"
base_str += f", output=tensor{self.output.shape}{strides_str}{dtype_str}{init_str}"
else:
base_str += f", output=tensor{self.output.shape}{dtype_str}{init_str}"
if self.kwargs:
base_str += f", kwargs={self.kwargs}"
if self.description:
......@@ -131,24 +148,30 @@ class TestRunner:
if self.config.dtype_combinations:
for dtype_combo in self.config.dtype_combinations:
try:
test_func(device, test_case, dtype_combo, self.config)
# Print test case info first
combo_str = self._format_dtype_combo(dtype_combo)
print(f"✓ {test_case} with {combo_str} passed")
print(f"{test_case} with {combo_str}")
test_func(device, test_case, dtype_combo, self.config)
print(f"\033[92m✓\033[0m Passed")
except Exception as e:
combo_str = self._format_dtype_combo(dtype_combo)
error_msg = f"{test_case} with {combo_str} on {InfiniDeviceNames[device]}: {e}"
print(f" {error_msg}")
error_msg = f"Error: {e}"
print(f"\033[91m✗\033[0m {error_msg}")
self.failed_tests.append(error_msg)
if self.config.debug:
raise
else:
for dtype in tensor_dtypes:
try:
# Print test case info first
print(f"{test_case} with {dtype}")
test_func(device, test_case, dtype, self.config)
print(f"{test_case} with {dtype} passed")
print(f"\033[92m✓\033[0m Passed")
except Exception as e:
error_msg = f"{test_case} with {dtype} on {InfiniDeviceNames[device]}: {e}"
print(f" {error_msg}")
error_msg = f"Error: {e}"
print(f"\033[91m✗\033[0m {error_msg}")
self.failed_tests.append(error_msg)
if self.config.debug:
raise
......@@ -214,7 +237,7 @@ class BaseOperatorTest(ABC):
raise NotImplementedError("torch_operator not implemented")
def infinicore_operator(self, *inputs, out=None, **kwargs):
"""Unified Infinicore operator function - can be overridden or return None"""
"""Unified InfiniCore operator function - can be overridden or return None"""
raise NotImplementedError("infinicore_operator not implemented")
def create_strided_tensor(
......@@ -321,9 +344,7 @@ class BaseOperatorTest(ABC):
# If neither operator is implemented, skip the test
if not torch_implemented and not infini_implemented:
print(
f"⚠ {self.operator_name} {mode_name}: Both operators not implemented - test skipped"
)
print(f"⚠ Both operators not implemented - test skipped")
return
# If only one operator is implemented, run it without comparison
......@@ -332,7 +353,7 @@ class BaseOperatorTest(ABC):
"torch_operator" if not torch_implemented else "infinicore_operator"
)
print(
f"⚠ {self.operator_name} {mode_name}: {missing_op} not implemented - running single operator without comparison"
f"⚠ {missing_op} not implemented - running single operator without comparison"
)
# Run the available operator for benchmarking if requested
......@@ -342,8 +363,9 @@ class BaseOperatorTest(ABC):
def torch_op():
return self.torch_operator(*inputs, **kwargs)
print(f" {mode_name}:")
profile_operation(
f"PyTorch {self.operator_name} {mode_name}",
"PyTorch ",
torch_op,
device_str,
config.num_prerun,
......@@ -354,8 +376,9 @@ class BaseOperatorTest(ABC):
def infini_op():
return self.infinicore_operator(*infini_inputs, **kwargs)
print(f" {mode_name}:")
profile_operation(
f"Infinicore {self.operator_name} {mode_name}",
"InfiniCore",
infini_op,
device_str,
config.num_prerun,
......@@ -388,21 +411,22 @@ class BaseOperatorTest(ABC):
)
compare_fn = create_test_comparator(
config, comparison_dtype, mode_name=f"{self.operator_name} {mode_name}"
config, comparison_dtype, mode_name=f"{mode_name}"
)
is_valid = compare_fn(infini_result, torch_result)
assert is_valid, f"{self.operator_name} {mode_name} test failed"
assert is_valid, f"{mode_name} result comparison failed"
if config.bench:
print(f" {mode_name}:")
profile_operation(
f"PyTorch {self.operator_name} {mode_name}",
"PyTorch ",
torch_op,
device_str,
config.num_prerun,
config.num_iterations,
)
profile_operation(
f"Infinicore {self.operator_name} {mode_name}",
"InfiniCore",
infini_op,
device_str,
config.num_prerun,
......@@ -464,21 +488,22 @@ class BaseOperatorTest(ABC):
test_case, dtype_config, torch_output
)
compare_fn = create_test_comparator(
config, comparison_dtype, mode_name=f"{self.operator_name} {mode_name}"
config, comparison_dtype, mode_name=f"{mode_name}"
)
is_valid = compare_fn(infini_output, torch_output)
assert is_valid, f"{self.operator_name} {mode_name} test failed"
assert is_valid, f"{mode_name} result comparison failed"
if config.bench:
print(f" {mode_name}:")
profile_operation(
f"PyTorch {self.operator_name} {mode_name}",
"PyTorch ",
torch_op_inplace,
device_str,
config.num_prerun,
config.num_iterations,
)
profile_operation(
f"Infinicore {self.operator_name} {mode_name}",
"InfiniCore",
infini_op_inplace,
device_str,
config.num_prerun,
......
......@@ -34,7 +34,7 @@ def profile_operation(desc, func, torch_device, num_prerun, num_iterations):
# Timed execution
elapsed = timed_op(lambda: func(), num_iterations, torch_device)
print(f" {desc} time: {elapsed * 1000 :6f} ms")
print(f" {desc} time: {elapsed * 1000 :6f} ms")
def is_integer_dtype(dtype):
......@@ -157,7 +157,7 @@ def print_discrepancy(
print(
f" - Min(delta) : {torch.min(delta):<{col_width[1]}} | Max(delta) : {torch.max(delta):<{col_width[2]}}"
)
print("-" * total_width + "\n")
print("-" * total_width)
return diff_indices
......@@ -273,7 +273,7 @@ def create_test_comparator(config, dtype, tolerance_map=None, mode_name=""):
def compare_test_results(infini_result, torch_result):
if config.debug and mode_name:
print(f"\n\033[94mDEBUG INFO - {mode_name}:\033[0m")
print(f"\033[94mDEBUG INFO - {mode_name}:\033[0m")
return compare_results(
infini_result,
torch_result,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment