Commit 3e8c6df1 authored by pengcheng888's avatar pengcheng888
Browse files

issue/596 - 将functional.py中的函数,拆成functional文件夹中的函数

parent 1a618ff0
from infinicore.nn import ( from infinicore.nn import functional
functional as functional,
) __all__ = ["functional"]
import infinicore
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor
__all__ = ["causal_softmax", "random_sample", "rms_norm", "silu", "swiglu"]
def causal_softmax(input: Tensor, out=None) -> Tensor:
r"""Apply a causal softmax function."""
if out is None:
return Tensor(_infinicore.causal_softmax(input._underlying))
_infinicore.causal_softmax_(out._underlying, input._underlying)
return out
def rms_norm(
input: Tensor,
normalized_shape: list[int],
weight: Tensor,
eps: float = 1e-5,
*,
out=None,
) -> Tensor:
r"""Apply Root Mean Square Layer Normalization."""
assert normalized_shape == weight.shape, (
"normalized_shape does not match weight.shape."
)
if out is None:
return Tensor(_infinicore.rms_norm(input._underlying, weight._underlying, eps))
_infinicore.rms_norm_(out._underlying, input._underlying, weight._underlying, eps)
return out
def silu(input: Tensor, inplace: bool = False, *, out=None) -> Tensor:
r"""Apply the Sigmoid Linear Unit (SiLU) function, element-wise."""
if infinicore.use_ntops and input.device.type in ("cuda", "musa") and out is None:
return infinicore.ntops.torch.silu(input, inplace=inplace)
if inplace:
_infinicore.silu_(input._underlying, input._underlying)
return input
if out is None:
return Tensor(_infinicore.silu(input._underlying))
_infinicore.silu_(out._underlying, input._underlying)
return out
def swiglu(input: Tensor, other: Tensor, *, out=None):
r"""Apply the Swish-Gated Linear Unit (SwiGLU) function, element-wise."""
if out is None:
return Tensor(_infinicore.swiglu(input._underlying, other._underlying))
_infinicore.swiglu_(out._underlying, input._underlying, other._underlying)
return out
def random_sample(
logits: Tensor,
random_val: float,
topp: float,
topk: int,
temperature: float,
*,
out=None,
) -> Tensor:
r"""Sample an index from logits with nucleus/top-k filtering."""
if out is None:
return Tensor(
_infinicore.random_sample(
logits._underlying,
random_val,
topp,
topk,
temperature,
)
)
_infinicore.random_sample_(
out._underlying,
logits._underlying,
random_val,
topp,
topk,
temperature,
)
return out
from .causal_softmax import causal_softmax
from .random_sample import random_sample
from .rms_norm import rms_norm
from .silu import silu
from .swiglu import swiglu
__all__ = [
"causal_softmax",
"random_sample",
"rms_norm",
"silu",
"swiglu",
]
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor
__all__ = ["causal_softmax"]
def causal_softmax(input: Tensor, out=None) -> Tensor:
r"""Apply a causal softmax function."""
if out is None:
return Tensor(_infinicore.causal_softmax(input._underlying))
_infinicore.causal_softmax_(out._underlying, input._underlying)
return out
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor
__all__ = ["random_sample"]
def random_sample(
logits: Tensor,
random_val: float,
topp: float,
topk: int,
temperature: float,
*,
out=None,
) -> Tensor:
r"""Sample an index from logits with nucleus/top-k filtering."""
if out is None:
return Tensor(
_infinicore.random_sample(
logits._underlying,
random_val,
topp,
topk,
temperature,
)
)
_infinicore.random_sample_(
out._underlying,
logits._underlying,
random_val,
topp,
topk,
temperature,
)
return out
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor
__all__ = ["rms_norm"]
def rms_norm(
input: Tensor,
normalized_shape: list[int],
weight: Tensor,
eps: float = 1e-5,
*,
out=None,
) -> Tensor:
r"""Apply Root Mean Square Layer Normalization."""
assert normalized_shape == weight.shape, (
"normalized_shape does not match weight.shape."
)
if out is None:
return Tensor(_infinicore.rms_norm(input._underlying, weight._underlying, eps))
_infinicore.rms_norm_(out._underlying, input._underlying, weight._underlying, eps)
return out
import infinicore
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor
__all__ = ["silu"]
def silu(input: Tensor, inplace: bool = False, *, out=None) -> Tensor:
r"""Apply the Sigmoid Linear Unit (SiLU) function, element-wise."""
if infinicore.use_ntops and input.device.type in ("cuda", "musa") and out is None:
return infinicore.ntops.torch.silu(input, inplace=inplace)
if inplace:
_infinicore.silu_(input._underlying, input._underlying)
return input
if out is None:
return Tensor(_infinicore.silu(input._underlying))
_infinicore.silu_(out._underlying, input._underlying)
return out
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor
__all__ = ["swiglu"]
def swiglu(input: Tensor, other: Tensor, *, out=None):
r"""Apply the Swish-Gated Linear Unit (SwiGLU) function, element-wise."""
if out is None:
return Tensor(_infinicore.swiglu(input._underlying, other._underlying))
_infinicore.swiglu_(out._underlying, input._underlying, other._underlying)
return out
...@@ -109,7 +109,11 @@ def torch_random_sample(data, random_val, topp, topk, voc, temperature): ...@@ -109,7 +109,11 @@ def torch_random_sample(data, random_val, topp, topk, voc, temperature):
idx = torch.searchsorted(cum_probs, threshold) idx = torch.searchsorted(cum_probs, threshold)
except Exception: except Exception:
indices = (cum_probs >= threshold).nonzero(as_tuple=True)[0] indices = (cum_probs >= threshold).nonzero(as_tuple=True)[0]
idx = indices[0] if indices.numel() > 0 else torch.tensor(len(cum_probs) - 1, device=cum_probs.device) idx = (
indices[0]
if indices.numel() > 0
else torch.tensor(len(cum_probs) - 1, device=cum_probs.device)
)
return sorted_indices[idx] return sorted_indices[idx]
return torch.argmax(data) return torch.argmax(data)
...@@ -207,7 +211,7 @@ class OpTest(BaseOperatorTest): ...@@ -207,7 +211,7 @@ class OpTest(BaseOperatorTest):
# Try the standard comparison first # Try the standard comparison first
# This will call prepare_inputs_and_kwargs which will set self._current_logits # This will call prepare_inputs_and_kwargs which will set self._current_logits
return super().run_test(device, test_case, config) return super().run_test(device, test_case, config)
except AssertionError: except AssertionError as original_error:
# If standard comparison fails, check if this is a valid case where # If standard comparison fails, check if this is a valid case where
# indices differ but logits values are equal # indices differ but logits values are equal
...@@ -237,7 +241,9 @@ class OpTest(BaseOperatorTest): ...@@ -237,7 +241,9 @@ class OpTest(BaseOperatorTest):
infini_inputs.append(inp) infini_inputs.append(inp)
infini_kwargs = kwargs.copy() infini_kwargs = kwargs.copy()
if "out" in infini_kwargs and isinstance(infini_kwargs["out"], torch.Tensor): if "out" in infini_kwargs and isinstance(
infini_kwargs["out"], torch.Tensor
):
cloned_out = infini_kwargs["out"].clone().detach() cloned_out = infini_kwargs["out"].clone().detach()
infini_kwargs["out"] = infinicore_tensor_from_torch(cloned_out) infini_kwargs["out"] = infinicore_tensor_from_torch(cloned_out)
...@@ -251,20 +257,18 @@ class OpTest(BaseOperatorTest): ...@@ -251,20 +257,18 @@ class OpTest(BaseOperatorTest):
# Compare output tensor from kwargs # Compare output tensor from kwargs
ref_idx = kwargs["out"].item() ref_idx = kwargs["out"].item()
torch_result_from_infini = convert_infinicore_to_torch( torch_result_from_infini = convert_infinicore_to_torch(
infini_kwargs["out"], kwargs["out"] infini_kwargs["out"]
) )
ic_idx = torch_result_from_infini.item() ic_idx = torch_result_from_infini.item()
else: else:
# Compare return values # Compare return values
ref_idx = torch_result.item() ref_idx = torch_result.item()
torch_result_from_infini = convert_infinicore_to_torch( torch_result_from_infini = convert_infinicore_to_torch(infini_result)
infini_result, torch_result
)
ic_idx = torch_result_from_infini.item() ic_idx = torch_result_from_infini.item()
# Check if indices are equal (standard case) # Check if indices are equal (standard case)
if ic_idx == ref_idx: if ic_idx == ref_idx:
return return True, "passed"
# Special case: indices differ but logits values are equal # Special case: indices differ but logits values are equal
# This is valid for random_sample when multiple indices have the same logits value # This is valid for random_sample when multiple indices have the same logits value
...@@ -273,13 +277,13 @@ class OpTest(BaseOperatorTest): ...@@ -273,13 +277,13 @@ class OpTest(BaseOperatorTest):
logits_ic = logits_tensor[ic_idx].item() logits_ic = logits_tensor[ic_idx].item()
if logits_ic == logits_ref: if logits_ic == logits_ref:
# Valid: different indices but same logits value # Valid: different indices but same logits value
return return True, "passed"
except (IndexError, RuntimeError): except (IndexError, RuntimeError):
# If we can't access the logits, fall through to raise the original error # If we can't access the logits, fall through to raise the original error
pass pass
# If we get here, the results are truly different # If we get here, the results are truly different
raise raise original_error
def main(): def main():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment