Commit 61a7dc0e authored by wooway777's avatar wooway777
Browse files

issue/666 - Standardized test imports

parent 406c9668
......@@ -5,9 +5,8 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
from framework.base import BaseOperatorTest, TensorSpec, TestCase
from framework import BaseOperatorTest, TensorSpec, TestCase, GenericTestRunner
from framework.tensor import TensorInitializer
from framework.runner import GenericTestRunner
# Test cases format: (indices_shape, indices_strides_or_None, num_classes_or_None)
_TEST_CASES_DATA = [
......
......@@ -5,9 +5,13 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
from framework.base import BaseOperatorTest, TensorSpec, TestCase
from framework.runner import GenericTestRunner
from framework.utils import is_broadcast
from framework import (
BaseOperatorTest,
TensorSpec,
TestCase,
GenericTestRunner,
is_broadcast,
)
# Test cases format: (shape, pad_tuple, mode, value_or_None, input_strides_or_None)
# infinicore.nn.functional.pad(input, pad, mode='constant', value=0)
......
......@@ -5,9 +5,13 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
from framework.base import BaseOperatorTest, TensorSpec, TestCase
from framework.runner import GenericTestRunner
from framework.utils import is_broadcast
from framework import (
BaseOperatorTest,
TensorSpec,
TestCase,
GenericTestRunner,
is_broadcast,
)
# Test cases format: (shape, p, eps, keepdim, a_strides_or_None, b_strides_or_None)
# infinicore.nn.functional.pairwise_distance(x1, x2, p=2.0, eps=1e-6, keepdim=False)
......
......@@ -5,9 +5,13 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
from framework.base import BaseOperatorTest, TensorSpec, TestCase
from framework.runner import GenericTestRunner
from framework.utils import is_broadcast
from framework import (
BaseOperatorTest,
TensorSpec,
TestCase,
GenericTestRunner,
is_broadcast,
)
# Test cases format: (shape, p, a_strides_or_None)
# infinicore.pdist(input, p=2.0) computes pairwise distances between rows of input
......
......@@ -5,9 +5,13 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
from framework.base import BaseOperatorTest, TensorSpec, TestCase
from framework.runner import GenericTestRunner
from framework.utils import is_broadcast
from framework import (
BaseOperatorTest,
TensorSpec,
TestCase,
GenericTestRunner,
is_broadcast,
)
# Test cases format: (shape, upscale_factor, input_strides_or_None)
# infinicore.nn.functional.pixel_shuffle(input, upscale_factor)
......
......@@ -5,9 +5,13 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
from framework.base import BaseOperatorTest, TensorSpec, TestCase
from framework.runner import GenericTestRunner
from framework.utils import is_broadcast
from framework import (
BaseOperatorTest,
TensorSpec,
TestCase,
GenericTestRunner,
is_broadcast,
)
# Test cases format: (shape, downscale_factor, input_strides_or_None)
# infinicore.nn.functional.pixel_unshuffle(input, downscale_factor)
......
......@@ -5,9 +5,13 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
from framework.base import BaseOperatorTest, TensorSpec, TestCase
from framework.runner import GenericTestRunner
from framework.utils import is_broadcast
from framework import (
BaseOperatorTest,
TensorSpec,
TestCase,
GenericTestRunner,
is_broadcast,
)
# Test cases format: (in_shape, in_strides_or_None, weight_shape_or_None)
# Note: PReLU requires a weight parameter of shape (C,) or (1,), we create a per-channel weight when possible.
......
......@@ -5,9 +5,13 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
from framework.base import BaseOperatorTest, TensorSpec, TestCase
from framework.runner import GenericTestRunner
from framework.utils import is_broadcast
from framework import (
BaseOperatorTest,
TensorSpec,
TestCase,
GenericTestRunner,
is_broadcast,
)
# Test cases format: (in_shape, in_strides_or_None, dim_or_None, keepdim_or_None, dtype_or_None)
# prod computes product along dim(s) or overall
......
......@@ -5,9 +5,13 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
from framework.base import BaseOperatorTest, TensorSpec, TestCase
from framework.runner import GenericTestRunner
from framework.utils import is_broadcast
from framework import (
BaseOperatorTest,
TensorSpec,
TestCase,
GenericTestRunner,
is_broadcast,
)
# Test cases format: (in_shape, in_strides_or_None, q_or_None, dim_or_None, keepdim_or_None, out_strides_or_None)
# quantile computes quantiles along dim or overall. q may be float or tensor
......
......@@ -6,8 +6,7 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
import infinicore.nn.functional as F
from framework.base import BaseOperatorTest, TensorSpec, TestCase
from framework.runner import GenericTestRunner
from framework import BaseOperatorTest, TensorSpec, TestCase, GenericTestRunner
from framework.tensor import TensorInitializer
# ==============================================================================
......
......@@ -5,9 +5,13 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
from framework.base import BaseOperatorTest, TensorSpec, TestCase
from framework.runner import GenericTestRunner
from framework.utils import is_broadcast
from framework import (
BaseOperatorTest,
TensorSpec,
TestCase,
GenericTestRunner,
is_broadcast,
)
# Test cases format: (in_shape, in_strides_or_None)
# infinicore.reciprocal(input)
......
......@@ -5,9 +5,13 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
from framework.base import BaseOperatorTest, TensorSpec, TestCase
from framework.runner import GenericTestRunner
from framework.utils import is_broadcast
from framework import (
BaseOperatorTest,
TensorSpec,
TestCase,
GenericTestRunner,
is_broadcast,
)
# Test cases format: (in_shape, in_strides_or_None)
......
......@@ -5,9 +5,13 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
from framework.base import BaseOperatorTest, TensorSpec, TestCase
from framework.runner import GenericTestRunner
from framework.utils import is_broadcast
from framework import (
BaseOperatorTest,
TensorSpec,
TestCase,
GenericTestRunner,
is_broadcast,
)
# Test cases format: (a_shape, a_strides_or_None, b_shape_or_None)
# infinicore.remainder(a, b)
......
......@@ -5,9 +5,13 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
from framework.base import BaseOperatorTest, TensorSpec, TestCase
from framework.runner import GenericTestRunner
from framework.utils import is_broadcast
from framework import (
BaseOperatorTest,
TensorSpec,
TestCase,
GenericTestRunner,
is_broadcast,
)
# Test cases format: (in_shape, in_strides_or_None, new_shape)
# reshape can change shape; out parameter is not used in infinicore.reshape API (returns view or tensor)
......
......@@ -5,9 +5,13 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
from framework.base import BaseOperatorTest, TensorSpec, TestCase
from framework.runner import GenericTestRunner
from framework.utils import is_broadcast
from framework import (
BaseOperatorTest,
TensorSpec,
TestCase,
GenericTestRunner,
is_broadcast,
)
# ==============================================================================
# Operator-specific configuration
......@@ -150,7 +154,7 @@ class OpTest(BaseOperatorTest):
def infinicore_operator(self, x, weight, epsilon=_EPSILON, out=None, **kwargs):
"""InfiniCore RMSNorm implementation"""
import infinicore.nn.functional as F
return F.rms_norm(x, weight.shape, weight, epsilon, out=out)
......
......@@ -5,9 +5,13 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
from framework.base import BaseOperatorTest, TensorSpec, TestCase
from framework.runner import GenericTestRunner
from framework.utils import is_broadcast
from framework import (
BaseOperatorTest,
TensorSpec,
TestCase,
GenericTestRunner,
is_broadcast,
)
from infinicore.nn.functional import RopeAlgo
import infinicore
......
......@@ -5,9 +5,13 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
from framework.base import BaseOperatorTest, TensorSpec, TestCase
from framework.runner import GenericTestRunner
from framework.utils import is_broadcast
from framework import (
BaseOperatorTest,
TensorSpec,
TestCase,
GenericTestRunner,
is_broadcast,
)
# Test cases format: (shape, k, dims_tuple, input_strides_or_None)
# infinicore.rot90(input, k=1, dims=(0,1))
......
......@@ -5,9 +5,13 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
from framework.base import BaseOperatorTest, TensorSpec, TestCase
from framework.runner import GenericTestRunner
from framework.utils import is_broadcast
from framework import (
BaseOperatorTest,
TensorSpec,
TestCase,
GenericTestRunner,
is_broadcast,
)
# round(input, decimals=0)
# We'll test with various decimals including negative values and None.
......
......@@ -5,9 +5,13 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
from framework.base import BaseOperatorTest, TensorSpec, TestCase
from framework.runner import GenericTestRunner
from framework.utils import is_broadcast
from framework import (
BaseOperatorTest,
TensorSpec,
TestCase,
GenericTestRunner,
is_broadcast,
)
# Test cases format: (in_shape, in_strides_or_None, lower_or_None, upper_or_None)
......
......@@ -5,8 +5,7 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
from framework.base import BaseOperatorTest, TensorSpec, TestCase
from framework.runner import GenericTestRunner
from framework import BaseOperatorTest, TensorSpec, TestCase, GenericTestRunner
# Test cases format: (q_shape, k_shape, v_shape, attn_mask_or_None, dropout_p, is_causal)
# q/k/v typically have shape (..., seq_len, head_dim) or (batch, seq_len, num_heads, head_dim)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment