Unverified Commit 6f167986 authored by thatPepe's avatar thatPepe Committed by GitHub
Browse files

Merge pull request #525 from InfiniTensor/issue/524

issue/524 - support unimplemented operator calls
parents 478e102c 9a429ae3
...@@ -209,15 +209,13 @@ class BaseOperatorTest(ABC): ...@@ -209,15 +209,13 @@ class BaseOperatorTest(ABC):
"""Return dtype combinations for mixed dtype tests""" """Return dtype combinations for mixed dtype tests"""
return None return None
@abstractmethod
def torch_operator(self, *inputs, out=None, **kwargs): def torch_operator(self, *inputs, out=None, **kwargs):
"""Unified PyTorch operator function""" """Unified PyTorch operator function - can be overridden or return None"""
pass raise NotImplementedError("torch_operator not implemented")
@abstractmethod
def infinicore_operator(self, *inputs, out=None, **kwargs): def infinicore_operator(self, *inputs, out=None, **kwargs):
"""Unified Infinicore operator function""" """Unified Infinicore operator function - can be overridden or return None"""
pass raise NotImplementedError("infinicore_operator not implemented")
def create_strided_tensor( def create_strided_tensor(
self, shape, strides, dtype, device, init_mode=TensorInitializer.RANDOM self, shape, strides, dtype, device, init_mode=TensorInitializer.RANDOM
...@@ -301,6 +299,71 @@ class BaseOperatorTest(ABC): ...@@ -301,6 +299,71 @@ class BaseOperatorTest(ABC):
else: else:
infini_inputs.append(inp) infini_inputs.append(inp)
# Check if operators are implemented
torch_implemented = True
infini_implemented = True
try:
torch_result = self.torch_operator(*inputs, **kwargs)
if torch_result is None:
torch_implemented = False
except NotImplementedError:
torch_implemented = False
torch_result = None
try:
infini_result = self.infinicore_operator(*infini_inputs, **kwargs)
if infini_result is None:
infini_implemented = False
except NotImplementedError:
infini_implemented = False
infini_result = None
# If neither operator is implemented, skip the test
if not torch_implemented and not infini_implemented:
print(
f"⚠ {self.operator_name} {mode_name}: Both operators not implemented - test skipped"
)
return
# If only one operator is implemented, run it without comparison
if not torch_implemented or not infini_implemented:
missing_op = (
"torch_operator" if not torch_implemented else "infinicore_operator"
)
print(
f"⚠ {self.operator_name} {mode_name}: {missing_op} not implemented - running single operator without comparison"
)
# Run the available operator for benchmarking if requested
if config.bench:
if torch_implemented:
def torch_op():
return self.torch_operator(*inputs, **kwargs)
profile_operation(
f"PyTorch {self.operator_name} {mode_name}",
torch_op,
device_str,
config.num_prerun,
config.num_iterations,
)
if infini_implemented:
def infini_op():
return self.infinicore_operator(*infini_inputs, **kwargs)
profile_operation(
f"Infinicore {self.operator_name} {mode_name}",
infini_op,
device_str,
config.num_prerun,
config.num_iterations,
)
return
# Both operators are implemented - proceed with normal comparison
if test_case.operation_mode == TestCase.OUT_OF_PLACE: if test_case.operation_mode == TestCase.OUT_OF_PLACE:
def torch_op(): def torch_op():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment