Commit 0450fb1e authored by PanZezhong's avatar PanZezhong
Browse files

issue/161 rope和causal softmax支持非原地

parent 61734617
......@@ -10,14 +10,13 @@
#include "infiniop/ops/expand.h"
#include "infiniop/ops/gemm.h"
#include "infiniop/ops/global_avg_pool.h"
#include "infiniop/ops/gemm.h"
#include "infiniop/ops/max_pool.h"
#include "infiniop/ops/mlp.h"
#include "infiniop/ops/random_sample.h"
#include "infiniop/ops/rearrange.h"
#include "infiniop/ops/relu.h"
#include "infiniop/ops/rms_norm.h"
#include "infiniop/ops/rotary_embedding.h"
#include "infiniop/ops/rope.h"
#include "infiniop/ops/swiglu.h"
#include "infiniop/tensor_descriptor.h"
......
......@@ -5,17 +5,21 @@
typedef struct InfiniopDescriptor *infiniopCausalSoftmaxDescriptor_t;
__C __export infiniStatus_t infiniopCreateCausalSoftmaxDescriptor(infiniopHandle_t handle,
infiniopCausalSoftmaxDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t y_desc);
__C __export infiniStatus_t infiniopCreateCausalSoftmaxDescriptor(
infiniopHandle_t handle,
infiniopCausalSoftmaxDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc);
__C __export infiniStatus_t infiniopGetCausalSoftmaxWorkspaceSize(infiniopCausalSoftmaxDescriptor_t desc, size_t *size);
__C __export infiniStatus_t infiniopCausalSoftmax(infiniopCausalSoftmaxDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *data,
void *stream);
__C __export infiniStatus_t infiniopCausalSoftmax(
infiniopCausalSoftmaxDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *y,
const void *x,
void *stream);
__C __export infiniStatus_t infiniopDestroyCausalSoftmaxDescriptor(infiniopCausalSoftmaxDescriptor_t desc);
......
#ifndef __INFINIOP_ROTARY_EMBEDDING_API_H__
#define __INFINIOP_ROTARY_EMBEDDING_API_H__
#ifndef __INFINIOP_ROPE_API_H__
#define __INFINIOP_ROPE_API_H__
#include "../operator_descriptor.h"
......@@ -8,7 +8,8 @@ typedef struct InfiniopDescriptor *infiniopRoPEDescriptor_t;
__C __export infiniStatus_t infiniopCreateRoPEDescriptor(
infiniopHandle_t handle,
infiniopRoPEDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t t,
infiniopTensorDescriptor_t y,
infiniopTensorDescriptor_t x,
infiniopTensorDescriptor_t pos_ids,
infiniopTensorDescriptor_t sin_table,
infiniopTensorDescriptor_t cos_table);
......@@ -19,7 +20,8 @@ __C __export infiniStatus_t infiniopRoPE(
infiniopRoPEDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *t,
void *y,
const void *x,
void const *pos_ids,
void const *sin_table,
void const *cos_table,
......
......@@ -32,11 +32,13 @@
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t y_desc); \
infiniopTensorDescriptor_t y_desc, \
infiniopTensorDescriptor_t x_desc); \
\
infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
void *data, \
void *y, \
const void *x, \
void *stream) const; \
}; \
}
......
......@@ -9,44 +9,46 @@ Descriptor::~Descriptor() {}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc) {
auto result = CausalSoftmaxInfo::create(y_desc);
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc) {
auto result = CausalSoftmaxInfo::create(y_desc, x_desc);
CHECK_RESULT(result);
*desc_ptr = new Descriptor(nullptr, result.take(), 0, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
template <typename T>
infiniStatus_t causal_softmax(const CausalSoftmaxInfo *info, T *data) {
infiniStatus_t causal_softmax(const CausalSoftmaxInfo *info, T *y, const T *x) {
#pragma omp parallel for
for (ptrdiff_t index = 0; index < ptrdiff_t(info->batch_size * info->seq_len); index++) {
size_t ind = index;
size_t offset = 0;
size_t i = (ind % info->seq_len);
offset += (ind % info->seq_len) * info->stride_i;
ind /= info->seq_len;
offset += (ind % info->batch_size) * info->stride_b;
size_t batch = index / info->seq_len;
size_t i = (index % info->seq_len);
ptrdiff_t y_offset = batch * info->y_stride_b + i * info->y_stride_i;
ptrdiff_t x_offset = batch * info->x_stride_b + i * info->x_stride_i;
T *y_ = y + y_offset;
const T *x_ = x + x_offset;
for (size_t j = info->total_seq_len - info->seq_len + i + 1; j < info->total_seq_len; j++) {
if constexpr (std::is_same<T, fp16_t>::value) {
data[offset + j * info->stride_j] = utils::cast<fp16_t>(0.0f);
y_[j * info->y_stride_j] = utils::cast<fp16_t>(0.0f);
} else {
data[offset + j * info->stride_j] = 0.0f;
y_[j * info->y_stride_j] = 0.0f;
}
}
float val = op::common_cpu::reduce_op::max(&data[offset], info->total_seq_len - info->seq_len + i + 1, info->stride_j);
float val = op::common_cpu::reduce_op::max(x_, info->total_seq_len - info->seq_len + i + 1, info->x_stride_j);
for (size_t j = 0; j <= info->total_seq_len - info->seq_len + i; j++) {
if constexpr (std::is_same<T, fp16_t>::value) {
data[offset + j * info->stride_j] = utils::cast<fp16_t>(std::exp(utils::cast<float>(data[offset + j * info->stride_j]) - val));
y_[j * info->y_stride_j] = utils::cast<fp16_t>(std::exp(utils::cast<float>(x_[j * info->x_stride_j]) - val));
} else {
data[offset + j * info->stride_j] = std::exp(data[offset + j * info->stride_j] - val);
y_[j * info->y_stride_j] = std::exp(x_[j * info->x_stride_j] - val);
}
}
float sum = op::common_cpu::reduce_op::sum(&data[offset], info->total_seq_len - info->seq_len + i + 1, info->stride_j);
float sum = op::common_cpu::reduce_op::sum(y_, info->total_seq_len - info->seq_len + i + 1, info->y_stride_j);
for (size_t j = 0; j <= info->total_seq_len - info->seq_len + i; j++) {
if constexpr (std::is_same<T, fp16_t>::value) {
data[offset + j * info->stride_j] = utils::cast<fp16_t>(utils::cast<float>(data[offset + j * info->stride_j]) / sum);
y_[j * info->y_stride_j] = utils::cast<fp16_t>(utils::cast<float>(y_[j * info->y_stride_j]) / sum);
} else {
data[offset + j * info->stride_j] = data[offset + j * info->stride_j] / sum;
y_[j * info->y_stride_j] = y_[y_offset + j * info->y_stride_j] / sum;
}
}
}
......@@ -56,13 +58,14 @@ infiniStatus_t causal_softmax(const CausalSoftmaxInfo *info, T *data) {
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *data,
void *y,
const void *x,
void *stream) const {
if (_info.dtype == INFINI_DTYPE_F16) {
CHECK_STATUS(causal_softmax<fp16_t>(&_info, (fp16_t *)data));
CHECK_STATUS(causal_softmax<fp16_t>(&_info, (fp16_t *)y, (const fp16_t *)x));
} else if (_info.dtype == INFINI_DTYPE_F32) {
CHECK_STATUS(causal_softmax<float>(&_info, (float *)data));
CHECK_STATUS(causal_softmax<float>(&_info, (float *)y, (const float *)x));
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
......
......@@ -13,45 +13,69 @@ class CausalSoftmaxInfo {
public:
infiniDtype_t dtype;
size_t batch_size;
ptrdiff_t stride_b;
size_t seq_len;
ptrdiff_t stride_i;
size_t total_seq_len;
ptrdiff_t stride_j;
static utils::Result<CausalSoftmaxInfo> create(infiniopTensorDescriptor_t y_desc) {
ptrdiff_t y_stride_b;
ptrdiff_t y_stride_i;
ptrdiff_t y_stride_j;
ptrdiff_t x_stride_b;
ptrdiff_t x_stride_i;
ptrdiff_t x_stride_j;
static utils::Result<CausalSoftmaxInfo> create(infiniopTensorDescriptor_t y_desc, infiniopTensorDescriptor_t x_desc) {
auto dtype = y_desc->dtype();
if (y_desc->dtype() != INFINI_DTYPE_F16 && y_desc->dtype() != INFINI_DTYPE_F32) {
if (dtype != x_desc->dtype()) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32);
auto ndim = y_desc->ndim();
if (ndim != x_desc->ndim()) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (y_desc->ndim() != 2 && y_desc->ndim() != 3) {
if (ndim != 2 && ndim != 3) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (y_desc->shape()[y_desc->ndim() - 1] < y_desc->shape()[y_desc->ndim() - 2]) {
auto shape = y_desc->shape();
if (!SAME_VEC(y_desc->shape(), x_desc->shape())) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (shape[ndim - 1] < shape[ndim - 2]) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
size_t batch_size = 1;
ptrdiff_t stride_b = 0;
size_t seq_len = y_desc->shape()[y_desc->ndim() - 2];
ptrdiff_t stride_i = y_desc->strides()[y_desc->ndim() - 2];
size_t total_seq_len = y_desc->shape()[y_desc->ndim() - 1];
ptrdiff_t stride_j = y_desc->strides()[y_desc->ndim() - 1];
if (y_desc->ndim() == 3) {
stride_b = y_desc->strides()[0];
batch_size = y_desc->shape()[0];
size_t seq_len = shape[ndim - 2];
size_t total_seq_len = shape[ndim - 1];
ptrdiff_t y_stride_b = 0,
x_stride_b = 0;
ptrdiff_t y_stride_i = y_desc->stride(ndim - 2),
y_stride_j = y_desc->stride(ndim - 1);
ptrdiff_t x_stride_i = x_desc->stride(ndim - 2),
x_stride_j = x_desc->stride(ndim - 1);
if (ndim == 3) {
y_stride_b = y_desc->stride(0);
x_stride_b = x_desc->stride(0);
batch_size = shape[0];
}
return utils::Result<CausalSoftmaxInfo>(CausalSoftmaxInfo{
dtype,
batch_size,
stride_b,
seq_len,
stride_i,
total_seq_len,
stride_j});
y_stride_b,
y_stride_i,
y_stride_j,
x_stride_b,
x_stride_i,
x_stride_j});
}
};
......
......@@ -9,14 +9,16 @@
__C infiniStatus_t infiniopCreateCausalSoftmaxDescriptor(
infiniopHandle_t handle,
infiniopCausalSoftmaxDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t y_desc) {
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc) {
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::causal_softmax::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::causal_softmax::NAMESPACE::Descriptor **>(desc_ptr), \
y_desc);
y_desc, \
x_desc);
switch (handle->device) {
#ifdef ENABLE_CPU_API
......@@ -96,12 +98,17 @@ __C infiniStatus_t infiniopGetCausalSoftmaxWorkspaceSize(infiniopCausalSoftmaxDe
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__C infiniStatus_t infiniopCausalSoftmax(infiniopCausalSoftmaxDescriptor_t desc, void *workspace, size_t workspace_size, void *data, void *stream) {
__C infiniStatus_t infiniopCausalSoftmax(
infiniopCausalSoftmaxDescriptor_t desc,
void *workspace, size_t workspace_size,
void *y,
const void *x,
void *stream) {
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<op::causal_softmax::NAMESPACE::Descriptor *>(desc)->calculate( \
workspace, workspace_size, data, stream);
workspace, workspace_size, y, x, stream);
switch (desc->device_type) {
#ifdef ENABLE_CPU_API
......
#include "../../operator.h"
#include "../../handle.h"
#include "infiniop/ops/rotary_embedding.h"
#include "infiniop/ops/rope.h"
__C infiniStatus_t infiniopCreateRoPEDescriptor(
infiniopHandle_t handle, infiniopRoPEDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t t, infiniopTensorDescriptor_t pos_ids,
infiniopHandle_t handle,
infiniopRoPEDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t y,
infiniopTensorDescriptor_t x,
infiniopTensorDescriptor_t pos_ids,
infiniopTensorDescriptor_t sin_table,
infiniopTensorDescriptor_t cos_table) {
switch (handle->device) {
......@@ -91,11 +94,16 @@ __C infiniStatus_t infiniopGetRoPEWorkspaceSize(infiniopRoPEDescriptor_t desc,
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
__C infiniStatus_t infiniopRoPE(infiniopRoPEDescriptor_t desc,
void *workspace, size_t workspace_size,
void *t, const void *pos_ids,
const void *sin_table, const void *cos_table,
void *stream) {
__C infiniStatus_t infiniopRoPE(
infiniopRoPEDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *y,
const void *x,
void const *pos_ids,
void const *sin_table,
void const *cos_table,
void *stream) {
switch (desc->device_type) {
#ifdef ENABLE_CPU
case DevCpu:
......
......@@ -16,18 +16,20 @@ from libinfiniop import (
get_tolerance,
profile_operation,
)
from enum import Enum, auto
# ==============================================================================
# Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES = [
# x_shape, x_stride
((32, 512), None),
((32, 512), (1024, 1)),
((32, 5, 5), None),
((32, 20, 512), None),
((32, 20, 512), (20480, 512, 1)), # Ascend 暂不支持非连续
_TEST_CASES_ = [
# shape, x_stride, y_stride
((3, 3), None, None),
((32, 512), None, None),
((32, 512), (1024, 1), (1024, 1)),
((32, 5, 5), None, None),
((32, 20, 512), None, None),
((32, 20, 512), (20480, 512, 1), None), # Ascend 暂不支持非连续
]
# Data types used for testing
......@@ -38,6 +40,23 @@ _TOLERANCE_MAP = {
torch.float16: {"atol": 0, "rtol": 1e-2},
}
class Inplace(Enum):
OUT_OF_PLACE = auto()
INPLACE_X = auto()
_INPLACE = [
Inplace.OUT_OF_PLACE,
Inplace.INPLACE_X,
]
_TEST_CASES = [
test_case + (inplace_item,)
for test_case in _TEST_CASES_
for inplace_item in _INPLACE
]
DEBUG = False
PROFILE = False
NUM_PRERUN = 10
......@@ -59,12 +78,21 @@ def causal_softmax(x):
return torch.nn.functional.softmax(masked, dim=-1).to(type)
def test(lib, handle, torch_device, x_shape, x_stride=None, dtype=torch.float16):
def test(
lib,
handle,
torch_device,
shape,
x_stride=None,
y_stride=None,
inplace=Inplace.OUT_OF_PLACE,
dtype=torch.float16,
):
print(
f"Testing CausalSoftmax on {torch_device} with x_shape:{x_shape} x_stride:{x_stride} dtype:{dtype}"
f"Testing CausalSoftmax on {torch_device} with shape:{shape} x_stride:{x_stride} y_stride:{y_stride} dtype:{dtype} inplace:{inplace}"
)
x = torch.rand(x_shape, dtype=dtype).to(torch_device)
x = torch.rand(shape, dtype=dtype).to(torch_device)
ans = causal_softmax(x)
......@@ -72,10 +100,18 @@ def test(lib, handle, torch_device, x_shape, x_stride=None, dtype=torch.float16)
x_tensor = to_tensor(x, lib)
if inplace == Inplace.INPLACE_X:
y = x
y_tensor = x_tensor
else:
y = torch.zeros(shape, dtype=dtype).to(torch_device)
y = rearrange_if_needed(y, y_stride)
y_tensor = to_tensor(y, lib)
descriptor = infiniopCausalSoftmaxDescriptor_t()
check_error(
lib.infiniopCreateCausalSoftmaxDescriptor(
handle, ctypes.byref(descriptor), x_tensor.descriptor
handle, ctypes.byref(descriptor), y_tensor.descriptor, x_tensor.descriptor
)
)
......@@ -96,6 +132,7 @@ def test(lib, handle, torch_device, x_shape, x_stride=None, dtype=torch.float16)
descriptor,
workspace.data_ptr() if workspace is not None else None,
workspace_size.value,
y_tensor.data,
x_tensor.data,
None,
)
......@@ -105,8 +142,8 @@ def test(lib, handle, torch_device, x_shape, x_stride=None, dtype=torch.float16)
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG:
debug(x, ans, atol=atol, rtol=rtol)
assert torch.allclose(x, ans, atol=atol, rtol=rtol)
debug(y, ans, atol=atol, rtol=rtol)
assert torch.allclose(y, ans, atol=atol, rtol=rtol)
# Profiling workflow
if PROFILE:
......
......@@ -232,5 +232,5 @@ if __name__ == "__main__":
# Execute tests
for device in get_test_devices(args):
test_operator(lib, device, test, _TEST_CASES, _TENSOR_DTYPES)
print("\033[92mTest passed!\033[0m")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment