Unverified Commit e16b6dbc authored by thatPepe's avatar thatPepe Committed by GitHub
Browse files

Issue/672 a (#673)

* issue/672 - fixed non-contiguous operations in the test framework
parent 79e6883b
......@@ -13,6 +13,7 @@ from .datatypes import to_torch_dtype, to_infinicore_dtype
from .devices import InfiniDeviceNames, torch_device_map
from .tensor import TensorSpec, TensorInitializer
from .utils import (
clone_torch_tensor,
create_test_comparator,
infinicore_tensor_from_torch,
)
......@@ -321,7 +322,7 @@ class BaseOperatorTest(ABC):
for item in input_sequence:
if isinstance(item, torch.Tensor):
if clone:
cloned_item = item.clone().detach()
cloned_item = clone_torch_tensor(item)
infini_item = infinicore_tensor_from_torch(cloned_item)
cloned_tensors.append(cloned_item)
else:
......@@ -340,7 +341,7 @@ class BaseOperatorTest(ABC):
if isinstance(inp, torch.Tensor):
# Clone only if this input will be used for comparison
if comparison_target == i:
cloned_inp = inp.clone().detach()
cloned_inp = clone_torch_tensor(inp)
infini_tensor = infinicore_tensor_from_torch(cloned_inp)
cloned_tensors.append(cloned_inp)
else:
......@@ -362,7 +363,7 @@ class BaseOperatorTest(ABC):
if isinstance(value, torch.Tensor):
# Check if this tensor is used for output comparison
if key == "out" and comparison_target == "out":
cloned_value = value.clone().detach()
cloned_value = clone_torch_tensor(value)
infini_kwargs[key] = infinicore_tensor_from_torch(cloned_value)
cloned_tensors.append(cloned_value)
elif key == "out" and isinstance(comparison_target, int):
......@@ -566,12 +567,12 @@ class BaseOperatorTest(ABC):
elif comparison_target == "out":
# Compare output tensor from kwargs (explicit output)
torch_comparison = kwargs.get("out")
infini_comparison = infini_kwargs.get("out")
infini_comparison = cloned_tensors[0]
elif isinstance(comparison_target, int):
# Compare specific input tensor (in-place operation on input)
if 0 <= comparison_target < len(inputs):
torch_comparison = inputs[comparison_target]
infini_comparison = infini_inputs[comparison_target]
infini_comparison = cloned_tensors[0]
else:
raise ValueError(
f"Invalid comparison target index: {comparison_target}"
......
......@@ -118,6 +118,13 @@ def get_tolerance(tolerance_map, tensor_dtype, default_atol=0, default_rtol=1e-3
return tolerance["atol"], tolerance["rtol"]
def clone_torch_tensor(torch_tensor):
cloned = torch_tensor.clone().detach()
if not torch_tensor.is_contiguous():
cloned = rearrange_tensor(cloned, torch_tensor.stride())
return cloned
def infinicore_tensor_from_torch(torch_tensor):
infini_device = infinicore.device(torch_tensor.device.type, 0)
if torch_tensor.is_contiguous():
......@@ -152,6 +159,10 @@ def convert_infinicore_to_torch(infini_result):
dtype=to_torch_dtype(infini_result.dtype),
device=infini_result.device.type,
)
if not infini_result.is_contiguous():
torch_result_from_infini = rearrange_tensor(
torch_result_from_infini, infini_result.stride()
)
temp_tensor = infinicore_tensor_from_torch(torch_result_from_infini)
temp_tensor.copy_(infini_result)
return torch_result_from_infini
......@@ -223,7 +234,10 @@ def compare_results(
return result_equal
# Convert infinicore result to PyTorch tensor for comparison
torch_result_from_infini = convert_infinicore_to_torch(infini_result)
if isinstance(infini_result, torch.Tensor):
torch_result_from_infini = infini_result
else:
torch_result_from_infini = convert_infinicore_to_torch(infini_result)
# Debug mode: detailed comparison
if debug_mode:
......
......@@ -49,8 +49,8 @@ _TEST_CASES_DATA = [
((13, 4), 0, False, None, (3,), (3,)),
((13, 4), 1, False, (20, 1), (10,), (10,)),
# 3D in-place cases
((4, 5, 6), 1, True, None, (4, 1, 6), (4, 1, 6)),
((4, 5, 6), -1, False, (30, 6, 1), (4, 5), (4, 5)),
((4, 5, 6), 1, True, None, (6, 6, 1), (6, 6, 1)),
((4, 5, 6), -1, False, (30, 6, 1), (5, 1), (5, 1)),
]
# Tolerance configuration
......
......@@ -28,7 +28,6 @@ _TEST_CASES_DATA = [
((4, 48, 6), None, None),
# Strided tensors
((1, 2048), (4096, 1), (4096, 1)),
((6, 2560), (2048, 1), (2560, 1)),
# Mixed cases
((8, 16, 32), None, None),
# Large tensors
......
......@@ -31,12 +31,12 @@ _TEST_CASES_DATA = [
((4, 5, 6), 1, False, None, None, None),
((4, 5, 6), -1, True, None, None, None),
# 3D in-place cases
((4, 5, 6), 1, False, None, (4, 1, 6), (4, 1, 6)),
((4, 5, 6), -1, False, (30, 6, 1), (64, 1, 5), (64, 1, 5)),
((4, 5, 6), 1, False, None, (30, 6, 1), (30, 6, 1)),
((4, 5, 6), -1, False, (30, 6, 1), (30, 6, 1), (30, 6, 1)),
# Strided inputs and outputs
((13, 4), None, False, (4, 1), (12, 1), (24, 1)),
((13, 4), 0, False, (1, 4), (64, 1), (1, 4)),
((13, 4), 1, False, (1, 4), (64, 1), (1, 4)),
((13, 4), None, False, (4, 1), (4, 1), (4, 1)),
((13, 4), 0, False, (13, 1), (13, 1), (13, 1)),
((13, 4), 1, False, (13, 1), (13, 1), (13, 1)),
]
# Tolerance configuration
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment