Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
39ec8f0e
Unverified
Commit
39ec8f0e
authored
Nov 07, 2025
by
PanZezhong1725
Committed by
GitHub
Nov 07, 2025
Browse files
Merge pull request #558 from InfiniTensor/issue/556
parents
2e5b2342
6b8949ce
Changes
23
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
302 additions
and
201 deletions
+302
-201
test/infinicore/ops/rms_norm.py
test/infinicore/ops/rms_norm.py
+113
-83
test/infinicore/ops/silu.py
test/infinicore/ops/silu.py
+86
-55
test/infinicore/ops/swiglu.py
test/infinicore/ops/swiglu.py
+103
-63
No files found.
test/infinicore/ops/rms_norm.py
View file @
39ec8f0e
...
...
@@ -7,119 +7,149 @@ import torch
import
infinicore
from
framework.base
import
BaseOperatorTest
,
TensorSpec
,
TestCase
from
framework.runner
import
GenericTestRunner
from
framework.utils
import
is_broadcast
# ==============================================================================
# Operator-specific configuration
# ==============================================================================
# Test cases format: (
operation_mode,
y_shape, x_shape, w_shape, y_strides, x_strides)
# Test cases format: (y_shape, x_shape, w_shape, y_strides, x_strides)
_TEST_CASES_DATA
=
[
(
TestCase
.
BOTH
,
(
1
,
4
),
(
1
,
4
),
(
4
,),
None
,
None
),
(
TestCase
.
BOTH
,
(
2
,
4
),
(
2
,
4
),
(
4
,),
None
,
None
),
(
TestCase
.
BOTH
,
(
2
,
2
,
4
),
(
2
,
2
,
4
),
(
4
,),
None
,
None
),
(
TestCase
.
BOTH
,
(
2
,
2
,
4
),
(
2
,
2
,
4
),
(
4
,),
(
12
,
8
,
1
),
(
12
,
8
,
1
)),
(
TestCase
.
BOTH
,
(
16
,
2048
),
(
16
,
2048
),
(
2048
,),
None
,
None
),
(
TestCase
.
BOTH
,
(
16
,
2048
),
(
16
,
2048
),
(
2048
,),
(
4096
,
1
),
(
4096
,
1
)),
# Basic cases
((
1
,
4
),
(
1
,
4
),
(
4
,),
None
,
None
),
((
2
,
4
),
(
2
,
4
),
(
4
,),
None
,
None
),
((
2
,
2
,
4
),
(
2
,
2
,
4
),
(
4
,),
None
,
None
),
# Strided cases
((
2
,
2
,
4
),
(
2
,
2
,
4
),
(
4
,),
(
12
,
8
,
1
),
(
12
,
8
,
1
)),
# Large tensors
((
16
,
2048
),
(
16
,
2048
),
(
2048
,),
None
,
None
),
((
16
,
2048
),
(
16
,
2048
),
(
2048
,),
(
4096
,
1
),
(
4096
,
1
)),
]
def
parse_test_cases
(
data
):
"""
Parse RMSNorm test case data according to format:
(operation_mode, y_shape, x_shape, w_shape, y_strides, x_strides)
"""
operation_mode
=
data
[
0
]
y_shape
=
data
[
1
]
# Output shape
x_shape
=
data
[
2
]
# Input shape
w_shape
=
data
[
3
]
# Weight shape (1D)
y_strides
=
data
[
4
]
if
len
(
data
)
>
4
else
None
x_strides
=
data
[
5
]
if
len
(
data
)
>
5
else
None
# Create input specifications
inputs
=
[]
# Input tensor x
if
x_strides
is
not
None
:
inputs
.
append
(
TensorSpec
.
from_strided_tensor
(
x_shape
,
x_strides
))
else
:
inputs
.
append
(
TensorSpec
.
from_tensor
(
x_shape
))
# Weight tensor (1D, always contiguous)
inputs
.
append
(
TensorSpec
.
from_tensor
(
w_shape
))
# Output tensor
if
y_strides
is
not
None
:
output
=
TensorSpec
.
from_strided_tensor
(
y_shape
,
y_strides
)
else
:
output
=
TensorSpec
.
from_tensor
(
y_shape
)
return
TestCase
(
operation_mode
,
inputs
,
output
)
# Parse test cases
_TEST_CASES
=
[
parse_test_cases
(
data
)
for
data
in
_TEST_CASES_DATA
]
# Data types for individual tensors
_INPUT_DTYPES
=
[
infinicore
.
float16
,
infinicore
.
bfloat16
]
_WEIGHT_DTYPES
=
[
infinicore
.
float16
,
infinicore
.
bfloat16
,
infinicore
.
float32
]
# Generate all dtype combinations
_DTYPE_COMBINATIONS
=
[]
for
input_dtype
in
_INPUT_DTYPES
:
for
weight_dtype
in
_WEIGHT_DTYPES
:
_DTYPE_COMBINATIONS
.
append
(
{
"input_0"
:
input_dtype
,
# x tensor
"input_1"
:
weight_dtype
,
# weight tensor
"output"
:
input_dtype
,
# output tensor (same as input)
}
)
# Base data types
_TENSOR_DTYPES
=
[
infinicore
.
float16
,
infinicore
.
bfloat16
]
# Tolerance
# Tolerance configuration
_TOLERANCE_MAP
=
{
infinicore
.
float16
:
{
"atol"
:
2e-3
,
"rtol"
:
2e-3
},
infinicore
.
bfloat16
:
{
"atol"
:
1e-2
,
"rtol"
:
1e-2
},
infinicore
.
float32
:
{
"atol"
:
1e-5
,
"rtol"
:
1e-4
},
}
# Data types for individual tensors
_INPUT_DTYPES
=
[
infinicore
.
float16
,
infinicore
.
bfloat16
]
_WEIGHT_DTYPES
=
[
infinicore
.
float16
,
infinicore
.
bfloat16
,
infinicore
.
float32
]
# EPSILON constant for RMSNorm
_EPSILON
=
1e-5
def
parse_test_cases
():
"""
Parse RMSNorm test case data and return list of TestCase objects.
Format: (y_shape, x_shape, w_shape, y_strides, x_strides)
"""
test_cases
=
[]
for
data
in
_TEST_CASES_DATA
:
y_shape
=
data
[
0
]
# Output shape
x_shape
=
data
[
1
]
# Input shape
w_shape
=
data
[
2
]
# Weight shape (1D)
y_strides
=
data
[
3
]
if
len
(
data
)
>
3
else
None
x_strides
=
data
[
4
]
if
len
(
data
)
>
4
else
None
# Check if tensors support in-place operations
x_supports_inplace
=
not
is_broadcast
(
x_strides
)
y_supports_inplace
=
not
is_broadcast
(
y_strides
)
# Generate test cases for all dtype combinations
for
input_dtype
in
_INPUT_DTYPES
:
for
weight_dtype
in
_WEIGHT_DTYPES
:
# Use input dtype tolerance for output
tolerance
=
_TOLERANCE_MAP
.
get
(
input_dtype
,
{
"atol"
:
1e-5
,
"rtol"
:
1e-4
}
)
# Create typed tensor specs
x_spec
=
TensorSpec
.
from_tensor
(
x_shape
,
x_strides
,
input_dtype
)
w_spec
=
TensorSpec
.
from_tensor
(
w_shape
,
None
,
weight_dtype
)
# Weight is always contiguous
y_spec
=
TensorSpec
.
from_tensor
(
y_shape
,
y_strides
,
input_dtype
)
# Test Case 1: Out-of-place (return value)
test_cases
.
append
(
TestCase
(
inputs
=
[
x_spec
,
w_spec
],
kwargs
=
{
"epsilon"
:
_EPSILON
},
output_spec
=
None
,
comparison_target
=
None
,
tolerance
=
tolerance
,
description
=
f
"RMSNorm - OUT_OF_PLACE"
,
)
)
# Test Case 2: In-place with explicit output tensor (rms_norm(x, w, out=y))
if
y_supports_inplace
:
test_cases
.
append
(
TestCase
(
inputs
=
[
x_spec
,
w_spec
],
kwargs
=
{
"epsilon"
:
_EPSILON
},
output_spec
=
y_spec
,
# Specify the output tensor spec
comparison_target
=
"out"
,
tolerance
=
tolerance
,
description
=
f
"RMSNorm - INPLACE(out)"
,
)
)
# Test Case 3: In-place on input tensor (rms_norm(x, w, out=x))
if
x_supports_inplace
:
test_cases
.
append
(
TestCase
(
inputs
=
[
x_spec
,
w_spec
],
kwargs
=
{
"out"
:
0
,
"epsilon"
:
_EPSILON
,
},
# Use index 0 for first input
output_spec
=
None
,
comparison_target
=
0
,
# Compare first input
tolerance
=
tolerance
,
description
=
f
"RMSNorm - INPLACE(x)"
,
)
)
return
test_cases
class
OpTest
(
BaseOperatorTest
):
"""RMSNorm test with simplified
test case parsing
"""
"""RMSNorm
operator
test with simplified
implementation
"""
def
__init__
(
self
):
super
().
__init__
(
"RMS
_
Norm"
)
super
().
__init__
(
"RMSNorm"
)
def
get_test_cases
(
self
):
return
_TEST_CASES
return
parse_test_cases
()
def
get_tensor_dtypes
(
self
):
return
_TENSOR_DTYPES
def
torch_operator
(
self
,
x
,
weight
,
epsilon
=
_EPSILON
,
out
=
None
,
**
kwargs
):
"""PyTorch RMSNorm implementation"""
input_dtype
=
x
.
dtype
def
get_tolerance_map
(
self
):
return
_TOLERANCE_MAP
# Convert to float32 for numerical stability
hidden_states
=
x
.
to
(
torch
.
float32
)
weight_fp32
=
weight
.
to
(
torch
.
float32
)
def
get_dtype_combinations
(
self
):
return
_DTYPE_COMBINATIONS
# Calculate RMSNorm: x * weight / sqrt(mean(x^2) + epsilon)
variance
=
hidden_states
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
result
=
hidden_states
*
torch
.
rsqrt
(
variance
+
epsilon
)
*
weight_fp32
def
torch_operator
(
self
,
x
,
weight
,
out
=
None
,
**
kwargs
):
input_dtype
=
x
.
dtype
hidden_states
=
x
.
to
(
torch
.
float32
)
scale
=
hidden_states
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
).
add_
(
_EPSILON
).
rsqrt_
()
result
=
(
hidden_states
*
scale
*
weight
).
to
(
input_dtype
)
# Convert back to original dtype
result
=
result
.
to
(
input_dtype
)
if
out
is
not
None
:
out
.
set
_
(
result
)
out
.
copy
_
(
result
)
return
out
else
:
return
result
return
result
def
infinicore_operator
(
self
,
x
,
weight
,
out
=
None
,
**
kwargs
):
return
infinicore
.
rms_norm
(
x
,
weight
,
_EPSILON
,
out
=
out
)
def
infinicore_operator
(
self
,
x
,
weight
,
epsilon
=
_EPSILON
,
out
=
None
,
**
kwargs
):
"""InfiniCore RMSNorm implementation"""
return
infinicore
.
rms_norm
(
x
,
weight
,
epsilon
,
out
=
out
)
def
main
():
...
...
test/infinicore/ops/silu.py
View file @
39ec8f0e
...
...
@@ -7,98 +7,129 @@ import torch
import
infinicore
from
framework.base
import
BaseOperatorTest
,
TensorSpec
,
TestCase
from
framework.runner
import
GenericTestRunner
from
framework.utils
import
is_broadcast
# ==============================================================================
# Operator-specific configuration
# ==============================================================================
# Test cases format: (
operation_mode,
shape, input_strides, output_strides)
# Test cases format: (shape, input_strides, output_strides)
# SiLU is a single-input activation function: output = input * sigmoid(input)
_TEST_CASES_DATA
=
[
# Basic 2D SiLU
(
TestCase
.
BOTH
,
(
2
,
4
),
None
,
None
),
(
TestCase
.
BOTH
,
(
128
,
64
),
None
,
None
),
((
2
,
4
),
None
,
None
),
((
128
,
64
),
None
,
None
),
# 3D SiLU
(
TestCase
.
BOTH
,
(
2
,
4
,
8
),
None
,
None
),
(
TestCase
.
BOTH
,
(
4
,
48
,
6
),
None
,
None
),
((
2
,
4
,
8
),
None
,
None
),
((
4
,
48
,
6
),
None
,
None
),
# Strided tensors
(
TestCase
.
BOTH
,
(
1
,
2048
),
(
4096
,
1
),
(
4096
,
1
)),
(
TestCase
.
BOTH
,
(
6
,
2560
),
(
2048
,
1
),
(
2560
,
1
)),
((
1
,
2048
),
(
4096
,
1
),
(
4096
,
1
)),
((
6
,
2560
),
(
2048
,
1
),
(
2560
,
1
)),
# Mixed cases
(
TestCase
.
BOTH
,
(
8
,
16
,
32
),
None
,
None
),
((
8
,
16
,
32
),
None
,
None
),
# Large tensors
(
TestCase
.
BOTH
,
(
16
,
5632
),
None
,
None
),
(
TestCase
.
BOTH
,
(
4
,
4
,
5632
),
None
,
None
),
((
16
,
5632
),
None
,
None
),
((
4
,
4
,
5632
),
None
,
None
),
]
def
parse_test_cases
(
data
):
"""
Parse silu test case data according to format:
(operation_mode, shape, input_strides, output_strides)
"""
operation_mode
=
data
[
0
]
shape
=
data
[
1
]
input_strides
=
data
[
2
]
if
len
(
data
)
>
2
else
None
output_strides
=
data
[
3
]
if
len
(
data
)
>
3
else
None
# Create input specifications
inputs
=
[]
# Tensor input
if
input_strides
is
not
None
:
inputs
.
append
(
TensorSpec
.
from_strided_tensor
(
shape
,
input_strides
))
else
:
inputs
.
append
(
TensorSpec
.
from_tensor
(
shape
))
# Output tensor
if
output_strides
is
not
None
:
output
=
TensorSpec
.
from_strided_tensor
(
shape
,
output_strides
)
else
:
output
=
TensorSpec
.
from_tensor
(
shape
)
return
TestCase
(
operation_mode
,
inputs
,
output
)
# Parse test cases
_TEST_CASES
=
[
parse_test_cases
(
data
)
for
data
in
_TEST_CASES_DATA
]
# Data types
_TENSOR_DTYPES
=
[
infinicore
.
float16
,
infinicore
.
bfloat16
,
infinicore
.
float32
]
# Tolerance
# Tolerance configuration
_TOLERANCE_MAP
=
{
infinicore
.
float16
:
{
"atol"
:
1e-3
,
"rtol"
:
1e-3
},
infinicore
.
float32
:
{
"atol"
:
1e-5
,
"rtol"
:
1e-5
},
infinicore
.
bfloat16
:
{
"atol"
:
5e-3
,
"rtol"
:
1e-2
},
}
# Data types to test
_TENSOR_DTYPES
=
[
infinicore
.
float16
,
infinicore
.
bfloat16
,
infinicore
.
float32
]
def
parse_test_cases
():
"""
Parse SiLU test case data according to format:
(shape, input_strides, output_strides)
SiLU only supports out-of-place and in-place modes
"""
test_cases
=
[]
for
data
in
_TEST_CASES_DATA
:
shape
=
data
[
0
]
input_strides
=
data
[
1
]
if
len
(
data
)
>
1
else
None
output_strides
=
data
[
2
]
if
len
(
data
)
>
2
else
None
# Check if tensors support in-place operations
input_supports_inplace
=
not
is_broadcast
(
input_strides
)
output_supports_inplace
=
not
is_broadcast
(
output_strides
)
# Generate test cases for all data types
for
dtype
in
_TENSOR_DTYPES
:
tolerance
=
_TOLERANCE_MAP
.
get
(
dtype
,
{
"atol"
:
1e-5
,
"rtol"
:
1e-4
})
# Create typed tensor specs
input_spec
=
TensorSpec
.
from_tensor
(
shape
,
input_strides
,
dtype
)
output_spec
=
TensorSpec
.
from_tensor
(
shape
,
output_strides
,
dtype
)
# Test Case 1: Out-of-place (return value)
test_cases
.
append
(
TestCase
(
inputs
=
[
input_spec
],
kwargs
=
{},
output_spec
=
None
,
comparison_target
=
None
,
tolerance
=
tolerance
,
description
=
f
"SiLU - OUT_OF_PLACE"
,
)
)
# Test Case 2: In-place with explicit output tensor (silu(input, out=output))
if
output_supports_inplace
:
test_cases
.
append
(
TestCase
(
inputs
=
[
input_spec
],
kwargs
=
None
,
output_spec
=
output_spec
,
# Specify the output tensor spec
comparison_target
=
"out"
,
tolerance
=
tolerance
,
description
=
f
"SiLU - INPLACE(out)"
,
)
)
# Test Case 3: In-place on first input (silu(input, out=input))
if
input_supports_inplace
:
test_cases
.
append
(
TestCase
(
inputs
=
[
input_spec
],
kwargs
=
{
"out"
:
0
},
# Use index 0 for first input
output_spec
=
None
,
comparison_target
=
0
,
# Compare first input
tolerance
=
tolerance
,
description
=
f
"SiLU - INPLACE(input)"
,
)
)
return
test_cases
class
OpTest
(
BaseOperatorTest
):
"""SiLU test with simplified
test case parsing
"""
"""SiLU
operator
test with simplified
implementation
"""
def
__init__
(
self
):
super
().
__init__
(
"SiLU"
)
def
get_test_cases
(
self
):
return
_TEST_CASES
def
get_tensor_dtypes
(
self
):
return
_TENSOR_DTYPES
def
get_tolerance_map
(
self
):
return
_TOLERANCE_MAP
return
parse_test_cases
()
def
torch_operator
(
self
,
input
,
out
=
None
,
**
kwargs
):
#
SiLU implementation: input * sigmoid(input)
"""PyTorch
SiLU implementation: input * sigmoid(input)
"""
sigmoid_input
=
torch
.
sigmoid
(
input
)
result
=
input
*
sigmoid_input
if
out
is
not
None
:
out
.
copy_
(
result
)
return
out
return
result
def
infinicore_operator
(
self
,
input
,
out
=
None
,
**
kwargs
):
"""InfiniCore SiLU implementation"""
return
infinicore
.
silu
(
input
,
out
=
out
)
...
...
test/infinicore/ops/swiglu.py
View file @
39ec8f0e
...
...
@@ -7,105 +7,145 @@ import torch
import
infinicore
from
framework.base
import
BaseOperatorTest
,
TensorSpec
,
TestCase
from
framework.runner
import
GenericTestRunner
from
framework.utils
import
is_broadcast
# ==============================================================================
# Operator-specific configuration
# ==============================================================================
# Test cases format: (
operation_mode,
shape, a_strides, b_strides, c_strides)
# SwiGLU operates element-wise on two tensors of the same shape
# Test cases format: (shape, a_strides, b_strides, c_strides)
# SwiGLU operates element-wise on two tensors of the same shape
: output = a * b * sigmoid(b)
_TEST_CASES_DATA
=
[
# Basic 2D SwiGLU
(
TestCase
.
BOTH
,
(
2
,
4
),
None
,
None
,
None
),
(
TestCase
.
BOTH
,
(
128
,
64
),
None
,
None
,
None
),
((
2
,
4
),
None
,
None
,
None
),
((
128
,
64
),
None
,
None
,
None
),
# 3D SwiGLU
(
TestCase
.
BOTH
,
(
2
,
4
,
8
),
None
,
None
,
None
),
(
TestCase
.
BOTH
,
(
4
,
48
,
6
),
None
,
None
,
None
),
((
2
,
4
,
8
),
None
,
None
,
None
),
((
4
,
48
,
6
),
None
,
None
,
None
),
# Strided tensors
(
TestCase
.
BOTH
,
(
1
,
2048
),
(
4096
,
1
),
(
4096
,
1
),
(
4096
,
1
)),
(
TestCase
.
BOTH
,
(
6
,
2560
),
(
2048
,
1
),
(
1
,
2048
),
(
2560
,
1
)),
((
1
,
2048
),
(
4096
,
1
),
(
4096
,
1
),
(
4096
,
1
)),
((
6
,
2560
),
(
2048
,
1
),
(
1
,
2048
),
(
2560
,
1
)),
# Mixed cases
(
TestCase
.
BOTH
,
(
8
,
16
,
32
),
None
,
None
,
None
),
((
8
,
16
,
32
),
None
,
None
,
None
),
# Large tensors
(
TestCase
.
BOTH
,
(
16
,
5632
),
None
,
None
,
None
),
(
TestCase
.
BOTH
,
(
4
,
4
,
5632
),
None
,
None
,
None
),
((
16
,
5632
),
None
,
None
,
None
),
((
4
,
4
,
5632
),
None
,
None
,
None
),
]
def
parse_test_cases
(
data
):
"""
Parse swiglu test case data according to format:
(operation_mode, shape, a_strides, b_strides, c_strides)
"""
operation_mode
=
data
[
0
]
shape
=
data
[
1
]
a_strides
=
data
[
2
]
if
len
(
data
)
>
2
else
None
b_strides
=
data
[
3
]
if
len
(
data
)
>
3
else
None
c_strides
=
data
[
4
]
if
len
(
data
)
>
4
else
None
# Create input specifications
inputs
=
[]
# Tensor a
if
a_strides
is
not
None
:
inputs
.
append
(
TensorSpec
.
from_strided_tensor
(
shape
,
a_strides
))
else
:
inputs
.
append
(
TensorSpec
.
from_tensor
(
shape
))
# Tensor b
if
b_strides
is
not
None
:
inputs
.
append
(
TensorSpec
.
from_strided_tensor
(
shape
,
b_strides
))
else
:
inputs
.
append
(
TensorSpec
.
from_tensor
(
shape
))
# Output tensor
if
c_strides
is
not
None
:
output
=
TensorSpec
.
from_strided_tensor
(
shape
,
c_strides
)
else
:
output
=
TensorSpec
.
from_tensor
(
shape
)
return
TestCase
(
operation_mode
,
inputs
,
output
)
# Parse test cases
_TEST_CASES
=
[
parse_test_cases
(
data
)
for
data
in
_TEST_CASES_DATA
]
# Data types
_TENSOR_DTYPES
=
[
infinicore
.
float16
,
infinicore
.
bfloat16
,
infinicore
.
float32
]
# Tolerance
# Tolerance configuration
_TOLERANCE_MAP
=
{
infinicore
.
float16
:
{
"atol"
:
1e-3
,
"rtol"
:
1e-3
},
infinicore
.
float32
:
{
"atol"
:
1e-5
,
"rtol"
:
1e-5
},
infinicore
.
bfloat16
:
{
"atol"
:
5e-3
,
"rtol"
:
1e-2
},
}
# Data types to test
_TENSOR_DTYPES
=
[
infinicore
.
float16
,
infinicore
.
bfloat16
,
infinicore
.
float32
]
def
parse_test_cases
():
"""
Parse SwiGLU test case data according to format:
(shape, a_strides, b_strides, c_strides)
SwiGLU is a two-input operation: output = a * b * sigmoid(b)
"""
test_cases
=
[]
for
data
in
_TEST_CASES_DATA
:
shape
=
data
[
0
]
a_strides
=
data
[
1
]
if
len
(
data
)
>
1
else
None
b_strides
=
data
[
2
]
if
len
(
data
)
>
2
else
None
c_strides
=
data
[
3
]
if
len
(
data
)
>
3
else
None
# Check if tensors support in-place operations
a_supports_inplace
=
not
is_broadcast
(
a_strides
)
and
a_strides
==
b_strides
b_supports_inplace
=
not
is_broadcast
(
b_strides
)
and
a_strides
==
b_strides
c_supports_inplace
=
not
is_broadcast
(
c_strides
)
# Generate test cases for all data types
for
dtype
in
_TENSOR_DTYPES
:
tolerance
=
_TOLERANCE_MAP
.
get
(
dtype
,
{
"atol"
:
1e-5
,
"rtol"
:
1e-4
})
# Create typed tensor specs
a_spec
=
TensorSpec
.
from_tensor
(
shape
,
a_strides
,
dtype
)
b_spec
=
TensorSpec
.
from_tensor
(
shape
,
b_strides
,
dtype
)
c_spec
=
TensorSpec
.
from_tensor
(
shape
,
c_strides
,
dtype
)
# Test Case 1: Out-of-place (return value)
test_cases
.
append
(
TestCase
(
inputs
=
[
a_spec
,
b_spec
],
kwargs
=
{},
output_spec
=
None
,
comparison_target
=
None
,
tolerance
=
tolerance
,
description
=
f
"SwiGLU - OUT_OF_PLACE"
,
)
)
# Test Case 2: In-place with explicit output tensor (swiglu(a, b, out=c))
if
c_supports_inplace
:
test_cases
.
append
(
TestCase
(
inputs
=
[
a_spec
,
b_spec
],
kwargs
=
None
,
output_spec
=
c_spec
,
# Specify the output tensor spec
comparison_target
=
"out"
,
tolerance
=
tolerance
,
description
=
f
"SwiGLU - INPLACE(out)"
,
)
)
# Test Case 3: In-place on first input (swiglu(a, b, out=a))
if
a_supports_inplace
:
test_cases
.
append
(
TestCase
(
inputs
=
[
a_spec
,
b_spec
],
kwargs
=
{
"out"
:
0
},
# Use index 0 for first input
output_spec
=
None
,
comparison_target
=
0
,
# Compare first input
tolerance
=
tolerance
,
description
=
f
"SwiGLU - INPLACE(a)"
,
)
)
# Test Case 4: In-place on second input (swiglu(a, b, out=b))
if
b_supports_inplace
:
test_cases
.
append
(
TestCase
(
inputs
=
[
a_spec
,
b_spec
],
kwargs
=
{
"out"
:
1
},
# Use index 1 for second input
output_spec
=
None
,
comparison_target
=
1
,
# Compare second input
tolerance
=
tolerance
,
description
=
f
"SwiGLU - INPLACE(b)"
,
)
)
return
test_cases
class
OpTest
(
BaseOperatorTest
):
"""SwiGLU test with simplified
test case parsing
"""
"""SwiGLU
operator
test with simplified
implementation
"""
def
__init__
(
self
):
super
().
__init__
(
"SwiGLU"
)
def
get_test_cases
(
self
):
return
_TEST_CASES
def
get_tensor_dtypes
(
self
):
return
_TENSOR_DTYPES
def
get_tolerance_map
(
self
):
return
_TOLERANCE_MAP
return
parse_test_cases
()
def
torch_operator
(
self
,
a
,
b
,
out
=
None
,
**
kwargs
):
#
SwiGLU implementation: a * b * sigmoid(b)
"""PyTorch
SwiGLU implementation: a * b * sigmoid(b)
"""
sigmoid_b
=
torch
.
sigmoid
(
b
)
result
=
a
*
b
*
sigmoid_b
if
out
is
not
None
:
out
.
copy_
(
result
)
return
out
return
result
def
infinicore_operator
(
self
,
a
,
b
,
out
=
None
,
**
kwargs
):
"""InfiniCore SwiGLU implementation"""
return
infinicore
.
swiglu
(
a
,
b
,
out
=
out
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment