Commit cf4403d6 authored by wooway777's avatar wooway777
Browse files

issue/603 - reduced test tensor clones

parent 17f65139
......@@ -35,7 +35,7 @@ static void calculate(
if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateRandomSampleDescriptor(
context::getInfiniopHandle(), &desc,
context::getInfiniopHandle(indices->device()), &desc,
indices->desc(), logits->desc()));
cache.put(seed, desc);
} else {
......
......@@ -260,9 +260,7 @@ class TestRunner:
return False
except Exception as e:
error_msg = (
f"{test_case} - {InfiniDeviceNames[device]} - Error: {e}"
)
error_msg = f"Error: {e}"
print(f"\033[91m✗\033[0m {error_msg}")
self.failed_tests.append(error_msg)
......@@ -392,7 +390,7 @@ class BaseOperatorTest(ABC):
return spec.create_torch_tensor(device)
return spec
def prepare_inputs_and_kwargs(self, test_case, device):
def prepare_pytorch_inputs_and_kwargs(self, test_case, device):
"""Prepare inputs and kwargs, replacing TensorSpec objects with actual tensors
Supports tuple inputs for operators like torch.cat and TensorSpec in kwargs
"""
......@@ -455,6 +453,71 @@ class BaseOperatorTest(ABC):
return inputs, kwargs
def prepare_infinicore_list(self, input_sequence, clone=False):
cloned_tensors = []
infini_list = []
for item in input_sequence:
if isinstance(item, torch.Tensor):
if clone:
cloned_item = item.clone().detach()
infini_item = infinicore_tensor_from_torch(cloned_item)
cloned_tensors.append(cloned_item)
else:
infini_item = infinicore_tensor_from_torch(item)
else:
infini_item = item
infini_list.append(infini_item)
return infini_list, cloned_tensors
def prepare_infinicore_inputs_and_kwargs(self, inputs, kwargs, comparison_target):
cloned_tensors = []
infini_inputs = []
# Prepare infinicore inputs - only clone if needed for comparison
for i, inp in enumerate(inputs):
if isinstance(inp, torch.Tensor):
# Clone only if this input will be used for comparison
if comparison_target == i:
cloned_inp = inp.clone().detach()
infini_tensor = infinicore_tensor_from_torch(cloned_inp)
cloned_tensors.append(cloned_inp)
else:
# For non-comparison inputs, we can use the original (but still need to convert)
infini_tensor = infinicore_tensor_from_torch(inp)
infini_inputs.append(infini_tensor)
elif isinstance(inp, (tuple, list)):
infini_list, cloned_list = self.prepare_infinicore_list(
inp, comparison_target == i
)
infini_inputs.append(infini_list)
cloned_tensors.append(cloned_list)
else:
infini_inputs.append(inp)
# Prepare infinicore kwargs
infini_kwargs = {}
for key, value in kwargs.items():
if isinstance(value, torch.Tensor):
# Check if this tensor is used for output comparison
if key == "out" and comparison_target == "out":
cloned_value = value.clone().detach()
infini_kwargs[key] = infinicore_tensor_from_torch(cloned_value)
cloned_tensors.append(cloned_value)
elif key == "out" and isinstance(comparison_target, int):
infini_kwargs[key] = infini_inputs[comparison_target]
else:
infini_kwargs[key] = infinicore_tensor_from_torch(value)
elif isinstance(value, (tuple, list)):
infini_list, cloned_list = self.prepare_infinicore_list(
value, key == "out"
)
cloned_tensors.append(cloned_list)
infini_kwargs[key] = infini_list
else:
infini_kwargs[key] = value
return infini_inputs, infini_kwargs, cloned_tensors
def run_test(self, device, test_case, config):
"""
Unified test execution flow
......@@ -478,66 +541,15 @@ class BaseOperatorTest(ABC):
)
# Prepare inputs and kwargs with actual tensors
inputs, kwargs = self.prepare_inputs_and_kwargs(test_case, device)
# For in-place operations on input tensors, we need to preserve the original state
original_inputs = []
if "out" in kwargs and isinstance(kwargs["out"], torch.Tensor):
# This is an in-place operation on an input tensor
# Store original values for comparison
for inp in inputs:
if isinstance(inp, torch.Tensor):
original_inputs.append(inp.clone().detach())
else:
original_inputs.append(inp)
# Create infinicore inputs (cloned to avoid in-place modifications affecting reference)
infini_inputs = []
torch_input_clones = []
for inp in inputs:
if isinstance(inp, torch.Tensor):
cloned_inp = inp.clone().detach()
torch_input_clones.append(cloned_inp)
infini_tensor = infinicore_tensor_from_torch(cloned_inp)
infini_inputs.append(infini_tensor)
else:
infini_inputs.append(inp)
infini_kwargs = {}
for key, value in kwargs.items():
if isinstance(value, torch.Tensor):
# Clone tensor and convert to infinicore
cloned_value = value.clone().detach()
torch_input_clones.append(cloned_value)
infini_kwargs[key] = infinicore_tensor_from_torch(cloned_value)
else:
# Pass through non-tensor values (scalars, strings, etc.)
infini_kwargs[key] = value
inputs, kwargs = self.prepare_pytorch_inputs_and_kwargs(test_case, device)
# Determine comparison target
comparison_target = test_case.comparison_target
# Handle infinicore output
infini_kwargs = kwargs.copy()
if "out" in infini_kwargs:
out_value = infini_kwargs["out"]
if isinstance(out_value, torch.Tensor):
# Single tensor output
if isinstance(comparison_target, int):
infini_kwargs["out"] = infini_inputs[comparison_target]
else:
cloned_out = out_value.clone().detach()
torch_input_clones.append(cloned_out)
infini_kwargs["out"] = infinicore_tensor_from_torch(cloned_out)
elif isinstance(out_value, (tuple, list)):
# Multiple tensor outputs
infini_outputs = []
for tensor in out_value:
cloned_tensor = tensor.clone().detach()
torch_input_clones.append(cloned_tensor)
infini_outputs.append(infinicore_tensor_from_torch(cloned_tensor))
infini_kwargs["out"] = tuple(infini_outputs)
# Create infinicore inputs (cloned to avoid in-place modifications affecting reference)
infini_inputs, infini_kwargs, cloned_tensors = (
self.prepare_infinicore_inputs_and_kwargs(inputs, kwargs, comparison_target)
)
# Check operator implementations
torch_implemented = True
......@@ -698,7 +710,7 @@ class BaseOperatorTest(ABC):
is_valid = compare_fn(infini_comparison, torch_comparison)
if not is_valid:
raise AssertionError(f"Result comparison failed for {test_case}")
raise AssertionError(f"Result comparison failed.")
# ==========================================================================
# UNIFIED BENCHMARKING LOGIC
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment