Unverified Commit 0883d6ee authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #605 from InfiniTensor/issue/603

Issue/603 - 优化张量复制逻辑
parents 17f65139 cf4403d6
...@@ -35,7 +35,7 @@ static void calculate( ...@@ -35,7 +35,7 @@ static void calculate(
if (!desc_opt) { if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateRandomSampleDescriptor( INFINICORE_CHECK_ERROR(infiniopCreateRandomSampleDescriptor(
context::getInfiniopHandle(), &desc, context::getInfiniopHandle(indices->device()), &desc,
indices->desc(), logits->desc())); indices->desc(), logits->desc()));
cache.put(seed, desc); cache.put(seed, desc);
} else { } else {
......
...@@ -260,9 +260,7 @@ class TestRunner: ...@@ -260,9 +260,7 @@ class TestRunner:
return False return False
except Exception as e: except Exception as e:
error_msg = ( error_msg = f"Error: {e}"
f"{test_case} - {InfiniDeviceNames[device]} - Error: {e}"
)
print(f"\033[91m✗\033[0m {error_msg}") print(f"\033[91m✗\033[0m {error_msg}")
self.failed_tests.append(error_msg) self.failed_tests.append(error_msg)
...@@ -392,7 +390,7 @@ class BaseOperatorTest(ABC): ...@@ -392,7 +390,7 @@ class BaseOperatorTest(ABC):
return spec.create_torch_tensor(device) return spec.create_torch_tensor(device)
return spec return spec
def prepare_inputs_and_kwargs(self, test_case, device): def prepare_pytorch_inputs_and_kwargs(self, test_case, device):
"""Prepare inputs and kwargs, replacing TensorSpec objects with actual tensors """Prepare inputs and kwargs, replacing TensorSpec objects with actual tensors
Supports tuple inputs for operators like torch.cat and TensorSpec in kwargs Supports tuple inputs for operators like torch.cat and TensorSpec in kwargs
""" """
...@@ -455,6 +453,71 @@ class BaseOperatorTest(ABC): ...@@ -455,6 +453,71 @@ class BaseOperatorTest(ABC):
return inputs, kwargs return inputs, kwargs
def prepare_infinicore_list(self, input_sequence, clone=False):
cloned_tensors = []
infini_list = []
for item in input_sequence:
if isinstance(item, torch.Tensor):
if clone:
cloned_item = item.clone().detach()
infini_item = infinicore_tensor_from_torch(cloned_item)
cloned_tensors.append(cloned_item)
else:
infini_item = infinicore_tensor_from_torch(item)
else:
infini_item = item
infini_list.append(infini_item)
return infini_list, cloned_tensors
def prepare_infinicore_inputs_and_kwargs(self, inputs, kwargs, comparison_target):
cloned_tensors = []
infini_inputs = []
# Prepare infinicore inputs - only clone if needed for comparison
for i, inp in enumerate(inputs):
if isinstance(inp, torch.Tensor):
# Clone only if this input will be used for comparison
if comparison_target == i:
cloned_inp = inp.clone().detach()
infini_tensor = infinicore_tensor_from_torch(cloned_inp)
cloned_tensors.append(cloned_inp)
else:
# For non-comparison inputs, we can use the original (but still need to convert)
infini_tensor = infinicore_tensor_from_torch(inp)
infini_inputs.append(infini_tensor)
elif isinstance(inp, (tuple, list)):
infini_list, cloned_list = self.prepare_infinicore_list(
inp, comparison_target == i
)
infini_inputs.append(infini_list)
cloned_tensors.append(cloned_list)
else:
infini_inputs.append(inp)
# Prepare infinicore kwargs
infini_kwargs = {}
for key, value in kwargs.items():
if isinstance(value, torch.Tensor):
# Check if this tensor is used for output comparison
if key == "out" and comparison_target == "out":
cloned_value = value.clone().detach()
infini_kwargs[key] = infinicore_tensor_from_torch(cloned_value)
cloned_tensors.append(cloned_value)
elif key == "out" and isinstance(comparison_target, int):
infini_kwargs[key] = infini_inputs[comparison_target]
else:
infini_kwargs[key] = infinicore_tensor_from_torch(value)
elif isinstance(value, (tuple, list)):
infini_list, cloned_list = self.prepare_infinicore_list(
value, key == "out"
)
cloned_tensors.append(cloned_list)
infini_kwargs[key] = infini_list
else:
infini_kwargs[key] = value
return infini_inputs, infini_kwargs, cloned_tensors
def run_test(self, device, test_case, config): def run_test(self, device, test_case, config):
""" """
Unified test execution flow Unified test execution flow
...@@ -478,66 +541,15 @@ class BaseOperatorTest(ABC): ...@@ -478,66 +541,15 @@ class BaseOperatorTest(ABC):
) )
# Prepare inputs and kwargs with actual tensors # Prepare inputs and kwargs with actual tensors
inputs, kwargs = self.prepare_inputs_and_kwargs(test_case, device) inputs, kwargs = self.prepare_pytorch_inputs_and_kwargs(test_case, device)
# For in-place operations on input tensors, we need to preserve the original state
original_inputs = []
if "out" in kwargs and isinstance(kwargs["out"], torch.Tensor):
# This is an in-place operation on an input tensor
# Store original values for comparison
for inp in inputs:
if isinstance(inp, torch.Tensor):
original_inputs.append(inp.clone().detach())
else:
original_inputs.append(inp)
# Create infinicore inputs (cloned to avoid in-place modifications affecting reference)
infini_inputs = []
torch_input_clones = []
for inp in inputs:
if isinstance(inp, torch.Tensor):
cloned_inp = inp.clone().detach()
torch_input_clones.append(cloned_inp)
infini_tensor = infinicore_tensor_from_torch(cloned_inp)
infini_inputs.append(infini_tensor)
else:
infini_inputs.append(inp)
infini_kwargs = {}
for key, value in kwargs.items():
if isinstance(value, torch.Tensor):
# Clone tensor and convert to infinicore
cloned_value = value.clone().detach()
torch_input_clones.append(cloned_value)
infini_kwargs[key] = infinicore_tensor_from_torch(cloned_value)
else:
# Pass through non-tensor values (scalars, strings, etc.)
infini_kwargs[key] = value
# Determine comparison target # Determine comparison target
comparison_target = test_case.comparison_target comparison_target = test_case.comparison_target
# Handle infinicore output # Create infinicore inputs (cloned to avoid in-place modifications affecting reference)
infini_kwargs = kwargs.copy() infini_inputs, infini_kwargs, cloned_tensors = (
if "out" in infini_kwargs: self.prepare_infinicore_inputs_and_kwargs(inputs, kwargs, comparison_target)
out_value = infini_kwargs["out"] )
if isinstance(out_value, torch.Tensor):
# Single tensor output
if isinstance(comparison_target, int):
infini_kwargs["out"] = infini_inputs[comparison_target]
else:
cloned_out = out_value.clone().detach()
torch_input_clones.append(cloned_out)
infini_kwargs["out"] = infinicore_tensor_from_torch(cloned_out)
elif isinstance(out_value, (tuple, list)):
# Multiple tensor outputs
infini_outputs = []
for tensor in out_value:
cloned_tensor = tensor.clone().detach()
torch_input_clones.append(cloned_tensor)
infini_outputs.append(infinicore_tensor_from_torch(cloned_tensor))
infini_kwargs["out"] = tuple(infini_outputs)
# Check operator implementations # Check operator implementations
torch_implemented = True torch_implemented = True
...@@ -698,7 +710,7 @@ class BaseOperatorTest(ABC): ...@@ -698,7 +710,7 @@ class BaseOperatorTest(ABC):
is_valid = compare_fn(infini_comparison, torch_comparison) is_valid = compare_fn(infini_comparison, torch_comparison)
if not is_valid: if not is_valid:
raise AssertionError(f"Result comparison failed for {test_case}") raise AssertionError(f"Result comparison failed.")
# ========================================================================== # ==========================================================================
# UNIFIED BENCHMARKING LOGIC # UNIFIED BENCHMARKING LOGIC
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment