Commit 77a96137 authored by zhuyue's avatar zhuyue Committed by zhuyue
Browse files

Fix random_sample test for new framework API.

parent 74934cdf
...@@ -129,9 +129,9 @@ class OpTest(BaseOperatorTest): ...@@ -129,9 +129,9 @@ class OpTest(BaseOperatorTest):
def get_test_cases(self): def get_test_cases(self):
return parse_test_cases() return parse_test_cases()
def prepare_inputs_and_kwargs(self, test_case, device): def prepare_pytorch_inputs_and_kwargs(self, test_case, device):
"""Prepare inputs and kwargs, replacing TensorSpec objects with actual tensors""" """Prepare inputs and kwargs, replacing TensorSpec objects with actual tensors"""
inputs, kwargs = super().prepare_inputs_and_kwargs(test_case, device) inputs, kwargs = super().prepare_pytorch_inputs_and_kwargs(test_case, device)
# If we already have stored logits (from a previous call), reuse them # If we already have stored logits (from a previous call), reuse them
# to ensure consistency across multiple calls for the same test case # to ensure consistency across multiple calls for the same test case
...@@ -209,26 +209,27 @@ class OpTest(BaseOperatorTest): ...@@ -209,26 +209,27 @@ class OpTest(BaseOperatorTest):
try: try:
# Try the standard comparison first # Try the standard comparison first
# This will call prepare_inputs_and_kwargs which will set self._current_logits # This will call prepare_pytorch_inputs_and_kwargs which will set self._current_logits
return super().run_test(device, test_case, config) return super().run_test(device, test_case, config)
except AssertionError as original_error: except AssertionError as original_error:
# If standard comparison fails, check if this is a valid case where # If standard comparison fails, check if this is a valid case where
# indices differ but logits values are equal # indices differ but logits values are equal
# Only handle if we have stored logits (from prepare_inputs_and_kwargs) # Only handle if we have stored logits (from prepare_pytorch_inputs_and_kwargs)
if self._current_logits is None: if self._current_logits is None:
raise raise
logits_tensor = self._current_logits logits_tensor = self._current_logits
# Re-run operations with the same logits to get results for comparison # Re-run operations with the same logits to get results for comparison
# prepare_inputs_and_kwargs will reuse self._current_logits if it exists # prepare_pytorch_inputs_and_kwargs will reuse self._current_logits if it exists
from framework.base import TestResult
from framework.utils import ( from framework.utils import (
infinicore_tensor_from_torch,
convert_infinicore_to_torch, convert_infinicore_to_torch,
infinicore_tensor_from_torch,
) )
inputs, kwargs = self.prepare_inputs_and_kwargs(test_case, device) inputs, kwargs = self.prepare_pytorch_inputs_and_kwargs(test_case, device)
# Prepare infinicore inputs # Prepare infinicore inputs
infini_inputs = [] infini_inputs = []
...@@ -268,7 +269,13 @@ class OpTest(BaseOperatorTest): ...@@ -268,7 +269,13 @@ class OpTest(BaseOperatorTest):
# Check if indices are equal (standard case) # Check if indices are equal (standard case)
if ic_idx == ref_idx: if ic_idx == ref_idx:
return True, "passed" # Return a successful TestResult object
return TestResult(
success=True,
return_code=0,
test_case=test_case,
device=device,
)
# Special case: indices differ but logits values are equal # Special case: indices differ but logits values are equal
# This is valid for random_sample when multiple indices have the same logits value # This is valid for random_sample when multiple indices have the same logits value
...@@ -277,7 +284,13 @@ class OpTest(BaseOperatorTest): ...@@ -277,7 +284,13 @@ class OpTest(BaseOperatorTest):
logits_ic = logits_tensor[ic_idx].item() logits_ic = logits_tensor[ic_idx].item()
if logits_ic == logits_ref: if logits_ic == logits_ref:
# Valid: different indices but same logits value # Valid: different indices but same logits value
return True, "passed" # Return a successful TestResult object
return TestResult(
success=True,
return_code=0,
test_case=test_case,
device=device,
)
except (IndexError, RuntimeError): except (IndexError, RuntimeError):
# If we can't access the logits, fall through to raise the original error # If we can't access the logits, fall through to raise the original error
pass pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment