Unverified Commit 6680a8c8 authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

issue/436:修补昆仑芯端到端推理遇到的问题 (#437)



* issue/436: support kunlun rope U32

* issue/436: 支持9g7b 4b模型

---------
Co-authored-by: default avatarzhangyue <zhangyue@qiyuanlab.com>
parents badccb86 3bdd832e
......@@ -12,7 +12,7 @@
namespace device::kunlun::kernel {
#define SM_SIZE 10240
#define SM_SIZE 40960
/**
* @brief Define ptrdiff_t and size_t for kunlun xpu
......
......@@ -102,6 +102,8 @@ infiniStatus_t Descriptor::calculate(
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
return INFINI_STATUS_SUCCESS;
}));
xpu_wait(stream);
return INFINI_STATUS_SUCCESS;
}
......
......@@ -120,13 +120,13 @@ Descriptor::calculate(
switch (_info.dt_p) {
case INFINI_DTYPE_F16:
LAUNCH_KERNEL(half, int32_t);
return INFINI_STATUS_SUCCESS;
break;
case INFINI_DTYPE_BF16:
LAUNCH_KERNEL(bfloat16_t, int32_t);
return INFINI_STATUS_SUCCESS;
break;
case INFINI_DTYPE_F32:
LAUNCH_KERNEL(float, int32_t);
return INFINI_STATUS_SUCCESS;
break;
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
......@@ -135,13 +135,13 @@ Descriptor::calculate(
switch (_info.dt_p) {
case INFINI_DTYPE_F16:
LAUNCH_KERNEL(half, int64_t);
return INFINI_STATUS_SUCCESS;
break;
case INFINI_DTYPE_BF16:
LAUNCH_KERNEL(bfloat16_t, int64_t);
return INFINI_STATUS_SUCCESS;
break;
case INFINI_DTYPE_F32:
LAUNCH_KERNEL(float, int64_t);
return INFINI_STATUS_SUCCESS;
break;
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
......
......@@ -159,19 +159,34 @@ infiniStatus_t Descriptor::calculate(
switch (_info.data_type) {
case INFINI_DTYPE_F32:
LAUNCH_KERNEL(float, int32_t);
return INFINI_STATUS_SUCCESS;
break;
case INFINI_DTYPE_F16:
LAUNCH_KERNEL(half, int32_t);
return INFINI_STATUS_SUCCESS;
break;
case INFINI_DTYPE_BF16:
LAUNCH_KERNEL(bfloat16_t, int32_t);
return INFINI_STATUS_SUCCESS;
break;
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else if (_info.pos_type == INFINI_DTYPE_U32) {
switch (_info.data_type) {
case INFINI_DTYPE_F32:
LAUNCH_KERNEL(float, uint32_t);
break;
case INFINI_DTYPE_F16:
LAUNCH_KERNEL(half, uint32_t);
break;
case INFINI_DTYPE_BF16:
LAUNCH_KERNEL(bfloat16_t, uint32_t);
break;
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}
} // namespace op::rope::kunlun
......
......@@ -2,6 +2,7 @@ from ctypes import c_uint64
import ctypes
import sys
import os
import torch
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
from libinfiniop import (
......@@ -21,7 +22,6 @@ from libinfiniop import (
infiniopOperatorDescriptor_t,
)
import torch
def causal_softmax(x):
......
from typing import Sequence
import torch
import ctypes
import numpy as np
from .datatypes import *
from .devices import *
from .liboperators import infiniopTensorDescriptor_t, LIBINFINIOP, infiniopHandle_t
......@@ -87,6 +88,12 @@ class TestTensor(CTensor):
self._torch_tensor = set_tensor.to(to_torch_dtype(dt)).to(
torch_device_map[device]
)
elif mode == "binary":
assert set_tensor is not None
assert torch_shape == list(set_tensor.shape)
self._torch_tensor = set_tensor.to(to_torch_dtype(dt)).to(
torch_device_map[device]
)
else:
raise ValueError("Unsupported mode")
......@@ -95,7 +102,7 @@ class TestTensor(CTensor):
if bias is not None:
self._torch_tensor += bias
if strides is not None:
if strides is not None and mode != "binary":
self._data_tensor = rearrange_tensor(self._torch_tensor, torch_strides)
else:
self._data_tensor = self._torch_tensor.clone()
......@@ -113,6 +120,14 @@ class TestTensor(CTensor):
def is_broadcast(self):
return self.strides is not None and 0 in self.strides
@staticmethod
def from_binary(binary_file, shape, strides, dt: InfiniDtype, device: InfiniDeviceEnum):
data = np.fromfile(binary_file, dtype=to_numpy_dtype(dt))
base = torch.from_numpy(data)
torch_tensor = torch.as_strided(base, size=shape, stride=strides).to(torch_device_map[device])
return TestTensor(
shape, strides, dt, device, mode="binary", set_tensor=torch_tensor)
@staticmethod
def from_torch(torch_tensor, dt: InfiniDtype, device: InfiniDeviceEnum):
......@@ -154,6 +169,38 @@ def to_torch_dtype(dt: InfiniDtype, compatability_mode=False):
raise ValueError("Unsupported data type")
def to_numpy_dtype(dt: InfiniDtype, compatability_mode=False):
if dt == InfiniDtype.I8:
return np.int8
elif dt == InfiniDtype.I16:
return np.int16
elif dt == InfiniDtype.I32:
return np.int32
elif dt == InfiniDtype.I64:
return np.int64
elif dt == InfiniDtype.U8:
return np.uint8
elif dt == InfiniDtype.U16:
return np.uint16 if not compatability_mode else np.int16
elif dt == InfiniDtype.U32:
return np.uint32 if not compatability_mode else np.int32
elif dt == InfiniDtype.U64:
return np.uint64 if not compatability_mode else np.int64
elif dt == InfiniDtype.F16:
return np.float16
elif dt == InfiniDtype.BF16:
# numpy 1.20+ 有 float32 的模拟 bf16 方案: np.dtype("bfloat16")
# 但很多环境里没直接支持,通常要 fallback 到 float32
return np.dtype("bfloat16") if not compatability_mode else np.float32
elif dt == InfiniDtype.F32:
return np.float32
elif dt == InfiniDtype.F64:
return np.float64
else:
raise ValueError("Unsupported data type")
class TestWorkspace:
def __init__(self, size, device):
if size != 0:
......@@ -422,6 +469,9 @@ def print_discrepancy(
is_terminal = sys.stdout.isatty()
actual = actual.to("cpu")
expected = expected.to("cpu")
actual_isnan = torch.isnan(actual)
expected_isnan = torch.isnan(expected)
......
......@@ -75,6 +75,7 @@ _TEST_CASES = [
row_major_strides((3, 4, 50, 50, 5, 7)), # x_stride
column_major_strides((3, 4, 50, 50, 5, 7)), # y_stride
),
((15, 10752), (0, 1), (10752, 1)),
]
# Data types used for testing
......@@ -94,7 +95,7 @@ NUM_ITERATIONS = 1000
def rearrange_torch(y, x, x_shape, y_stride):
y.set_(y.untyped_storage(), 0, x_shape, y_stride)
y[:] = x.view_as(y)
y.copy_(x.expand_as(y))
def test(
......
......@@ -30,6 +30,7 @@ _TEST_CASES_ = [
((2, 2, 4), (2, 2, 4), (4,), (12, 8, 1), (12, 8, 1)),
((16, 2048), (16, 2048), (2048,), None, None),
((16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1)),
((15, 3584), (15, 3584), (3584,), None, None),
((4, 4, 2048), (4, 4, 2048), (2048,), None, None),
((4, 4, 2048), (4, 4, 2048), (2048,), (2048, 8192, 1), (2048, 8192, 1)),
((4, 4, 2048), (4, 4, 2048), (2048,), (16384, 4096, 1), (16384, 4096, 1)),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment