Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
f4bf6ac9
Unverified
Commit
f4bf6ac9
authored
Nov 19, 2025
by
thatPepe
Committed by
GitHub
Nov 19, 2025
Browse files
Merge pull request #617 from gongchensu/feature/fix_randomSample_tests
Fix random_sample test for new framework API.
parents
79b70e58
77a96137
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
22 additions
and
9 deletions
+22
-9
test/infinicore/ops/random_sample.py
test/infinicore/ops/random_sample.py
+22
-9
No files found.
test/infinicore/ops/random_sample.py
View file @
f4bf6ac9
...
@@ -129,9 +129,9 @@ class OpTest(BaseOperatorTest):
...
@@ -129,9 +129,9 @@ class OpTest(BaseOperatorTest):
def
get_test_cases
(
self
):
def
get_test_cases
(
self
):
return
parse_test_cases
()
return
parse_test_cases
()
def
prepare_inputs_and_kwargs
(
self
,
test_case
,
device
):
def
prepare_
pytorch_
inputs_and_kwargs
(
self
,
test_case
,
device
):
"""Prepare inputs and kwargs, replacing TensorSpec objects with actual tensors"""
"""Prepare inputs and kwargs, replacing TensorSpec objects with actual tensors"""
inputs
,
kwargs
=
super
().
prepare_inputs_and_kwargs
(
test_case
,
device
)
inputs
,
kwargs
=
super
().
prepare_
pytorch_
inputs_and_kwargs
(
test_case
,
device
)
# If we already have stored logits (from a previous call), reuse them
# If we already have stored logits (from a previous call), reuse them
# to ensure consistency across multiple calls for the same test case
# to ensure consistency across multiple calls for the same test case
...
@@ -209,26 +209,27 @@ class OpTest(BaseOperatorTest):
...
@@ -209,26 +209,27 @@ class OpTest(BaseOperatorTest):
try
:
try
:
# Try the standard comparison first
# Try the standard comparison first
# This will call prepare_inputs_and_kwargs which will set self._current_logits
# This will call prepare_
pytorch_
inputs_and_kwargs which will set self._current_logits
return
super
().
run_test
(
device
,
test_case
,
config
)
return
super
().
run_test
(
device
,
test_case
,
config
)
except
AssertionError
as
original_error
:
except
AssertionError
as
original_error
:
# If standard comparison fails, check if this is a valid case where
# If standard comparison fails, check if this is a valid case where
# indices differ but logits values are equal
# indices differ but logits values are equal
# Only handle if we have stored logits (from prepare_inputs_and_kwargs)
# Only handle if we have stored logits (from prepare_
pytorch_
inputs_and_kwargs)
if
self
.
_current_logits
is
None
:
if
self
.
_current_logits
is
None
:
raise
raise
logits_tensor
=
self
.
_current_logits
logits_tensor
=
self
.
_current_logits
# Re-run operations with the same logits to get results for comparison
# Re-run operations with the same logits to get results for comparison
# prepare_inputs_and_kwargs will reuse self._current_logits if it exists
# prepare_pytorch_inputs_and_kwargs will reuse self._current_logits if it exists
from
framework.base
import
TestResult
from
framework.utils
import
(
from
framework.utils
import
(
infinicore_tensor_from_torch
,
convert_infinicore_to_torch
,
convert_infinicore_to_torch
,
infinicore_tensor_from_torch
,
)
)
inputs
,
kwargs
=
self
.
prepare_inputs_and_kwargs
(
test_case
,
device
)
inputs
,
kwargs
=
self
.
prepare_
pytorch_
inputs_and_kwargs
(
test_case
,
device
)
# Prepare infinicore inputs
# Prepare infinicore inputs
infini_inputs
=
[]
infini_inputs
=
[]
...
@@ -268,7 +269,13 @@ class OpTest(BaseOperatorTest):
...
@@ -268,7 +269,13 @@ class OpTest(BaseOperatorTest):
# Check if indices are equal (standard case)
# Check if indices are equal (standard case)
if
ic_idx
==
ref_idx
:
if
ic_idx
==
ref_idx
:
return
True
,
"passed"
# Return a successful TestResult object
return
TestResult
(
success
=
True
,
return_code
=
0
,
test_case
=
test_case
,
device
=
device
,
)
# Special case: indices differ but logits values are equal
# Special case: indices differ but logits values are equal
# This is valid for random_sample when multiple indices have the same logits value
# This is valid for random_sample when multiple indices have the same logits value
...
@@ -277,7 +284,13 @@ class OpTest(BaseOperatorTest):
...
@@ -277,7 +284,13 @@ class OpTest(BaseOperatorTest):
logits_ic
=
logits_tensor
[
ic_idx
].
item
()
logits_ic
=
logits_tensor
[
ic_idx
].
item
()
if
logits_ic
==
logits_ref
:
if
logits_ic
==
logits_ref
:
# Valid: different indices but same logits value
# Valid: different indices but same logits value
return
True
,
"passed"
# Return a successful TestResult object
return
TestResult
(
success
=
True
,
return_code
=
0
,
test_case
=
test_case
,
device
=
device
,
)
except
(
IndexError
,
RuntimeError
):
except
(
IndexError
,
RuntimeError
):
# If we can't access the logits, fall through to raise the original error
# If we can't access the logits, fall through to raise the original error
pass
pass
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment