Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
5028ea42
Unverified
Commit
5028ea42
authored
Nov 14, 2025
by
pengcheng888
Committed by
GitHub
Nov 14, 2025
Browse files
Merge pull request #597 from pengcheng888/issue/596
issue/596 - 将functional.py中的函数,拆成functional文件夹中的函数
parents
1a618ff0
3e8c6df1
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
161 additions
and
128 deletions
+161
-128
python/infinicore/nn/__init__.py
python/infinicore/nn/__init__.py
+3
-3
python/infinicore/nn/functional.py
python/infinicore/nn/functional.py
+0
-101
python/infinicore/nn/functional/__init__.py
python/infinicore/nn/functional/__init__.py
+13
-0
python/infinicore/nn/functional/causal_softmax.py
python/infinicore/nn/functional/causal_softmax.py
+15
-0
python/infinicore/nn/functional/random_sample.py
python/infinicore/nn/functional/random_sample.py
+38
-0
python/infinicore/nn/functional/rms_norm.py
python/infinicore/nn/functional/rms_norm.py
+26
-0
python/infinicore/nn/functional/silu.py
python/infinicore/nn/functional/silu.py
+23
-0
python/infinicore/nn/functional/swiglu.py
python/infinicore/nn/functional/swiglu.py
+15
-0
test/infinicore/ops/random_sample.py
test/infinicore/ops/random_sample.py
+28
-24
No files found.
python/infinicore/nn/__init__.py
View file @
5028ea42
from
infinicore.nn
import
(
from
infinicore.nn
import
functional
functional
as
functional
,
)
__all__
=
[
"functional"
]
python/infinicore/nn/functional.py
deleted
100644 → 0
View file @
1a618ff0
import
infinicore
from
infinicore.lib
import
_infinicore
from
infinicore.tensor
import
Tensor
__all__
=
[
"causal_softmax"
,
"random_sample"
,
"rms_norm"
,
"silu"
,
"swiglu"
]
def
causal_softmax
(
input
:
Tensor
,
out
=
None
)
->
Tensor
:
r
"""Apply a causal softmax function."""
if
out
is
None
:
return
Tensor
(
_infinicore
.
causal_softmax
(
input
.
_underlying
))
_infinicore
.
causal_softmax_
(
out
.
_underlying
,
input
.
_underlying
)
return
out
def
rms_norm
(
input
:
Tensor
,
normalized_shape
:
list
[
int
],
weight
:
Tensor
,
eps
:
float
=
1e-5
,
*
,
out
=
None
,
)
->
Tensor
:
r
"""Apply Root Mean Square Layer Normalization."""
assert
normalized_shape
==
weight
.
shape
,
(
"normalized_shape does not match weight.shape."
)
if
out
is
None
:
return
Tensor
(
_infinicore
.
rms_norm
(
input
.
_underlying
,
weight
.
_underlying
,
eps
))
_infinicore
.
rms_norm_
(
out
.
_underlying
,
input
.
_underlying
,
weight
.
_underlying
,
eps
)
return
out
def
silu
(
input
:
Tensor
,
inplace
:
bool
=
False
,
*
,
out
=
None
)
->
Tensor
:
r
"""Apply the Sigmoid Linear Unit (SiLU) function, element-wise."""
if
infinicore
.
use_ntops
and
input
.
device
.
type
in
(
"cuda"
,
"musa"
)
and
out
is
None
:
return
infinicore
.
ntops
.
torch
.
silu
(
input
,
inplace
=
inplace
)
if
inplace
:
_infinicore
.
silu_
(
input
.
_underlying
,
input
.
_underlying
)
return
input
if
out
is
None
:
return
Tensor
(
_infinicore
.
silu
(
input
.
_underlying
))
_infinicore
.
silu_
(
out
.
_underlying
,
input
.
_underlying
)
return
out
def
swiglu
(
input
:
Tensor
,
other
:
Tensor
,
*
,
out
=
None
):
r
"""Apply the Swish-Gated Linear Unit (SwiGLU) function, element-wise."""
if
out
is
None
:
return
Tensor
(
_infinicore
.
swiglu
(
input
.
_underlying
,
other
.
_underlying
))
_infinicore
.
swiglu_
(
out
.
_underlying
,
input
.
_underlying
,
other
.
_underlying
)
return
out
def
random_sample
(
logits
:
Tensor
,
random_val
:
float
,
topp
:
float
,
topk
:
int
,
temperature
:
float
,
*
,
out
=
None
,
)
->
Tensor
:
r
"""Sample an index from logits with nucleus/top-k filtering."""
if
out
is
None
:
return
Tensor
(
_infinicore
.
random_sample
(
logits
.
_underlying
,
random_val
,
topp
,
topk
,
temperature
,
)
)
_infinicore
.
random_sample_
(
out
.
_underlying
,
logits
.
_underlying
,
random_val
,
topp
,
topk
,
temperature
,
)
return
out
python/infinicore/nn/functional/__init__.py
0 → 100644
View file @
5028ea42
from
.causal_softmax
import
causal_softmax
from
.random_sample
import
random_sample
from
.rms_norm
import
rms_norm
from
.silu
import
silu
from
.swiglu
import
swiglu
__all__
=
[
"causal_softmax"
,
"random_sample"
,
"rms_norm"
,
"silu"
,
"swiglu"
,
]
python/infinicore/nn/functional/causal_softmax.py
0 → 100644
View file @
5028ea42
from
infinicore.lib
import
_infinicore
from
infinicore.tensor
import
Tensor
__all__
=
[
"causal_softmax"
]
def
causal_softmax
(
input
:
Tensor
,
out
=
None
)
->
Tensor
:
r
"""Apply a causal softmax function."""
if
out
is
None
:
return
Tensor
(
_infinicore
.
causal_softmax
(
input
.
_underlying
))
_infinicore
.
causal_softmax_
(
out
.
_underlying
,
input
.
_underlying
)
return
out
python/infinicore/nn/functional/random_sample.py
0 → 100644
View file @
5028ea42
from
infinicore.lib
import
_infinicore
from
infinicore.tensor
import
Tensor
__all__
=
[
"random_sample"
]
def
random_sample
(
logits
:
Tensor
,
random_val
:
float
,
topp
:
float
,
topk
:
int
,
temperature
:
float
,
*
,
out
=
None
,
)
->
Tensor
:
r
"""Sample an index from logits with nucleus/top-k filtering."""
if
out
is
None
:
return
Tensor
(
_infinicore
.
random_sample
(
logits
.
_underlying
,
random_val
,
topp
,
topk
,
temperature
,
)
)
_infinicore
.
random_sample_
(
out
.
_underlying
,
logits
.
_underlying
,
random_val
,
topp
,
topk
,
temperature
,
)
return
out
python/infinicore/nn/functional/rms_norm.py
0 → 100644
View file @
5028ea42
from
infinicore.lib
import
_infinicore
from
infinicore.tensor
import
Tensor
__all__
=
[
"rms_norm"
]
def
rms_norm
(
input
:
Tensor
,
normalized_shape
:
list
[
int
],
weight
:
Tensor
,
eps
:
float
=
1e-5
,
*
,
out
=
None
,
)
->
Tensor
:
r
"""Apply Root Mean Square Layer Normalization."""
assert
normalized_shape
==
weight
.
shape
,
(
"normalized_shape does not match weight.shape."
)
if
out
is
None
:
return
Tensor
(
_infinicore
.
rms_norm
(
input
.
_underlying
,
weight
.
_underlying
,
eps
))
_infinicore
.
rms_norm_
(
out
.
_underlying
,
input
.
_underlying
,
weight
.
_underlying
,
eps
)
return
out
python/infinicore/nn/functional/silu.py
0 → 100644
View file @
5028ea42
import
infinicore
from
infinicore.lib
import
_infinicore
from
infinicore.tensor
import
Tensor
__all__
=
[
"silu"
]
def
silu
(
input
:
Tensor
,
inplace
:
bool
=
False
,
*
,
out
=
None
)
->
Tensor
:
r
"""Apply the Sigmoid Linear Unit (SiLU) function, element-wise."""
if
infinicore
.
use_ntops
and
input
.
device
.
type
in
(
"cuda"
,
"musa"
)
and
out
is
None
:
return
infinicore
.
ntops
.
torch
.
silu
(
input
,
inplace
=
inplace
)
if
inplace
:
_infinicore
.
silu_
(
input
.
_underlying
,
input
.
_underlying
)
return
input
if
out
is
None
:
return
Tensor
(
_infinicore
.
silu
(
input
.
_underlying
))
_infinicore
.
silu_
(
out
.
_underlying
,
input
.
_underlying
)
return
out
python/infinicore/nn/functional/swiglu.py
0 → 100644
View file @
5028ea42
from
infinicore.lib
import
_infinicore
from
infinicore.tensor
import
Tensor
__all__
=
[
"swiglu"
]
def
swiglu
(
input
:
Tensor
,
other
:
Tensor
,
*
,
out
=
None
):
r
"""Apply the Swish-Gated Linear Unit (SwiGLU) function, element-wise."""
if
out
is
None
:
return
Tensor
(
_infinicore
.
swiglu
(
input
.
_underlying
,
other
.
_underlying
))
_infinicore
.
swiglu_
(
out
.
_underlying
,
input
.
_underlying
,
other
.
_underlying
)
return
out
test/infinicore/ops/random_sample.py
View file @
5028ea42
...
@@ -109,7 +109,11 @@ def torch_random_sample(data, random_val, topp, topk, voc, temperature):
...
@@ -109,7 +109,11 @@ def torch_random_sample(data, random_val, topp, topk, voc, temperature):
idx
=
torch
.
searchsorted
(
cum_probs
,
threshold
)
idx
=
torch
.
searchsorted
(
cum_probs
,
threshold
)
except
Exception
:
except
Exception
:
indices
=
(
cum_probs
>=
threshold
).
nonzero
(
as_tuple
=
True
)[
0
]
indices
=
(
cum_probs
>=
threshold
).
nonzero
(
as_tuple
=
True
)[
0
]
idx
=
indices
[
0
]
if
indices
.
numel
()
>
0
else
torch
.
tensor
(
len
(
cum_probs
)
-
1
,
device
=
cum_probs
.
device
)
idx
=
(
indices
[
0
]
if
indices
.
numel
()
>
0
else
torch
.
tensor
(
len
(
cum_probs
)
-
1
,
device
=
cum_probs
.
device
)
)
return
sorted_indices
[
idx
]
return
sorted_indices
[
idx
]
return
torch
.
argmax
(
data
)
return
torch
.
argmax
(
data
)
...
@@ -191,41 +195,41 @@ class OpTest(BaseOperatorTest):
...
@@ -191,41 +195,41 @@ class OpTest(BaseOperatorTest):
def
run_test
(
self
,
device
,
test_case
,
config
):
def
run_test
(
self
,
device
,
test_case
,
config
):
"""
"""
Override run_test to handle random_sample's special comparison logic.
Override run_test to handle random_sample's special comparison logic.
For random_sample, if the indices differ but the logits values at those
For random_sample, if the indices differ but the logits values at those
indices are equal, the result is still considered valid. This handles
indices are equal, the result is still considered valid. This handles
cases where multiple valid indices exist due to floating-point precision.
cases where multiple valid indices exist due to floating-point precision.
This is necessary because random_sample can return different valid indices
This is necessary because random_sample can return different valid indices
when multiple positions have the same logits value, especially with
when multiple positions have the same logits value, especially with
low-precision types like bfloat16 due to floating-point rounding.
low-precision types like bfloat16 due to floating-point rounding.
"""
"""
# Clear stored logits before test to ensure fresh generation
# Clear stored logits before test to ensure fresh generation
self
.
_current_logits
=
None
self
.
_current_logits
=
None
try
:
try
:
# Try the standard comparison first
# Try the standard comparison first
# This will call prepare_inputs_and_kwargs which will set self._current_logits
# This will call prepare_inputs_and_kwargs which will set self._current_logits
return
super
().
run_test
(
device
,
test_case
,
config
)
return
super
().
run_test
(
device
,
test_case
,
config
)
except
AssertionError
:
except
AssertionError
as
original_error
:
# If standard comparison fails, check if this is a valid case where
# If standard comparison fails, check if this is a valid case where
# indices differ but logits values are equal
# indices differ but logits values are equal
# Only handle if we have stored logits (from prepare_inputs_and_kwargs)
# Only handle if we have stored logits (from prepare_inputs_and_kwargs)
if
self
.
_current_logits
is
None
:
if
self
.
_current_logits
is
None
:
raise
raise
logits_tensor
=
self
.
_current_logits
logits_tensor
=
self
.
_current_logits
# Re-run operations with the same logits to get results for comparison
# Re-run operations with the same logits to get results for comparison
# prepare_inputs_and_kwargs will reuse self._current_logits if it exists
# prepare_inputs_and_kwargs will reuse self._current_logits if it exists
from
framework.utils
import
(
from
framework.utils
import
(
infinicore_tensor_from_torch
,
infinicore_tensor_from_torch
,
convert_infinicore_to_torch
,
convert_infinicore_to_torch
,
)
)
inputs
,
kwargs
=
self
.
prepare_inputs_and_kwargs
(
test_case
,
device
)
inputs
,
kwargs
=
self
.
prepare_inputs_and_kwargs
(
test_case
,
device
)
# Prepare infinicore inputs
# Prepare infinicore inputs
infini_inputs
=
[]
infini_inputs
=
[]
for
inp
in
inputs
:
for
inp
in
inputs
:
...
@@ -235,37 +239,37 @@ class OpTest(BaseOperatorTest):
...
@@ -235,37 +239,37 @@ class OpTest(BaseOperatorTest):
infini_inputs
.
append
(
infini_tensor
)
infini_inputs
.
append
(
infini_tensor
)
else
:
else
:
infini_inputs
.
append
(
inp
)
infini_inputs
.
append
(
inp
)
infini_kwargs
=
kwargs
.
copy
()
infini_kwargs
=
kwargs
.
copy
()
if
"out"
in
infini_kwargs
and
isinstance
(
infini_kwargs
[
"out"
],
torch
.
Tensor
):
if
"out"
in
infini_kwargs
and
isinstance
(
infini_kwargs
[
"out"
],
torch
.
Tensor
):
cloned_out
=
infini_kwargs
[
"out"
].
clone
().
detach
()
cloned_out
=
infini_kwargs
[
"out"
].
clone
().
detach
()
infini_kwargs
[
"out"
]
=
infinicore_tensor_from_torch
(
cloned_out
)
infini_kwargs
[
"out"
]
=
infinicore_tensor_from_torch
(
cloned_out
)
# Run both operators
# Run both operators
torch_result
=
self
.
torch_operator
(
*
inputs
,
**
kwargs
)
torch_result
=
self
.
torch_operator
(
*
inputs
,
**
kwargs
)
infini_result
=
self
.
infinicore_operator
(
*
infini_inputs
,
**
infini_kwargs
)
infini_result
=
self
.
infinicore_operator
(
*
infini_inputs
,
**
infini_kwargs
)
# Extract indices from results
# Extract indices from results
comparison_target
=
test_case
.
comparison_target
comparison_target
=
test_case
.
comparison_target
if
comparison_target
==
"out"
:
if
comparison_target
==
"out"
:
# Compare output tensor from kwargs
# Compare output tensor from kwargs
ref_idx
=
kwargs
[
"out"
].
item
()
ref_idx
=
kwargs
[
"out"
].
item
()
torch_result_from_infini
=
convert_infinicore_to_torch
(
torch_result_from_infini
=
convert_infinicore_to_torch
(
infini_kwargs
[
"out"
]
,
kwargs
[
"out"
]
infini_kwargs
[
"out"
]
)
)
ic_idx
=
torch_result_from_infini
.
item
()
ic_idx
=
torch_result_from_infini
.
item
()
else
:
else
:
# Compare return values
# Compare return values
ref_idx
=
torch_result
.
item
()
ref_idx
=
torch_result
.
item
()
torch_result_from_infini
=
convert_infinicore_to_torch
(
torch_result_from_infini
=
convert_infinicore_to_torch
(
infini_result
)
infini_result
,
torch_result
)
ic_idx
=
torch_result_from_infini
.
item
()
ic_idx
=
torch_result_from_infini
.
item
()
# Check if indices are equal (standard case)
# Check if indices are equal (standard case)
if
ic_idx
==
ref_idx
:
if
ic_idx
==
ref_idx
:
return
return
True
,
"passed"
# Special case: indices differ but logits values are equal
# Special case: indices differ but logits values are equal
# This is valid for random_sample when multiple indices have the same logits value
# This is valid for random_sample when multiple indices have the same logits value
try
:
try
:
...
@@ -273,13 +277,13 @@ class OpTest(BaseOperatorTest):
...
@@ -273,13 +277,13 @@ class OpTest(BaseOperatorTest):
logits_ic
=
logits_tensor
[
ic_idx
].
item
()
logits_ic
=
logits_tensor
[
ic_idx
].
item
()
if
logits_ic
==
logits_ref
:
if
logits_ic
==
logits_ref
:
# Valid: different indices but same logits value
# Valid: different indices but same logits value
return
return
True
,
"passed"
except
(
IndexError
,
RuntimeError
):
except
(
IndexError
,
RuntimeError
):
# If we can't access the logits, fall through to raise the original error
# If we can't access the logits, fall through to raise the original error
pass
pass
# If we get here, the results are truly different
# If we get here, the results are truly different
raise
raise
original_error
def
main
():
def
main
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment