Unverified Commit 784139b9 authored by thatPepe's avatar thatPepe Committed by GitHub
Browse files

Merge pull request #990 from InfiniTensor/demo131

Demo-131 Cuda graph with optimized paged attention
parents 3c8fb3c0 1d6527cb
......@@ -3,8 +3,8 @@ import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
import torch
from framework import BaseOperatorTest, TensorSpec, TestCase, GenericTestRunner
# Test cases format: (in_shape, in_strides_or_None)
......
......@@ -3,8 +3,8 @@ import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
import torch
from framework import (
BaseOperatorTest,
TensorSpec,
......
......@@ -3,8 +3,8 @@ import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
import torch
from framework import (
BaseOperatorTest,
TensorSpec,
......
......@@ -3,8 +3,8 @@ import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
import torch
from framework import (
BaseOperatorTest,
TensorSpec,
......
......@@ -3,8 +3,8 @@ import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
import torch
from framework import (
BaseOperatorTest,
TensorSpec,
......
......@@ -3,8 +3,8 @@ import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
import torch
from framework import (
BaseOperatorTest,
TensorSpec,
......
......@@ -3,8 +3,8 @@ import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
import torch
from framework import (
BaseOperatorTest,
TensorSpec,
......
......@@ -3,8 +3,8 @@ import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
import torch
from framework import (
BaseOperatorTest,
CaseResult,
......
......@@ -3,8 +3,8 @@ import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
import torch
from framework import (
BaseOperatorTest,
TensorSpec,
......
......@@ -3,8 +3,8 @@ import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
import torch
from framework import (
BaseOperatorTest,
TensorSpec,
......@@ -30,8 +30,24 @@ _TEST_CASES_DATA = [
((16, 2048), (16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1), (4096, 1)),
((15, 3584), (15, 3584), (15, 3584), (3584,), None, None, None),
((4, 4, 2048), (4, 4, 2048), (4, 4, 2048), (2048,), None, None, None),
((4, 4, 2048), (4, 4, 2048), (4, 4, 2048), (2048,), (2048, 8192, 1), (2048, 8192, 1), (2048, 8192, 1)),
((4, 4, 2048), (4, 4, 2048), (4, 4, 2048), (2048,), (16384, 4096, 1), (16384, 4096, 1), (16384, 4096, 1)),
(
(4, 4, 2048),
(4, 4, 2048),
(4, 4, 2048),
(2048,),
(2048, 8192, 1),
(2048, 8192, 1),
(2048, 8192, 1),
),
(
(4, 4, 2048),
(4, 4, 2048),
(4, 4, 2048),
(2048,),
(16384, 4096, 1),
(16384, 4096, 1),
(16384, 4096, 1),
),
]
# Tolerance configuration
......@@ -87,12 +103,14 @@ def parse_test_cases():
y_spec = TensorSpec.from_tensor(y_shape, y_strides, input_dtype)
# Test Case 1: Out-of-place (return value) - returns (normalized_result, add_result)
residual_out_spec = TensorSpec.from_tensor(a_shape, a_strides, input_dtype)
residual_out_spec = TensorSpec.from_tensor(
a_shape, a_strides, input_dtype
)
test_cases.append(
TestCase(
inputs=[a_spec, b_spec, w_spec],
kwargs={"epsilon": _EPSILON},
output_specs=[y_spec, residual_out_spec], # Two outputs
output_specs=None, # Two outputs
comparison_target=None,
tolerance=tolerance,
output_count=2, # Two outputs: normalized_result and add_result
......@@ -101,19 +119,25 @@ def parse_test_cases():
)
# Test Case 2: In-place with explicit output tensors (add_rms_norm_(y, residual_out, a, b, w))
if y_supports_inplace:
residual_out_spec = TensorSpec.from_tensor(a_shape, a_strides, input_dtype)
test_cases.append(
TestCase(
inputs=[a_spec, b_spec, w_spec],
kwargs={"epsilon": _EPSILON, "out": (y_spec, residual_out_spec)},
output_specs=[y_spec, residual_out_spec], # Two outputs
comparison_target="out",
tolerance=tolerance,
output_count=2,
description=f"AddRMSNorm - INPLACE(out)",
)
)
# if y_supports_inplace:
# residual_out_spec = TensorSpec.from_tensor(
# a_shape, a_strides, input_dtype
# )
# test_cases.append(
# TestCase(
# inputs=[a_spec, b_spec, w_spec],
# kwargs={
# "epsilon": _EPSILON,
# "out": y_spec,
# "residual": residual_out_spec,
# },
# output_specs=[y_spec, residual_out_spec], # Two outputs
# comparison_target="out",
# tolerance=tolerance,
# output_count=2,
# description=f"AddRMSNorm - INPLACE(out)",
# )
# )
return test_cases
......@@ -127,7 +151,9 @@ class OpTest(BaseOperatorTest):
def get_test_cases(self):
return parse_test_cases()
def torch_operator(self, a, b, weight, epsilon=_EPSILON, out=None, **kwargs):
def torch_operator(
self, a, b, weight, epsilon=_EPSILON, out=None, residual=None, **kwargs
):
"""PyTorch AddRMSNorm implementation - returns (normalized_result, add_result)"""
input_dtype = a.dtype
......@@ -144,21 +170,19 @@ class OpTest(BaseOperatorTest):
add_result = sum_tensor.to(input_dtype)
if out is not None:
# For in-place operations, we need to handle the output tuple
if isinstance(out, (tuple, list)) and len(out) == 2:
out[0].copy_(normalized_result)
out[1].copy_(add_result)
return tuple(out)
else:
# Single output - just return normalized result for backward compatibility
out.copy_(normalized_result)
return out
out.copy_(normalized_result)
if residual is not None:
residual.copy_(add_result)
return (normalized_result, add_result)
def infinicore_operator(self, a, b, weight, epsilon=_EPSILON, out=None, **kwargs):
def infinicore_operator(
self, a, b, weight, epsilon=_EPSILON, out=None, residual=None, **kwargs
):
"""InfiniCore AddRMSNorm implementation - returns (normalized_result, add_result)"""
return infinicore.add_rms_norm(a, b, weight, epsilon, out=out)
return infinicore.add_rms_norm(
a, b, weight, epsilon, out=out, residual=residual
)
def main():
......
......@@ -3,8 +3,8 @@ import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
import torch
from framework import (
BaseOperatorTest,
TensorSpec,
......
......@@ -3,8 +3,8 @@ import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
import torch
from framework import (
BaseOperatorTest,
TensorSpec,
......
......@@ -3,8 +3,8 @@ import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
import torch
from framework import (
BaseOperatorTest,
TensorSpec,
......
......@@ -3,8 +3,8 @@ import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
import torch
from framework import (
BaseOperatorTest,
TensorSpec,
......
......@@ -3,8 +3,8 @@ import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
import torch
from framework import (
BaseOperatorTest,
TensorSpec,
......
......@@ -3,8 +3,8 @@ import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
import torch
from framework import (
BaseOperatorTest,
TensorSpec,
......
......@@ -3,8 +3,8 @@ import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
import torch
from framework import (
BaseOperatorTest,
TensorSpec,
......
......@@ -3,8 +3,8 @@ import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
import torch
from framework import (
BaseOperatorTest,
TensorSpec,
......
......@@ -3,8 +3,8 @@ import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
import torch
from framework import (
BaseOperatorTest,
TensorSpec,
......
......@@ -3,8 +3,8 @@ import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
import torch
from framework import (
BaseOperatorTest,
TensorSpec,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment