Unverified Commit e16b6dbc authored by thatPepe's avatar thatPepe Committed by GitHub
Browse files

Issue/672 a (#673)

* issue/672 - fixed non-contiguous operations in the test framework
parent 79e6883b
...@@ -13,6 +13,7 @@ from .datatypes import to_torch_dtype, to_infinicore_dtype ...@@ -13,6 +13,7 @@ from .datatypes import to_torch_dtype, to_infinicore_dtype
from .devices import InfiniDeviceNames, torch_device_map from .devices import InfiniDeviceNames, torch_device_map
from .tensor import TensorSpec, TensorInitializer from .tensor import TensorSpec, TensorInitializer
from .utils import ( from .utils import (
clone_torch_tensor,
create_test_comparator, create_test_comparator,
infinicore_tensor_from_torch, infinicore_tensor_from_torch,
) )
...@@ -321,7 +322,7 @@ class BaseOperatorTest(ABC): ...@@ -321,7 +322,7 @@ class BaseOperatorTest(ABC):
for item in input_sequence: for item in input_sequence:
if isinstance(item, torch.Tensor): if isinstance(item, torch.Tensor):
if clone: if clone:
cloned_item = item.clone().detach() cloned_item = clone_torch_tensor(item)
infini_item = infinicore_tensor_from_torch(cloned_item) infini_item = infinicore_tensor_from_torch(cloned_item)
cloned_tensors.append(cloned_item) cloned_tensors.append(cloned_item)
else: else:
...@@ -340,7 +341,7 @@ class BaseOperatorTest(ABC): ...@@ -340,7 +341,7 @@ class BaseOperatorTest(ABC):
if isinstance(inp, torch.Tensor): if isinstance(inp, torch.Tensor):
# Clone only if this input will be used for comparison # Clone only if this input will be used for comparison
if comparison_target == i: if comparison_target == i:
cloned_inp = inp.clone().detach() cloned_inp = clone_torch_tensor(inp)
infini_tensor = infinicore_tensor_from_torch(cloned_inp) infini_tensor = infinicore_tensor_from_torch(cloned_inp)
cloned_tensors.append(cloned_inp) cloned_tensors.append(cloned_inp)
else: else:
...@@ -362,7 +363,7 @@ class BaseOperatorTest(ABC): ...@@ -362,7 +363,7 @@ class BaseOperatorTest(ABC):
if isinstance(value, torch.Tensor): if isinstance(value, torch.Tensor):
# Check if this tensor is used for output comparison # Check if this tensor is used for output comparison
if key == "out" and comparison_target == "out": if key == "out" and comparison_target == "out":
cloned_value = value.clone().detach() cloned_value = clone_torch_tensor(value)
infini_kwargs[key] = infinicore_tensor_from_torch(cloned_value) infini_kwargs[key] = infinicore_tensor_from_torch(cloned_value)
cloned_tensors.append(cloned_value) cloned_tensors.append(cloned_value)
elif key == "out" and isinstance(comparison_target, int): elif key == "out" and isinstance(comparison_target, int):
...@@ -566,12 +567,12 @@ class BaseOperatorTest(ABC): ...@@ -566,12 +567,12 @@ class BaseOperatorTest(ABC):
elif comparison_target == "out": elif comparison_target == "out":
# Compare output tensor from kwargs (explicit output) # Compare output tensor from kwargs (explicit output)
torch_comparison = kwargs.get("out") torch_comparison = kwargs.get("out")
infini_comparison = infini_kwargs.get("out") infini_comparison = cloned_tensors[0]
elif isinstance(comparison_target, int): elif isinstance(comparison_target, int):
# Compare specific input tensor (in-place operation on input) # Compare specific input tensor (in-place operation on input)
if 0 <= comparison_target < len(inputs): if 0 <= comparison_target < len(inputs):
torch_comparison = inputs[comparison_target] torch_comparison = inputs[comparison_target]
infini_comparison = infini_inputs[comparison_target] infini_comparison = cloned_tensors[0]
else: else:
raise ValueError( raise ValueError(
f"Invalid comparison target index: {comparison_target}" f"Invalid comparison target index: {comparison_target}"
......
...@@ -118,6 +118,13 @@ def get_tolerance(tolerance_map, tensor_dtype, default_atol=0, default_rtol=1e-3 ...@@ -118,6 +118,13 @@ def get_tolerance(tolerance_map, tensor_dtype, default_atol=0, default_rtol=1e-3
return tolerance["atol"], tolerance["rtol"] return tolerance["atol"], tolerance["rtol"]
def clone_torch_tensor(torch_tensor):
cloned = torch_tensor.clone().detach()
if not torch_tensor.is_contiguous():
cloned = rearrange_tensor(cloned, torch_tensor.stride())
return cloned
def infinicore_tensor_from_torch(torch_tensor): def infinicore_tensor_from_torch(torch_tensor):
infini_device = infinicore.device(torch_tensor.device.type, 0) infini_device = infinicore.device(torch_tensor.device.type, 0)
if torch_tensor.is_contiguous(): if torch_tensor.is_contiguous():
...@@ -152,6 +159,10 @@ def convert_infinicore_to_torch(infini_result): ...@@ -152,6 +159,10 @@ def convert_infinicore_to_torch(infini_result):
dtype=to_torch_dtype(infini_result.dtype), dtype=to_torch_dtype(infini_result.dtype),
device=infini_result.device.type, device=infini_result.device.type,
) )
if not infini_result.is_contiguous():
torch_result_from_infini = rearrange_tensor(
torch_result_from_infini, infini_result.stride()
)
temp_tensor = infinicore_tensor_from_torch(torch_result_from_infini) temp_tensor = infinicore_tensor_from_torch(torch_result_from_infini)
temp_tensor.copy_(infini_result) temp_tensor.copy_(infini_result)
return torch_result_from_infini return torch_result_from_infini
...@@ -223,6 +234,9 @@ def compare_results( ...@@ -223,6 +234,9 @@ def compare_results(
return result_equal return result_equal
# Convert infinicore result to PyTorch tensor for comparison # Convert infinicore result to PyTorch tensor for comparison
if isinstance(infini_result, torch.Tensor):
torch_result_from_infini = infini_result
else:
torch_result_from_infini = convert_infinicore_to_torch(infini_result) torch_result_from_infini = convert_infinicore_to_torch(infini_result)
# Debug mode: detailed comparison # Debug mode: detailed comparison
......
...@@ -49,8 +49,8 @@ _TEST_CASES_DATA = [ ...@@ -49,8 +49,8 @@ _TEST_CASES_DATA = [
((13, 4), 0, False, None, (3,), (3,)), ((13, 4), 0, False, None, (3,), (3,)),
((13, 4), 1, False, (20, 1), (10,), (10,)), ((13, 4), 1, False, (20, 1), (10,), (10,)),
# 3D in-place cases # 3D in-place cases
((4, 5, 6), 1, True, None, (4, 1, 6), (4, 1, 6)), ((4, 5, 6), 1, True, None, (6, 6, 1), (6, 6, 1)),
((4, 5, 6), -1, False, (30, 6, 1), (4, 5), (4, 5)), ((4, 5, 6), -1, False, (30, 6, 1), (5, 1), (5, 1)),
] ]
# Tolerance configuration # Tolerance configuration
......
...@@ -28,7 +28,6 @@ _TEST_CASES_DATA = [ ...@@ -28,7 +28,6 @@ _TEST_CASES_DATA = [
((4, 48, 6), None, None), ((4, 48, 6), None, None),
# Strided tensors # Strided tensors
((1, 2048), (4096, 1), (4096, 1)), ((1, 2048), (4096, 1), (4096, 1)),
((6, 2560), (2048, 1), (2560, 1)),
# Mixed cases # Mixed cases
((8, 16, 32), None, None), ((8, 16, 32), None, None),
# Large tensors # Large tensors
......
...@@ -31,12 +31,12 @@ _TEST_CASES_DATA = [ ...@@ -31,12 +31,12 @@ _TEST_CASES_DATA = [
((4, 5, 6), 1, False, None, None, None), ((4, 5, 6), 1, False, None, None, None),
((4, 5, 6), -1, True, None, None, None), ((4, 5, 6), -1, True, None, None, None),
# 3D in-place cases # 3D in-place cases
((4, 5, 6), 1, False, None, (4, 1, 6), (4, 1, 6)), ((4, 5, 6), 1, False, None, (30, 6, 1), (30, 6, 1)),
((4, 5, 6), -1, False, (30, 6, 1), (64, 1, 5), (64, 1, 5)), ((4, 5, 6), -1, False, (30, 6, 1), (30, 6, 1), (30, 6, 1)),
# Strided inputs and outputs # Strided inputs and outputs
((13, 4), None, False, (4, 1), (12, 1), (24, 1)), ((13, 4), None, False, (4, 1), (4, 1), (4, 1)),
((13, 4), 0, False, (1, 4), (64, 1), (1, 4)), ((13, 4), 0, False, (13, 1), (13, 1), (13, 1)),
((13, 4), 1, False, (1, 4), (64, 1), (1, 4)), ((13, 4), 1, False, (13, 1), (13, 1), (13, 1)),
] ]
# Tolerance configuration # Tolerance configuration
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment