Unverified Commit 784139b9 authored by thatPepe's avatar thatPepe Committed by GitHub
Browse files

Merge pull request #990 from InfiniTensor/demo131

Demo-131 Cuda graph with optimized paged attention
parents 3c8fb3c0 1d6527cb
......@@ -3,8 +3,8 @@ import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
import torch
from framework import (
BaseOperatorTest,
TensorSpec,
......
......@@ -3,8 +3,8 @@ import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
import torch
from framework import BaseOperatorTest, TensorSpec, TestCase, GenericTestRunner
# Test cases format: (anchor_shape, positive_shape, negative_shape, strides_or_None, margin_or_None, p_or_None, eps_or_None, swap_or_None)
......
......@@ -3,8 +3,8 @@ import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
import torch
from framework import BaseOperatorTest, TensorSpec, TestCase, GenericTestRunner
# Test cases format: (anchor_shape, positive_shape, negative_shape, strides_or_None, margin_or_None, swap_or_None)
......
......@@ -3,8 +3,8 @@ import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
import torch
from framework import (
BaseOperatorTest,
TensorSpec,
......
......@@ -3,8 +3,8 @@ import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
import torch
from framework import BaseOperatorTest, TensorSpec, TestCase, GenericTestRunner
# Test cases format: (in_shape, in_strides_or_None, kernel_size, dilation, padding, stride)
......
......@@ -3,8 +3,8 @@ import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
import torch
from framework import (
BaseOperatorTest,
TensorSpec,
......
......@@ -3,8 +3,8 @@ import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
import torch
from framework import (
BaseOperatorTest,
TensorSpec,
......
......@@ -3,8 +3,8 @@ import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
import torch
from framework import (
BaseOperatorTest,
TensorSpec,
......
......@@ -3,8 +3,8 @@ import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
import torch
from framework import (
BaseOperatorTest,
TensorSpec,
......
......@@ -3,8 +3,8 @@ import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
import torch
from framework import BaseOperatorTest, TensorSpec, TestCase, GenericTestRunner
# Test cases format: (input_shape, input_strides_or_None, N)
......
......@@ -3,8 +3,8 @@ import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
import torch
from framework import (
BaseOperatorTest,
TensorSpec,
......
......@@ -3,8 +3,8 @@ import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
import torch
from framework import (
BaseOperatorTest,
TensorSpec,
......
......@@ -3,8 +3,8 @@ import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
import torch
from framework import BaseOperatorTest, TensorSpec, TestCase, GenericTestRunner
# Test cases format: (vec1_shape, vec2_shape, vec1_strides_or_None, vec2_strides_or_None)
......
......@@ -3,8 +3,8 @@ import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
import torch
from framework import BaseOperatorTest, TensorSpec, TestCase, GenericTestRunner
# Test cases format: (condition_shape, cond_strides_or_None, x_shape_or_None, y_shape_or_None)
......
......@@ -183,6 +183,7 @@ def func6_initialize_device_relationship():
_infinicore.Device.Type.QY, # 9 "cuda"
_infinicore.Device.Type.KUNLUN, # 7 "cuda"
_infinicore.Device.Type.HYGON, # 8 "cuda"
_infinicore.Device.Type.ALI, # 10 "cuda"
]
if True:
print("\n ---------- 测试 CPU")
......
......@@ -32,8 +32,24 @@ _TEST_CASES_ = [
((16, 2048), (16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1), (4096, 1)),
((15, 3584), (15, 3584), (15, 3584), (3584,), None, None, None),
((4, 4, 2048), (4, 4, 2048), (4, 4, 2048), (2048,), None, None, None),
((4, 4, 2048), (4, 4, 2048), (4, 4, 2048), (2048,), (2048, 8192, 1), (2048, 8192, 1), (2048, 8192, 1)),
((4, 4, 2048), (4, 4, 2048), (4, 4, 2048), (2048,), (16384, 4096, 1), (16384, 4096, 1), (16384, 4096, 1)),
(
(4, 4, 2048),
(4, 4, 2048),
(4, 4, 2048),
(2048,),
(2048, 8192, 1),
(2048, 8192, 1),
(2048, 8192, 1),
),
(
(4, 4, 2048),
(4, 4, 2048),
(4, 4, 2048),
(2048,),
(16384, 4096, 1),
(16384, 4096, 1),
(16384, 4096, 1),
),
((15, 3584), (15, 3584), (15, 3584), (3584,), None, None, None),
((15, 8192), (15, 8192), (15, 8192), (8192,), None, None, None),
]
......@@ -97,7 +113,9 @@ def test(
w = TestTensor(w_shape, None, w_dtype, device)
eps = 1e-6
add_rms_norm(y.torch_tensor(), a.torch_tensor(), b.torch_tensor(), w.torch_tensor(), eps)
add_rms_norm(
y.torch_tensor(), a.torch_tensor(), b.torch_tensor(), w.torch_tensor(), eps
)
if sync is not None:
sync()
......@@ -109,11 +127,11 @@ def test(
handle,
ctypes.byref(descriptor),
y.descriptor,
residual_out.descriptor,
a.descriptor,
b.descriptor,
w.descriptor,
eps,
residual_out.descriptor,
)
)
......@@ -136,10 +154,10 @@ def test(
workspace.data(),
workspace_size.value,
y.data(),
residual_out.data(),
a.data(),
b.data(),
w.data(),
residual_out.data(),
None,
)
)
......@@ -147,18 +165,22 @@ def test(
lib_add_rms_norm()
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
# Verify normalized result (y)
if DEBUG:
debug(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol)
assert torch.allclose(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol)
# Verify add result (residual_out) - should be a + b
expected_residual = a.torch_tensor().to(torch.float32) + b.torch_tensor().to(torch.float32)
expected_residual = a.torch_tensor().to(torch.float32) + b.torch_tensor().to(
torch.float32
)
expected_residual = expected_residual.to(a.torch_tensor().dtype)
if DEBUG:
debug(residual_out.actual_tensor(), expected_residual, atol=atol, rtol=rtol)
assert torch.allclose(residual_out.actual_tensor(), expected_residual, atol=atol, rtol=rtol)
assert torch.allclose(
residual_out.actual_tensor(), expected_residual, atol=atol, rtol=rtol
)
# Profiling workflow
if PROFILE:
......
......@@ -15,6 +15,7 @@ from libinfiniop import (
InfiniDtype,
InfiniDtypeNames,
InfiniDeviceNames,
InfiniDeviceEnum,
infiniopOperatorDescriptor_t,
)
from enum import Enum, auto
......@@ -83,6 +84,12 @@ def test(
dtype=torch.float16,
sync=None,
):
# Skip strided cases on Iluvatar: GELU with non-contiguous tensors can hang the GPU (requires ixsmi -r to recover)
if device == InfiniDeviceEnum.ILUVATAR and (
input_stride is not None or output_stride is not None
):
return
input = TestTensor(shape, input_stride, dtype, device)
if inplace == Inplace.INPLACE:
if input_stride != output_stride:
......@@ -141,6 +148,9 @@ def test(
lib_gelu()
if sync is not None:
sync()
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG:
debug(output.actual_tensor(), output.torch_tensor(), atol=atol, rtol=rtol)
......
......@@ -9,6 +9,7 @@ class InfiniDeviceEnum:
KUNLUN = 7
HYGON = 8
QY = 9
ALI = 10
InfiniDeviceNames = {
......@@ -22,6 +23,7 @@ InfiniDeviceNames = {
InfiniDeviceEnum.KUNLUN: "Kunlun",
InfiniDeviceEnum.HYGON: "Hygon",
InfiniDeviceEnum.QY: "QY",
InfiniDeviceEnum.ALI: "Ali",
}
# Mapping that maps InfiniDeviceEnum to torch device string
......@@ -36,4 +38,5 @@ torch_device_map = {
InfiniDeviceEnum.KUNLUN: "cuda",
InfiniDeviceEnum.HYGON: "cuda",
InfiniDeviceEnum.QY: "cuda",
InfiniDeviceEnum.ALI: "cuda",
}
......@@ -393,6 +393,7 @@ def add_rms_norm_(lib):
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
c_float,
]
......@@ -412,6 +413,7 @@ def add_rms_norm_(lib):
c_void_p,
c_void_p,
c_void_p,
c_void_p,
]
lib.infiniopDestroyAddRMSNormDescriptor.restype = c_int32
......@@ -723,6 +725,41 @@ def dequantize_(lib):
]
@OpRegister.operator
def per_channel_quant_int8_(lib):
lib.infiniopCreatePerChannelQuantI8Descriptor.restype = c_int32
lib.infiniopCreatePerChannelQuantI8Descriptor.argtypes = [
infiniopHandle_t,
POINTER(infiniopOperatorDescriptor_t),
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
]
lib.infiniopGetPerChannelQuantI8WorkspaceSize.restype = c_int32
lib.infiniopGetPerChannelQuantI8WorkspaceSize.argtypes = [
infiniopOperatorDescriptor_t,
POINTER(c_size_t),
]
lib.infiniopPerChannelQuantI8.restype = c_int32
lib.infiniopPerChannelQuantI8.argtypes = [
infiniopOperatorDescriptor_t,
c_void_p,
c_size_t,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
c_void_p,
]
lib.infiniopDestroyPerChannelQuantI8Descriptor.restype = c_int32
lib.infiniopDestroyPerChannelQuantI8Descriptor.argtypes = [
infiniopOperatorDescriptor_t,
]
@OpRegister.operator
def softplus_(lib):
lib.infiniopCreateSoftplusDescriptor.restype = c_int32
......@@ -1144,3 +1181,35 @@ def paged_attention_prefill_(lib):
lib.infiniopDestroyPagedAttentionPrefillDescriptor.argtypes = [
infiniopOperatorDescriptor_t,
]
@OpRegister.operator
def silu_and_mul(lib):
lib.infiniopCreateSiluAndMulDescriptor.restype = c_int32
lib.infiniopCreateSiluAndMulDescriptor.argtypes = [
infiniopHandle_t,
POINTER(infiniopOperatorDescriptor_t),
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
]
lib.infiniopGetSiluAndMulWorkspaceSize.restype = c_int32
lib.infiniopGetSiluAndMulWorkspaceSize.argtypes = [
infiniopOperatorDescriptor_t,
POINTER(c_size_t),
]
lib.infiniopSiluAndMul.restype = c_int32
lib.infiniopSiluAndMul.argtypes = [
infiniopOperatorDescriptor_t,
c_void_p,
c_size_t,
c_void_p,
c_void_p,
c_void_p,
]
lib.infiniopDestroySiluAndMulDescriptor.restype = c_int32
lib.infiniopDestroySiluAndMulDescriptor.argtypes = [
infiniopOperatorDescriptor_t,
]
......@@ -433,6 +433,11 @@ def get_args():
action="store_true",
help="Run HYGON DCU test",
)
parser.add_argument(
"--ali",
action="store_true",
help="Run ALI PPU test",
)
return parser.parse_args()
......@@ -487,6 +492,7 @@ def filter_tensor_dtypes_by_device(device, tensor_dtypes):
InfiniDeviceEnum.ASCEND,
InfiniDeviceEnum.ILUVATAR,
InfiniDeviceEnum.CAMBRICON,
InfiniDeviceEnum.ALI,
):
return tensor_dtypes
else:
......@@ -757,6 +763,10 @@ def get_test_devices(args):
import torch
devices_to_test.append(InfiniDeviceEnum.HYGON)
if args.ali:
import torch
devices_to_test.append(InfiniDeviceEnum.ALI)
if not devices_to_test:
devices_to_test = [InfiniDeviceEnum.CPU]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment