Commit 3e8c6df1 authored by pengcheng888's avatar pengcheng888
Browse files

issue/596 - 将functional.py中的函数,拆成functional文件夹中的函数

parent 1a618ff0
from infinicore.nn import (
functional as functional,
)
from infinicore.nn import functional
__all__ = ["functional"]
import infinicore
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor
__all__ = ["causal_softmax", "random_sample", "rms_norm", "silu", "swiglu"]
def causal_softmax(input: Tensor, out=None) -> Tensor:
r"""Apply a causal softmax function."""
if out is None:
return Tensor(_infinicore.causal_softmax(input._underlying))
_infinicore.causal_softmax_(out._underlying, input._underlying)
return out
def rms_norm(
input: Tensor,
normalized_shape: list[int],
weight: Tensor,
eps: float = 1e-5,
*,
out=None,
) -> Tensor:
r"""Apply Root Mean Square Layer Normalization."""
assert normalized_shape == weight.shape, (
"normalized_shape does not match weight.shape."
)
if out is None:
return Tensor(_infinicore.rms_norm(input._underlying, weight._underlying, eps))
_infinicore.rms_norm_(out._underlying, input._underlying, weight._underlying, eps)
return out
def silu(input: Tensor, inplace: bool = False, *, out=None) -> Tensor:
r"""Apply the Sigmoid Linear Unit (SiLU) function, element-wise."""
if infinicore.use_ntops and input.device.type in ("cuda", "musa") and out is None:
return infinicore.ntops.torch.silu(input, inplace=inplace)
if inplace:
_infinicore.silu_(input._underlying, input._underlying)
return input
if out is None:
return Tensor(_infinicore.silu(input._underlying))
_infinicore.silu_(out._underlying, input._underlying)
return out
def swiglu(input: Tensor, other: Tensor, *, out=None):
r"""Apply the Swish-Gated Linear Unit (SwiGLU) function, element-wise."""
if out is None:
return Tensor(_infinicore.swiglu(input._underlying, other._underlying))
_infinicore.swiglu_(out._underlying, input._underlying, other._underlying)
return out
def random_sample(
logits: Tensor,
random_val: float,
topp: float,
topk: int,
temperature: float,
*,
out=None,
) -> Tensor:
r"""Sample an index from logits with nucleus/top-k filtering."""
if out is None:
return Tensor(
_infinicore.random_sample(
logits._underlying,
random_val,
topp,
topk,
temperature,
)
)
_infinicore.random_sample_(
out._underlying,
logits._underlying,
random_val,
topp,
topk,
temperature,
)
return out
from .causal_softmax import causal_softmax
from .random_sample import random_sample
from .rms_norm import rms_norm
from .silu import silu
from .swiglu import swiglu
__all__ = [
"causal_softmax",
"random_sample",
"rms_norm",
"silu",
"swiglu",
]
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor
__all__ = ["causal_softmax"]
def causal_softmax(input: Tensor, out=None) -> Tensor:
r"""Apply a causal softmax function."""
if out is None:
return Tensor(_infinicore.causal_softmax(input._underlying))
_infinicore.causal_softmax_(out._underlying, input._underlying)
return out
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor
__all__ = ["random_sample"]
def random_sample(
logits: Tensor,
random_val: float,
topp: float,
topk: int,
temperature: float,
*,
out=None,
) -> Tensor:
r"""Sample an index from logits with nucleus/top-k filtering."""
if out is None:
return Tensor(
_infinicore.random_sample(
logits._underlying,
random_val,
topp,
topk,
temperature,
)
)
_infinicore.random_sample_(
out._underlying,
logits._underlying,
random_val,
topp,
topk,
temperature,
)
return out
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor
__all__ = ["rms_norm"]
def rms_norm(
input: Tensor,
normalized_shape: list[int],
weight: Tensor,
eps: float = 1e-5,
*,
out=None,
) -> Tensor:
r"""Apply Root Mean Square Layer Normalization."""
assert normalized_shape == weight.shape, (
"normalized_shape does not match weight.shape."
)
if out is None:
return Tensor(_infinicore.rms_norm(input._underlying, weight._underlying, eps))
_infinicore.rms_norm_(out._underlying, input._underlying, weight._underlying, eps)
return out
import infinicore
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor
__all__ = ["silu"]
def silu(input: Tensor, inplace: bool = False, *, out=None) -> Tensor:
r"""Apply the Sigmoid Linear Unit (SiLU) function, element-wise."""
if infinicore.use_ntops and input.device.type in ("cuda", "musa") and out is None:
return infinicore.ntops.torch.silu(input, inplace=inplace)
if inplace:
_infinicore.silu_(input._underlying, input._underlying)
return input
if out is None:
return Tensor(_infinicore.silu(input._underlying))
_infinicore.silu_(out._underlying, input._underlying)
return out
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor
__all__ = ["swiglu"]
def swiglu(input: Tensor, other: Tensor, *, out=None):
r"""Apply the Swish-Gated Linear Unit (SwiGLU) function, element-wise."""
if out is None:
return Tensor(_infinicore.swiglu(input._underlying, other._underlying))
_infinicore.swiglu_(out._underlying, input._underlying, other._underlying)
return out
......@@ -109,7 +109,11 @@ def torch_random_sample(data, random_val, topp, topk, voc, temperature):
idx = torch.searchsorted(cum_probs, threshold)
except Exception:
indices = (cum_probs >= threshold).nonzero(as_tuple=True)[0]
idx = indices[0] if indices.numel() > 0 else torch.tensor(len(cum_probs) - 1, device=cum_probs.device)
idx = (
indices[0]
if indices.numel() > 0
else torch.tensor(len(cum_probs) - 1, device=cum_probs.device)
)
return sorted_indices[idx]
return torch.argmax(data)
......@@ -191,41 +195,41 @@ class OpTest(BaseOperatorTest):
def run_test(self, device, test_case, config):
"""
Override run_test to handle random_sample's special comparison logic.
For random_sample, if the indices differ but the logits values at those
indices are equal, the result is still considered valid. This handles
cases where multiple valid indices exist due to floating-point precision.
This is necessary because random_sample can return different valid indices
when multiple positions have the same logits value, especially with
low-precision types like bfloat16 due to floating-point rounding.
"""
# Clear stored logits before test to ensure fresh generation
self._current_logits = None
try:
# Try the standard comparison first
# This will call prepare_inputs_and_kwargs which will set self._current_logits
return super().run_test(device, test_case, config)
except AssertionError:
except AssertionError as original_error:
# If standard comparison fails, check if this is a valid case where
# indices differ but logits values are equal
# Only handle if we have stored logits (from prepare_inputs_and_kwargs)
if self._current_logits is None:
raise
logits_tensor = self._current_logits
# Re-run operations with the same logits to get results for comparison
# prepare_inputs_and_kwargs will reuse self._current_logits if it exists
from framework.utils import (
infinicore_tensor_from_torch,
convert_infinicore_to_torch,
)
inputs, kwargs = self.prepare_inputs_and_kwargs(test_case, device)
# Prepare infinicore inputs
infini_inputs = []
for inp in inputs:
......@@ -235,37 +239,37 @@ class OpTest(BaseOperatorTest):
infini_inputs.append(infini_tensor)
else:
infini_inputs.append(inp)
infini_kwargs = kwargs.copy()
if "out" in infini_kwargs and isinstance(infini_kwargs["out"], torch.Tensor):
if "out" in infini_kwargs and isinstance(
infini_kwargs["out"], torch.Tensor
):
cloned_out = infini_kwargs["out"].clone().detach()
infini_kwargs["out"] = infinicore_tensor_from_torch(cloned_out)
# Run both operators
torch_result = self.torch_operator(*inputs, **kwargs)
infini_result = self.infinicore_operator(*infini_inputs, **infini_kwargs)
# Extract indices from results
comparison_target = test_case.comparison_target
if comparison_target == "out":
# Compare output tensor from kwargs
ref_idx = kwargs["out"].item()
torch_result_from_infini = convert_infinicore_to_torch(
infini_kwargs["out"], kwargs["out"]
infini_kwargs["out"]
)
ic_idx = torch_result_from_infini.item()
else:
# Compare return values
ref_idx = torch_result.item()
torch_result_from_infini = convert_infinicore_to_torch(
infini_result, torch_result
)
torch_result_from_infini = convert_infinicore_to_torch(infini_result)
ic_idx = torch_result_from_infini.item()
# Check if indices are equal (standard case)
if ic_idx == ref_idx:
return
return True, "passed"
# Special case: indices differ but logits values are equal
# This is valid for random_sample when multiple indices have the same logits value
try:
......@@ -273,13 +277,13 @@ class OpTest(BaseOperatorTest):
logits_ic = logits_tensor[ic_idx].item()
if logits_ic == logits_ref:
# Valid: different indices but same logits value
return
return True, "passed"
except (IndexError, RuntimeError):
# If we can't access the logits, fall through to raise the original error
pass
# If we get here, the results are truly different
raise
raise original_error
def main():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment