Unverified Commit f4bf6ac9 authored by thatPepe's avatar thatPepe Committed by GitHub
Browse files

Merge pull request #617 from gongchensu/feature/fix_randomSample_tests

Fix random_sample test for new framework API.
parents 79b70e58 77a96137
......@@ -129,9 +129,9 @@ class OpTest(BaseOperatorTest):
def get_test_cases(self):
return parse_test_cases()
def prepare_inputs_and_kwargs(self, test_case, device):
def prepare_pytorch_inputs_and_kwargs(self, test_case, device):
"""Prepare inputs and kwargs, replacing TensorSpec objects with actual tensors"""
inputs, kwargs = super().prepare_inputs_and_kwargs(test_case, device)
inputs, kwargs = super().prepare_pytorch_inputs_and_kwargs(test_case, device)
# If we already have stored logits (from a previous call), reuse them
# to ensure consistency across multiple calls for the same test case
......@@ -209,26 +209,27 @@ class OpTest(BaseOperatorTest):
try:
# Try the standard comparison first
# This will call prepare_inputs_and_kwargs which will set self._current_logits
# This will call prepare_pytorch_inputs_and_kwargs which will set self._current_logits
return super().run_test(device, test_case, config)
except AssertionError as original_error:
# If standard comparison fails, check if this is a valid case where
# indices differ but logits values are equal
# Only handle if we have stored logits (from prepare_inputs_and_kwargs)
# Only handle if we have stored logits (from prepare_pytorch_inputs_and_kwargs)
if self._current_logits is None:
raise
logits_tensor = self._current_logits
# Re-run operations with the same logits to get results for comparison
# prepare_inputs_and_kwargs will reuse self._current_logits if it exists
# prepare_pytorch_inputs_and_kwargs will reuse self._current_logits if it exists
from framework.base import TestResult
from framework.utils import (
infinicore_tensor_from_torch,
convert_infinicore_to_torch,
infinicore_tensor_from_torch,
)
inputs, kwargs = self.prepare_inputs_and_kwargs(test_case, device)
inputs, kwargs = self.prepare_pytorch_inputs_and_kwargs(test_case, device)
# Prepare infinicore inputs
infini_inputs = []
......@@ -268,7 +269,13 @@ class OpTest(BaseOperatorTest):
# Check if indices are equal (standard case)
if ic_idx == ref_idx:
return True, "passed"
# Return a successful TestResult object
return TestResult(
success=True,
return_code=0,
test_case=test_case,
device=device,
)
# Special case: indices differ but logits values are equal
# This is valid for random_sample when multiple indices have the same logits value
......@@ -277,7 +284,13 @@ class OpTest(BaseOperatorTest):
logits_ic = logits_tensor[ic_idx].item()
if logits_ic == logits_ref:
# Valid: different indices but same logits value
return True, "passed"
# Return a successful TestResult object
return TestResult(
success=True,
return_code=0,
test_case=test_case,
device=device,
)
except (IndexError, RuntimeError):
# If we can't access the logits, fall through to raise the original error
pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment