Unverified Commit 39ec8f0e authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #558 from InfiniTensor/issue/556

parents 2e5b2342 6b8949ce
...@@ -7,119 +7,149 @@ import torch ...@@ -7,119 +7,149 @@ import torch
import infinicore import infinicore
from framework.base import BaseOperatorTest, TensorSpec, TestCase from framework.base import BaseOperatorTest, TensorSpec, TestCase
from framework.runner import GenericTestRunner from framework.runner import GenericTestRunner
from framework.utils import is_broadcast
# ============================================================================== # ==============================================================================
# Operator-specific configuration # Operator-specific configuration
# ============================================================================== # ==============================================================================
# Test cases format: (operation_mode, y_shape, x_shape, w_shape, y_strides, x_strides) # Test cases format: (y_shape, x_shape, w_shape, y_strides, x_strides)
_TEST_CASES_DATA = [ _TEST_CASES_DATA = [
(TestCase.BOTH, (1, 4), (1, 4), (4,), None, None), # Basic cases
(TestCase.BOTH, (2, 4), (2, 4), (4,), None, None), ((1, 4), (1, 4), (4,), None, None),
(TestCase.BOTH, (2, 2, 4), (2, 2, 4), (4,), None, None), ((2, 4), (2, 4), (4,), None, None),
(TestCase.BOTH, (2, 2, 4), (2, 2, 4), (4,), (12, 8, 1), (12, 8, 1)), ((2, 2, 4), (2, 2, 4), (4,), None, None),
(TestCase.BOTH, (16, 2048), (16, 2048), (2048,), None, None), # Strided cases
(TestCase.BOTH, (16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1)), ((2, 2, 4), (2, 2, 4), (4,), (12, 8, 1), (12, 8, 1)),
# Large tensors
((16, 2048), (16, 2048), (2048,), None, None),
((16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1)),
] ]
# Tolerance configuration
def parse_test_cases(data):
"""
Parse RMSNorm test case data according to format:
(operation_mode, y_shape, x_shape, w_shape, y_strides, x_strides)
"""
operation_mode = data[0]
y_shape = data[1] # Output shape
x_shape = data[2] # Input shape
w_shape = data[3] # Weight shape (1D)
y_strides = data[4] if len(data) > 4 else None
x_strides = data[5] if len(data) > 5 else None
# Create input specifications
inputs = []
# Input tensor x
if x_strides is not None:
inputs.append(TensorSpec.from_strided_tensor(x_shape, x_strides))
else:
inputs.append(TensorSpec.from_tensor(x_shape))
# Weight tensor (1D, always contiguous)
inputs.append(TensorSpec.from_tensor(w_shape))
# Output tensor
if y_strides is not None:
output = TensorSpec.from_strided_tensor(y_shape, y_strides)
else:
output = TensorSpec.from_tensor(y_shape)
return TestCase(operation_mode, inputs, output)
# Parse test cases
_TEST_CASES = [parse_test_cases(data) for data in _TEST_CASES_DATA]
# Data types for individual tensors
_INPUT_DTYPES = [infinicore.float16, infinicore.bfloat16]
_WEIGHT_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
# Generate all dtype combinations
_DTYPE_COMBINATIONS = []
for input_dtype in _INPUT_DTYPES:
for weight_dtype in _WEIGHT_DTYPES:
_DTYPE_COMBINATIONS.append(
{
"input_0": input_dtype, # x tensor
"input_1": weight_dtype, # weight tensor
"output": input_dtype, # output tensor (same as input)
}
)
# Base data types
_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16]
# Tolerance
_TOLERANCE_MAP = { _TOLERANCE_MAP = {
infinicore.float16: {"atol": 2e-3, "rtol": 2e-3}, infinicore.float16: {"atol": 2e-3, "rtol": 2e-3},
infinicore.bfloat16: {"atol": 1e-2, "rtol": 1e-2}, infinicore.bfloat16: {"atol": 1e-2, "rtol": 1e-2},
infinicore.float32: {"atol": 1e-5, "rtol": 1e-4},
} }
# Data types for individual tensors
_INPUT_DTYPES = [infinicore.float16, infinicore.bfloat16]
_WEIGHT_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
# EPSILON constant for RMSNorm # EPSILON constant for RMSNorm
_EPSILON = 1e-5 _EPSILON = 1e-5
def parse_test_cases():
"""
Parse RMSNorm test case data and return list of TestCase objects.
Format: (y_shape, x_shape, w_shape, y_strides, x_strides)
"""
test_cases = []
for data in _TEST_CASES_DATA:
y_shape = data[0] # Output shape
x_shape = data[1] # Input shape
w_shape = data[2] # Weight shape (1D)
y_strides = data[3] if len(data) > 3 else None
x_strides = data[4] if len(data) > 4 else None
# Check if tensors support in-place operations
x_supports_inplace = not is_broadcast(x_strides)
y_supports_inplace = not is_broadcast(y_strides)
# Generate test cases for all dtype combinations
for input_dtype in _INPUT_DTYPES:
for weight_dtype in _WEIGHT_DTYPES:
# Use input dtype tolerance for output
tolerance = _TOLERANCE_MAP.get(
input_dtype, {"atol": 1e-5, "rtol": 1e-4}
)
# Create typed tensor specs
x_spec = TensorSpec.from_tensor(x_shape, x_strides, input_dtype)
w_spec = TensorSpec.from_tensor(
w_shape, None, weight_dtype
) # Weight is always contiguous
y_spec = TensorSpec.from_tensor(y_shape, y_strides, input_dtype)
# Test Case 1: Out-of-place (return value)
test_cases.append(
TestCase(
inputs=[x_spec, w_spec],
kwargs={"epsilon": _EPSILON},
output_spec=None,
comparison_target=None,
tolerance=tolerance,
description=f"RMSNorm - OUT_OF_PLACE",
)
)
# Test Case 2: In-place with explicit output tensor (rms_norm(x, w, out=y))
if y_supports_inplace:
test_cases.append(
TestCase(
inputs=[x_spec, w_spec],
kwargs={"epsilon": _EPSILON},
output_spec=y_spec, # Specify the output tensor spec
comparison_target="out",
tolerance=tolerance,
description=f"RMSNorm - INPLACE(out)",
)
)
# Test Case 3: In-place on input tensor (rms_norm(x, w, out=x))
if x_supports_inplace:
test_cases.append(
TestCase(
inputs=[x_spec, w_spec],
kwargs={
"out": 0,
"epsilon": _EPSILON,
}, # Use index 0 for first input
output_spec=None,
comparison_target=0, # Compare first input
tolerance=tolerance,
description=f"RMSNorm - INPLACE(x)",
)
)
return test_cases
class OpTest(BaseOperatorTest): class OpTest(BaseOperatorTest):
"""RMSNorm test with simplified test case parsing""" """RMSNorm operator test with simplified implementation"""
def __init__(self): def __init__(self):
super().__init__("RMS_Norm") super().__init__("RMSNorm")
def get_test_cases(self): def get_test_cases(self):
return _TEST_CASES return parse_test_cases()
def get_tensor_dtypes(self): def torch_operator(self, x, weight, epsilon=_EPSILON, out=None, **kwargs):
return _TENSOR_DTYPES """PyTorch RMSNorm implementation"""
input_dtype = x.dtype
def get_tolerance_map(self): # Convert to float32 for numerical stability
return _TOLERANCE_MAP hidden_states = x.to(torch.float32)
weight_fp32 = weight.to(torch.float32)
def get_dtype_combinations(self): # Calculate RMSNorm: x * weight / sqrt(mean(x^2) + epsilon)
return _DTYPE_COMBINATIONS variance = hidden_states.pow(2).mean(-1, keepdim=True)
result = hidden_states * torch.rsqrt(variance + epsilon) * weight_fp32
def torch_operator(self, x, weight, out=None, **kwargs): # Convert back to original dtype
input_dtype = x.dtype result = result.to(input_dtype)
hidden_states = x.to(torch.float32)
scale = hidden_states.pow(2).mean(-1, keepdim=True).add_(_EPSILON).rsqrt_()
result = (hidden_states * scale * weight).to(input_dtype)
if out is not None: if out is not None:
out.set_(result) out.copy_(result)
return out return out
else: return result
return result
def infinicore_operator(self, x, weight, out=None, **kwargs): def infinicore_operator(self, x, weight, epsilon=_EPSILON, out=None, **kwargs):
return infinicore.rms_norm(x, weight, _EPSILON, out=out) """InfiniCore RMSNorm implementation"""
return infinicore.rms_norm(x, weight, epsilon, out=out)
def main(): def main():
......
...@@ -7,98 +7,129 @@ import torch ...@@ -7,98 +7,129 @@ import torch
import infinicore import infinicore
from framework.base import BaseOperatorTest, TensorSpec, TestCase from framework.base import BaseOperatorTest, TensorSpec, TestCase
from framework.runner import GenericTestRunner from framework.runner import GenericTestRunner
from framework.utils import is_broadcast
# ============================================================================== # ==============================================================================
# Operator-specific configuration # Operator-specific configuration
# ============================================================================== # ==============================================================================
# Test cases format: (operation_mode, shape, input_strides, output_strides) # Test cases format: (shape, input_strides, output_strides)
# SiLU is a single-input activation function: output = input * sigmoid(input) # SiLU is a single-input activation function: output = input * sigmoid(input)
_TEST_CASES_DATA = [ _TEST_CASES_DATA = [
# Basic 2D SiLU # Basic 2D SiLU
(TestCase.BOTH, (2, 4), None, None), ((2, 4), None, None),
(TestCase.BOTH, (128, 64), None, None), ((128, 64), None, None),
# 3D SiLU # 3D SiLU
(TestCase.BOTH, (2, 4, 8), None, None), ((2, 4, 8), None, None),
(TestCase.BOTH, (4, 48, 6), None, None), ((4, 48, 6), None, None),
# Strided tensors # Strided tensors
(TestCase.BOTH, (1, 2048), (4096, 1), (4096, 1)), ((1, 2048), (4096, 1), (4096, 1)),
(TestCase.BOTH, (6, 2560), (2048, 1), (2560, 1)), ((6, 2560), (2048, 1), (2560, 1)),
# Mixed cases # Mixed cases
(TestCase.BOTH, (8, 16, 32), None, None), ((8, 16, 32), None, None),
# Large tensors # Large tensors
(TestCase.BOTH, (16, 5632), None, None), ((16, 5632), None, None),
(TestCase.BOTH, (4, 4, 5632), None, None), ((4, 4, 5632), None, None),
] ]
# Tolerance configuration
def parse_test_cases(data):
"""
Parse silu test case data according to format:
(operation_mode, shape, input_strides, output_strides)
"""
operation_mode = data[0]
shape = data[1]
input_strides = data[2] if len(data) > 2 else None
output_strides = data[3] if len(data) > 3 else None
# Create input specifications
inputs = []
# Tensor input
if input_strides is not None:
inputs.append(TensorSpec.from_strided_tensor(shape, input_strides))
else:
inputs.append(TensorSpec.from_tensor(shape))
# Output tensor
if output_strides is not None:
output = TensorSpec.from_strided_tensor(shape, output_strides)
else:
output = TensorSpec.from_tensor(shape)
return TestCase(operation_mode, inputs, output)
# Parse test cases
_TEST_CASES = [parse_test_cases(data) for data in _TEST_CASES_DATA]
# Data types
_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
# Tolerance
_TOLERANCE_MAP = { _TOLERANCE_MAP = {
infinicore.float16: {"atol": 1e-3, "rtol": 1e-3}, infinicore.float16: {"atol": 1e-3, "rtol": 1e-3},
infinicore.float32: {"atol": 1e-5, "rtol": 1e-5}, infinicore.float32: {"atol": 1e-5, "rtol": 1e-5},
infinicore.bfloat16: {"atol": 5e-3, "rtol": 1e-2}, infinicore.bfloat16: {"atol": 5e-3, "rtol": 1e-2},
} }
# Data types to test
_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
def parse_test_cases():
"""
Parse SiLU test case data according to format:
(shape, input_strides, output_strides)
SiLU only supports out-of-place and in-place modes
"""
test_cases = []
for data in _TEST_CASES_DATA:
shape = data[0]
input_strides = data[1] if len(data) > 1 else None
output_strides = data[2] if len(data) > 2 else None
# Check if tensors support in-place operations
input_supports_inplace = not is_broadcast(input_strides)
output_supports_inplace = not is_broadcast(output_strides)
# Generate test cases for all data types
for dtype in _TENSOR_DTYPES:
tolerance = _TOLERANCE_MAP.get(dtype, {"atol": 1e-5, "rtol": 1e-4})
# Create typed tensor specs
input_spec = TensorSpec.from_tensor(shape, input_strides, dtype)
output_spec = TensorSpec.from_tensor(shape, output_strides, dtype)
# Test Case 1: Out-of-place (return value)
test_cases.append(
TestCase(
inputs=[input_spec],
kwargs={},
output_spec=None,
comparison_target=None,
tolerance=tolerance,
description=f"SiLU - OUT_OF_PLACE",
)
)
# Test Case 2: In-place with explicit output tensor (silu(input, out=output))
if output_supports_inplace:
test_cases.append(
TestCase(
inputs=[input_spec],
kwargs=None,
output_spec=output_spec, # Specify the output tensor spec
comparison_target="out",
tolerance=tolerance,
description=f"SiLU - INPLACE(out)",
)
)
# Test Case 3: In-place on first input (silu(input, out=input))
if input_supports_inplace:
test_cases.append(
TestCase(
inputs=[input_spec],
kwargs={"out": 0}, # Use index 0 for first input
output_spec=None,
comparison_target=0, # Compare first input
tolerance=tolerance,
description=f"SiLU - INPLACE(input)",
)
)
return test_cases
class OpTest(BaseOperatorTest): class OpTest(BaseOperatorTest):
"""SiLU test with simplified test case parsing""" """SiLU operator test with simplified implementation"""
def __init__(self): def __init__(self):
super().__init__("SiLU") super().__init__("SiLU")
def get_test_cases(self): def get_test_cases(self):
return _TEST_CASES return parse_test_cases()
def get_tensor_dtypes(self):
return _TENSOR_DTYPES
def get_tolerance_map(self):
return _TOLERANCE_MAP
def torch_operator(self, input, out=None, **kwargs): def torch_operator(self, input, out=None, **kwargs):
# SiLU implementation: input * sigmoid(input) """PyTorch SiLU implementation: input * sigmoid(input)"""
sigmoid_input = torch.sigmoid(input) sigmoid_input = torch.sigmoid(input)
result = input * sigmoid_input result = input * sigmoid_input
if out is not None: if out is not None:
out.copy_(result) out.copy_(result)
return out return out
return result return result
def infinicore_operator(self, input, out=None, **kwargs): def infinicore_operator(self, input, out=None, **kwargs):
"""InfiniCore SiLU implementation"""
return infinicore.silu(input, out=out) return infinicore.silu(input, out=out)
......
...@@ -7,105 +7,145 @@ import torch ...@@ -7,105 +7,145 @@ import torch
import infinicore import infinicore
from framework.base import BaseOperatorTest, TensorSpec, TestCase from framework.base import BaseOperatorTest, TensorSpec, TestCase
from framework.runner import GenericTestRunner from framework.runner import GenericTestRunner
from framework.utils import is_broadcast
# ============================================================================== # ==============================================================================
# Operator-specific configuration # Operator-specific configuration
# ============================================================================== # ==============================================================================
# Test cases format: (operation_mode, shape, a_strides, b_strides, c_strides) # Test cases format: (shape, a_strides, b_strides, c_strides)
# SwiGLU operates element-wise on two tensors of the same shape # SwiGLU operates element-wise on two tensors of the same shape: output = a * b * sigmoid(b)
_TEST_CASES_DATA = [ _TEST_CASES_DATA = [
# Basic 2D SwiGLU # Basic 2D SwiGLU
(TestCase.BOTH, (2, 4), None, None, None), ((2, 4), None, None, None),
(TestCase.BOTH, (128, 64), None, None, None), ((128, 64), None, None, None),
# 3D SwiGLU # 3D SwiGLU
(TestCase.BOTH, (2, 4, 8), None, None, None), ((2, 4, 8), None, None, None),
(TestCase.BOTH, (4, 48, 6), None, None, None), ((4, 48, 6), None, None, None),
# Strided tensors # Strided tensors
(TestCase.BOTH, (1, 2048), (4096, 1), (4096, 1), (4096, 1)), ((1, 2048), (4096, 1), (4096, 1), (4096, 1)),
(TestCase.BOTH, (6, 2560), (2048, 1), (1, 2048), (2560, 1)), ((6, 2560), (2048, 1), (1, 2048), (2560, 1)),
# Mixed cases # Mixed cases
(TestCase.BOTH, (8, 16, 32), None, None, None), ((8, 16, 32), None, None, None),
# Large tensors # Large tensors
(TestCase.BOTH, (16, 5632), None, None, None), ((16, 5632), None, None, None),
(TestCase.BOTH, (4, 4, 5632), None, None, None), ((4, 4, 5632), None, None, None),
] ]
# Tolerance configuration
def parse_test_cases(data):
"""
Parse swiglu test case data according to format:
(operation_mode, shape, a_strides, b_strides, c_strides)
"""
operation_mode = data[0]
shape = data[1]
a_strides = data[2] if len(data) > 2 else None
b_strides = data[3] if len(data) > 3 else None
c_strides = data[4] if len(data) > 4 else None
# Create input specifications
inputs = []
# Tensor a
if a_strides is not None:
inputs.append(TensorSpec.from_strided_tensor(shape, a_strides))
else:
inputs.append(TensorSpec.from_tensor(shape))
# Tensor b
if b_strides is not None:
inputs.append(TensorSpec.from_strided_tensor(shape, b_strides))
else:
inputs.append(TensorSpec.from_tensor(shape))
# Output tensor
if c_strides is not None:
output = TensorSpec.from_strided_tensor(shape, c_strides)
else:
output = TensorSpec.from_tensor(shape)
return TestCase(operation_mode, inputs, output)
# Parse test cases
_TEST_CASES = [parse_test_cases(data) for data in _TEST_CASES_DATA]
# Data types
_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
# Tolerance
_TOLERANCE_MAP = { _TOLERANCE_MAP = {
infinicore.float16: {"atol": 1e-3, "rtol": 1e-3}, infinicore.float16: {"atol": 1e-3, "rtol": 1e-3},
infinicore.float32: {"atol": 1e-5, "rtol": 1e-5}, infinicore.float32: {"atol": 1e-5, "rtol": 1e-5},
infinicore.bfloat16: {"atol": 5e-3, "rtol": 1e-2}, infinicore.bfloat16: {"atol": 5e-3, "rtol": 1e-2},
} }
# Data types to test
_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
def parse_test_cases():
"""
Parse SwiGLU test case data according to format:
(shape, a_strides, b_strides, c_strides)
SwiGLU is a two-input operation: output = a * b * sigmoid(b)
"""
test_cases = []
for data in _TEST_CASES_DATA:
shape = data[0]
a_strides = data[1] if len(data) > 1 else None
b_strides = data[2] if len(data) > 2 else None
c_strides = data[3] if len(data) > 3 else None
# Check if tensors support in-place operations
a_supports_inplace = not is_broadcast(a_strides) and a_strides == b_strides
b_supports_inplace = not is_broadcast(b_strides) and a_strides == b_strides
c_supports_inplace = not is_broadcast(c_strides)
# Generate test cases for all data types
for dtype in _TENSOR_DTYPES:
tolerance = _TOLERANCE_MAP.get(dtype, {"atol": 1e-5, "rtol": 1e-4})
# Create typed tensor specs
a_spec = TensorSpec.from_tensor(shape, a_strides, dtype)
b_spec = TensorSpec.from_tensor(shape, b_strides, dtype)
c_spec = TensorSpec.from_tensor(shape, c_strides, dtype)
# Test Case 1: Out-of-place (return value)
test_cases.append(
TestCase(
inputs=[a_spec, b_spec],
kwargs={},
output_spec=None,
comparison_target=None,
tolerance=tolerance,
description=f"SwiGLU - OUT_OF_PLACE",
)
)
# Test Case 2: In-place with explicit output tensor (swiglu(a, b, out=c))
if c_supports_inplace:
test_cases.append(
TestCase(
inputs=[a_spec, b_spec],
kwargs=None,
output_spec=c_spec, # Specify the output tensor spec
comparison_target="out",
tolerance=tolerance,
description=f"SwiGLU - INPLACE(out)",
)
)
# Test Case 3: In-place on first input (swiglu(a, b, out=a))
if a_supports_inplace:
test_cases.append(
TestCase(
inputs=[a_spec, b_spec],
kwargs={"out": 0}, # Use index 0 for first input
output_spec=None,
comparison_target=0, # Compare first input
tolerance=tolerance,
description=f"SwiGLU - INPLACE(a)",
)
)
# Test Case 4: In-place on second input (swiglu(a, b, out=b))
if b_supports_inplace:
test_cases.append(
TestCase(
inputs=[a_spec, b_spec],
kwargs={"out": 1}, # Use index 1 for second input
output_spec=None,
comparison_target=1, # Compare second input
tolerance=tolerance,
description=f"SwiGLU - INPLACE(b)",
)
)
return test_cases
class OpTest(BaseOperatorTest): class OpTest(BaseOperatorTest):
"""SwiGLU test with simplified test case parsing""" """SwiGLU operator test with simplified implementation"""
def __init__(self): def __init__(self):
super().__init__("SwiGLU") super().__init__("SwiGLU")
def get_test_cases(self): def get_test_cases(self):
return _TEST_CASES return parse_test_cases()
def get_tensor_dtypes(self):
return _TENSOR_DTYPES
def get_tolerance_map(self):
return _TOLERANCE_MAP
def torch_operator(self, a, b, out=None, **kwargs): def torch_operator(self, a, b, out=None, **kwargs):
# SwiGLU implementation: a * b * sigmoid(b) """PyTorch SwiGLU implementation: a * b * sigmoid(b)"""
sigmoid_b = torch.sigmoid(b) sigmoid_b = torch.sigmoid(b)
result = a * b * sigmoid_b result = a * b * sigmoid_b
if out is not None: if out is not None:
out.copy_(result) out.copy_(result)
return out return out
return result return result
def infinicore_operator(self, a, b, out=None, **kwargs): def infinicore_operator(self, a, b, out=None, **kwargs):
"""InfiniCore SwiGLU implementation"""
return infinicore.swiglu(a, b, out=out) return infinicore.swiglu(a, b, out=out)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment