Unverified Commit 5028ea42 authored by pengcheng888's avatar pengcheng888 Committed by GitHub
Browse files

Merge pull request #597 from pengcheng888/issue/596

issue/596 - 将functional.py中的函数,拆成functional文件夹中的函数
parents 1a618ff0 3e8c6df1
from infinicore.nn import ( from infinicore.nn import functional
functional as functional,
) __all__ = ["functional"]
import infinicore
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor
__all__ = ["causal_softmax", "random_sample", "rms_norm", "silu", "swiglu"]
def causal_softmax(input: Tensor, out=None) -> Tensor:
r"""Apply a causal softmax function."""
if out is None:
return Tensor(_infinicore.causal_softmax(input._underlying))
_infinicore.causal_softmax_(out._underlying, input._underlying)
return out
def rms_norm(
input: Tensor,
normalized_shape: list[int],
weight: Tensor,
eps: float = 1e-5,
*,
out=None,
) -> Tensor:
r"""Apply Root Mean Square Layer Normalization."""
assert normalized_shape == weight.shape, (
"normalized_shape does not match weight.shape."
)
if out is None:
return Tensor(_infinicore.rms_norm(input._underlying, weight._underlying, eps))
_infinicore.rms_norm_(out._underlying, input._underlying, weight._underlying, eps)
return out
def silu(input: Tensor, inplace: bool = False, *, out=None) -> Tensor:
r"""Apply the Sigmoid Linear Unit (SiLU) function, element-wise."""
if infinicore.use_ntops and input.device.type in ("cuda", "musa") and out is None:
return infinicore.ntops.torch.silu(input, inplace=inplace)
if inplace:
_infinicore.silu_(input._underlying, input._underlying)
return input
if out is None:
return Tensor(_infinicore.silu(input._underlying))
_infinicore.silu_(out._underlying, input._underlying)
return out
def swiglu(input: Tensor, other: Tensor, *, out=None):
r"""Apply the Swish-Gated Linear Unit (SwiGLU) function, element-wise."""
if out is None:
return Tensor(_infinicore.swiglu(input._underlying, other._underlying))
_infinicore.swiglu_(out._underlying, input._underlying, other._underlying)
return out
def random_sample(
logits: Tensor,
random_val: float,
topp: float,
topk: int,
temperature: float,
*,
out=None,
) -> Tensor:
r"""Sample an index from logits with nucleus/top-k filtering."""
if out is None:
return Tensor(
_infinicore.random_sample(
logits._underlying,
random_val,
topp,
topk,
temperature,
)
)
_infinicore.random_sample_(
out._underlying,
logits._underlying,
random_val,
topp,
topk,
temperature,
)
return out
from .causal_softmax import causal_softmax
from .random_sample import random_sample
from .rms_norm import rms_norm
from .silu import silu
from .swiglu import swiglu
__all__ = [
"causal_softmax",
"random_sample",
"rms_norm",
"silu",
"swiglu",
]
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor
__all__ = ["causal_softmax"]
def causal_softmax(input: Tensor, out=None) -> Tensor:
r"""Apply a causal softmax function."""
if out is None:
return Tensor(_infinicore.causal_softmax(input._underlying))
_infinicore.causal_softmax_(out._underlying, input._underlying)
return out
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor
__all__ = ["random_sample"]
def random_sample(
logits: Tensor,
random_val: float,
topp: float,
topk: int,
temperature: float,
*,
out=None,
) -> Tensor:
r"""Sample an index from logits with nucleus/top-k filtering."""
if out is None:
return Tensor(
_infinicore.random_sample(
logits._underlying,
random_val,
topp,
topk,
temperature,
)
)
_infinicore.random_sample_(
out._underlying,
logits._underlying,
random_val,
topp,
topk,
temperature,
)
return out
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor
__all__ = ["rms_norm"]
def rms_norm(
input: Tensor,
normalized_shape: list[int],
weight: Tensor,
eps: float = 1e-5,
*,
out=None,
) -> Tensor:
r"""Apply Root Mean Square Layer Normalization."""
assert normalized_shape == weight.shape, (
"normalized_shape does not match weight.shape."
)
if out is None:
return Tensor(_infinicore.rms_norm(input._underlying, weight._underlying, eps))
_infinicore.rms_norm_(out._underlying, input._underlying, weight._underlying, eps)
return out
import infinicore
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor
__all__ = ["silu"]
def silu(input: Tensor, inplace: bool = False, *, out=None) -> Tensor:
r"""Apply the Sigmoid Linear Unit (SiLU) function, element-wise."""
if infinicore.use_ntops and input.device.type in ("cuda", "musa") and out is None:
return infinicore.ntops.torch.silu(input, inplace=inplace)
if inplace:
_infinicore.silu_(input._underlying, input._underlying)
return input
if out is None:
return Tensor(_infinicore.silu(input._underlying))
_infinicore.silu_(out._underlying, input._underlying)
return out
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor
__all__ = ["swiglu"]
def swiglu(input: Tensor, other: Tensor, *, out=None):
r"""Apply the Swish-Gated Linear Unit (SwiGLU) function, element-wise."""
if out is None:
return Tensor(_infinicore.swiglu(input._underlying, other._underlying))
_infinicore.swiglu_(out._underlying, input._underlying, other._underlying)
return out
...@@ -109,7 +109,11 @@ def torch_random_sample(data, random_val, topp, topk, voc, temperature): ...@@ -109,7 +109,11 @@ def torch_random_sample(data, random_val, topp, topk, voc, temperature):
idx = torch.searchsorted(cum_probs, threshold) idx = torch.searchsorted(cum_probs, threshold)
except Exception: except Exception:
indices = (cum_probs >= threshold).nonzero(as_tuple=True)[0] indices = (cum_probs >= threshold).nonzero(as_tuple=True)[0]
idx = indices[0] if indices.numel() > 0 else torch.tensor(len(cum_probs) - 1, device=cum_probs.device) idx = (
indices[0]
if indices.numel() > 0
else torch.tensor(len(cum_probs) - 1, device=cum_probs.device)
)
return sorted_indices[idx] return sorted_indices[idx]
return torch.argmax(data) return torch.argmax(data)
...@@ -191,41 +195,41 @@ class OpTest(BaseOperatorTest): ...@@ -191,41 +195,41 @@ class OpTest(BaseOperatorTest):
def run_test(self, device, test_case, config): def run_test(self, device, test_case, config):
""" """
Override run_test to handle random_sample's special comparison logic. Override run_test to handle random_sample's special comparison logic.
For random_sample, if the indices differ but the logits values at those For random_sample, if the indices differ but the logits values at those
indices are equal, the result is still considered valid. This handles indices are equal, the result is still considered valid. This handles
cases where multiple valid indices exist due to floating-point precision. cases where multiple valid indices exist due to floating-point precision.
This is necessary because random_sample can return different valid indices This is necessary because random_sample can return different valid indices
when multiple positions have the same logits value, especially with when multiple positions have the same logits value, especially with
low-precision types like bfloat16 due to floating-point rounding. low-precision types like bfloat16 due to floating-point rounding.
""" """
# Clear stored logits before test to ensure fresh generation # Clear stored logits before test to ensure fresh generation
self._current_logits = None self._current_logits = None
try: try:
# Try the standard comparison first # Try the standard comparison first
# This will call prepare_inputs_and_kwargs which will set self._current_logits # This will call prepare_inputs_and_kwargs which will set self._current_logits
return super().run_test(device, test_case, config) return super().run_test(device, test_case, config)
except AssertionError: except AssertionError as original_error:
# If standard comparison fails, check if this is a valid case where # If standard comparison fails, check if this is a valid case where
# indices differ but logits values are equal # indices differ but logits values are equal
# Only handle if we have stored logits (from prepare_inputs_and_kwargs) # Only handle if we have stored logits (from prepare_inputs_and_kwargs)
if self._current_logits is None: if self._current_logits is None:
raise raise
logits_tensor = self._current_logits logits_tensor = self._current_logits
# Re-run operations with the same logits to get results for comparison # Re-run operations with the same logits to get results for comparison
# prepare_inputs_and_kwargs will reuse self._current_logits if it exists # prepare_inputs_and_kwargs will reuse self._current_logits if it exists
from framework.utils import ( from framework.utils import (
infinicore_tensor_from_torch, infinicore_tensor_from_torch,
convert_infinicore_to_torch, convert_infinicore_to_torch,
) )
inputs, kwargs = self.prepare_inputs_and_kwargs(test_case, device) inputs, kwargs = self.prepare_inputs_and_kwargs(test_case, device)
# Prepare infinicore inputs # Prepare infinicore inputs
infini_inputs = [] infini_inputs = []
for inp in inputs: for inp in inputs:
...@@ -235,37 +239,37 @@ class OpTest(BaseOperatorTest): ...@@ -235,37 +239,37 @@ class OpTest(BaseOperatorTest):
infini_inputs.append(infini_tensor) infini_inputs.append(infini_tensor)
else: else:
infini_inputs.append(inp) infini_inputs.append(inp)
infini_kwargs = kwargs.copy() infini_kwargs = kwargs.copy()
if "out" in infini_kwargs and isinstance(infini_kwargs["out"], torch.Tensor): if "out" in infini_kwargs and isinstance(
infini_kwargs["out"], torch.Tensor
):
cloned_out = infini_kwargs["out"].clone().detach() cloned_out = infini_kwargs["out"].clone().detach()
infini_kwargs["out"] = infinicore_tensor_from_torch(cloned_out) infini_kwargs["out"] = infinicore_tensor_from_torch(cloned_out)
# Run both operators # Run both operators
torch_result = self.torch_operator(*inputs, **kwargs) torch_result = self.torch_operator(*inputs, **kwargs)
infini_result = self.infinicore_operator(*infini_inputs, **infini_kwargs) infini_result = self.infinicore_operator(*infini_inputs, **infini_kwargs)
# Extract indices from results # Extract indices from results
comparison_target = test_case.comparison_target comparison_target = test_case.comparison_target
if comparison_target == "out": if comparison_target == "out":
# Compare output tensor from kwargs # Compare output tensor from kwargs
ref_idx = kwargs["out"].item() ref_idx = kwargs["out"].item()
torch_result_from_infini = convert_infinicore_to_torch( torch_result_from_infini = convert_infinicore_to_torch(
infini_kwargs["out"], kwargs["out"] infini_kwargs["out"]
) )
ic_idx = torch_result_from_infini.item() ic_idx = torch_result_from_infini.item()
else: else:
# Compare return values # Compare return values
ref_idx = torch_result.item() ref_idx = torch_result.item()
torch_result_from_infini = convert_infinicore_to_torch( torch_result_from_infini = convert_infinicore_to_torch(infini_result)
infini_result, torch_result
)
ic_idx = torch_result_from_infini.item() ic_idx = torch_result_from_infini.item()
# Check if indices are equal (standard case) # Check if indices are equal (standard case)
if ic_idx == ref_idx: if ic_idx == ref_idx:
return return True, "passed"
# Special case: indices differ but logits values are equal # Special case: indices differ but logits values are equal
# This is valid for random_sample when multiple indices have the same logits value # This is valid for random_sample when multiple indices have the same logits value
try: try:
...@@ -273,13 +277,13 @@ class OpTest(BaseOperatorTest): ...@@ -273,13 +277,13 @@ class OpTest(BaseOperatorTest):
logits_ic = logits_tensor[ic_idx].item() logits_ic = logits_tensor[ic_idx].item()
if logits_ic == logits_ref: if logits_ic == logits_ref:
# Valid: different indices but same logits value # Valid: different indices but same logits value
return return True, "passed"
except (IndexError, RuntimeError): except (IndexError, RuntimeError):
# If we can't access the logits, fall through to raise the original error # If we can't access the logits, fall through to raise the original error
pass pass
# If we get here, the results are truly different # If we get here, the results are truly different
raise raise original_error
def main(): def main():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment