Commit 8d1207dd authored by PanZezhong's avatar PanZezhong
Browse files

issue/219 添加attention算子

parent e580d751
......@@ -18,6 +18,7 @@ def run_tests(args):
"rms_norm.py",
"rope.py",
"swiglu.py",
"attention.py",
]:
result = subprocess.run(
f"python {test} {args}", text=True, encoding="utf-8", shell=True
......
#ifndef ATTENTION_H
#define ATTENTION_H
#include "../../operator.h"
#include "info.h"
#define DESCRIPTOR(NAMESPACE) \
\
namespace op::attention::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
size_t _workspace_size; \
\
Descriptor( \
Opaque *opaque, \
size_t workspace_size, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_workspace_size(workspace_size) {} \
\
public: \
~Descriptor(); \
\
size_t workspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t y_desc, \
infiniopTensorDescriptor_t x_desc); \
}; \
}
#endif // ATTENTION_H
#include "../../operator.h"
#include "../../../utils.h"
#include "../../../utils/check.h"
#include "../../handle.h"
#include "../../tensor.h"
#include "infiniop/ops/attention.h"
#include "infiniop/ops/causal_softmax.h"
#include "infiniop/ops/gemm.h"
#include "infiniop/ops/rearrange.h"
#include <cmath>
#include <cstdint>
struct InfiniopAttentionDescriptor {
InfiniopDescriptor _super;
infiniopRearrangeDescriptor_t rearrange_desc_k;
infiniopRearrangeDescriptor_t rearrange_desc_v;
infiniopRearrangeDescriptor_t rearrange_desc_q;
infiniopRearrangeDescriptor_t rearrange_desc_out;
infiniopGemmDescriptor_t matmul_desc1;
infiniopGemmDescriptor_t matmul_desc2;
infiniopCausalSoftmaxDescriptor_t softmax_desc;
uint64_t workspace_size;
uint64_t rearranged_q_size;
uint64_t matmul1_workspace_size;
uint64_t matmul1_tensor_size;
uint64_t matmul2_workspace_size;
uint64_t matmul2_tensor_size;
uint64_t softmax_workspace_size;
uint64_t k_cache_offset;
uint64_t v_cache_offset;
float qk_alpha;
};
__C __export infiniStatus_t infiniopCreateAttentionDescriptor(infiniopHandle_t handle,
infiniopAttentionDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t q_desc,
infiniopTensorDescriptor_t k_desc,
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
uint64_t pos) {
if (out_desc->ndim() != 3 || q_desc->ndim() != 3 || k_desc->ndim() != 3 || v_desc->ndim() != 3 || k_cache_desc->ndim() != 3 || v_cache_desc->ndim() != 3) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
if (!out_desc->isContiguous(0, 2)) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
if (q_desc->strides()[2] != 1 || k_desc->strides()[2] != 1 || v_desc->strides()[2] != 1 || k_cache_desc->strides()[2] != 1 || v_cache_desc->strides()[2] != 1) {
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
uint64_t n_q_head = q_desc->shape()[0];
uint64_t seq_len = q_desc->shape()[1];
uint64_t head_dim = q_desc->shape()[2];
uint64_t hidden_size = n_q_head * head_dim;
uint64_t n_kv_head = k_desc->shape()[0];
uint64_t total_seq_len = seq_len + pos;
uint64_t n_group = n_q_head / n_kv_head;
if (out_desc->shape()[0] != seq_len || out_desc->shape()[1] != n_q_head || out_desc->shape()[2] != head_dim) {
return INFINI_STATUS_BAD_PARAM;
}
// k: [n_kv_head, seq_len, head_dim]
if (k_desc->shape()[0] != n_kv_head || k_desc->shape()[1] != seq_len || k_desc->shape()[2] != head_dim) {
return INFINI_STATUS_BAD_PARAM;
}
// v: [n_kv_head, seq_len, head_dim]
if (v_desc->shape()[0] != n_kv_head || v_desc->shape()[1] != seq_len || v_desc->shape()[2] != head_dim) {
return INFINI_STATUS_BAD_PARAM;
}
// k_cache: [n_kv_head, _, head_dim]
if (k_cache_desc->shape()[0] != n_kv_head || k_cache_desc->shape()[1] < total_seq_len || k_cache_desc->shape()[2] != head_dim) {
return INFINI_STATUS_BAD_PARAM;
}
// v_cache: [n_kv_head, _, head_dim]
if (v_cache_desc->shape()[0] != n_kv_head || v_cache_desc->shape()[1] < total_seq_len || v_cache_desc->shape()[2] != head_dim) {
return INFINI_STATUS_BAD_PARAM;
}
// Rearrange k into k_cache
infiniopTensorDescriptor_t dst_k_desc;
CHECK_STATUS(infiniopCreateTensorDescriptor(&dst_k_desc, 3, k_desc->shape().data(), k_cache_desc->strides().data(), k_cache_desc->dtype()));
infiniopRearrangeDescriptor_t rearrange_desc_k;
CHECK_STATUS(infiniopCreateRearrangeDescriptor(handle, &rearrange_desc_k, dst_k_desc, k_desc));
// Rearrange v into v_cache
infiniopTensorDescriptor_t dst_v_desc;
CHECK_STATUS(infiniopCreateTensorDescriptor(&dst_v_desc, 3, v_desc->shape().data(), v_cache_desc->strides().data(), v_cache_desc->dtype()));
infiniopRearrangeDescriptor_t rearrange_desc_v;
CHECK_STATUS(infiniopCreateRearrangeDescriptor(handle, &rearrange_desc_v, dst_v_desc, v_desc));
infiniopRearrangeDescriptor_t rearrange_desc_q = nullptr;
uint64_t rearranged_q_size = 0;
infiniopTensorDescriptor_t rearranged_q_desc;
// Rearrange q into contiguous
if (!q_desc->isContiguous(0, 1)) {
CHECK_STATUS(infiniopCreateTensorDescriptor(&rearranged_q_desc, 3, q_desc->shape().data(), nullptr, q_desc->dtype()));
rearranged_q_size = rearranged_q_desc->numel() * infiniSizeOf(rearranged_q_desc->dtype());
rearrange_desc_q = new InfiniopDescriptor;
CHECK_STATUS(infiniopCreateRearrangeDescriptor(handle, &rearrange_desc_q, rearranged_q_desc, q_desc));
}
// Matmul1: q * full_k
// q: [n_q_head, seq_len, head_dim] -> [n_kv_head, n_group *seq_len, head_dim]
infiniopTensorDescriptor_t reshaped_q_desc;
CHECK_STATUS(infiniopCreateTensorDescriptor(&reshaped_q_desc, 3, q_desc->shape().data(), nullptr, q_desc->dtype()));
TRANSFORM_TENSOR_DESC(reshaped_q_desc, dimSplit(0, {n_kv_head, n_group}));
TRANSFORM_TENSOR_DESC(reshaped_q_desc, dimMerge(1, 2));
// full_k: [n_kv_head, head_dim, total_seq_len]
infiniopTensorDescriptor_t full_k_desc;
uint64_t full_k_shape[3] = {n_kv_head, total_seq_len, head_dim};
CHECK_STATUS(infiniopCreateTensorDescriptor(&full_k_desc, 3, full_k_shape, k_cache_desc->strides().data(), k_cache_desc->dtype()));
TRANSFORM_TENSOR_DESC(full_k_desc, dimPermute({0, 2, 1}));
// qk: [n_kv_head, n_group * seq_len, total_seq_len]
infiniopTensorDescriptor_t qk_desc;
uint64_t qk_shape[3] = {n_kv_head, n_group * seq_len, total_seq_len};
CHECK_STATUS(infiniopCreateTensorDescriptor(&qk_desc, 3, qk_shape, nullptr, q_desc->dtype()));
// matmul1_desc
// qk_alpha
float qk_alpha = 1 / sqrt(head_dim);
infiniopGemmDescriptor_t matmul1_desc;
CHECK_STATUS(infiniopCreateGemmDescriptor(handle, &matmul1_desc, qk_desc, reshaped_q_desc, full_k_desc));
// matmul1 workspace size
uint64_t matmul1_workspace_size;
CHECK_STATUS(infiniopGetGemmWorkspaceSize(matmul1_desc, &matmul1_workspace_size));
// matmul1 tensor size
uint64_t matmul1_tensor_size = qk_desc->numel() * infiniSizeOf(qk_desc->dtype());
// CausalSoftmax: softmax(qk)
// qk: [n_kv_head, n_group * seq_len, total_seq_len] -> [n_q_head, seq_len, total_seq_len]
TRANSFORM_TENSOR_DESC(qk_desc, dimSplit(1, {n_group, seq_len}));
TRANSFORM_TENSOR_DESC(qk_desc, dimMerge(0, 1));
infiniopCausalSoftmaxDescriptor_t softmax_desc;
CHECK_STATUS(infiniopCreateCausalSoftmaxDescriptor(handle, &softmax_desc, qk_desc, qk_desc));
// softmax workspace size
uint64_t softmax_workspace_size;
CHECK_STATUS(infiniopGetCausalSoftmaxWorkspaceSize(softmax_desc, &softmax_workspace_size));
// Matmul2: softmax(qk) * full_v
// softmax(qk): [n_q_head, seq_len, total_seq_len] -> [n_kv_head, n_group * seq_len, total_seq_len]
// full_v: [n_kv_head, total_seq_len, head_dim]
TRANSFORM_TENSOR_DESC(qk_desc, dimSplit(0, {n_kv_head, n_group}));
TRANSFORM_TENSOR_DESC(qk_desc, dimMerge(1, 2));
infiniopTensorDescriptor_t full_v_desc;
uint64_t full_v_shape[3] = {n_kv_head, total_seq_len, head_dim};
CHECK_STATUS(infiniopCreateTensorDescriptor(&full_v_desc, 3, full_v_shape, v_cache_desc->strides().data(), v_cache_desc->dtype()));
// temp_out: [n_kv_head, n_group * seq_len, head_dim]
infiniopTensorDescriptor_t temp_out_desc;
uint64_t temp_out_shape[3] = {n_kv_head, n_group * seq_len, head_dim};
CHECK_STATUS(infiniopCreateTensorDescriptor(&temp_out_desc, 3, temp_out_shape, nullptr, q_desc->dtype()));
// matmul2_desc
infiniopGemmDescriptor_t matmul2_desc;
CHECK_STATUS(infiniopCreateGemmDescriptor(handle, &matmul2_desc, temp_out_desc, qk_desc, full_v_desc));
// matmul2 workspace size
uint64_t matmul2_workspace_size;
CHECK_STATUS(infiniopGetGemmWorkspaceSize(matmul2_desc, &matmul2_workspace_size));
// matmul2 tensor size
uint64_t matmul2_tensor_size = temp_out_desc->numel() * infiniSizeOf(temp_out_desc->dtype());
// Rearrange temp_out into out
// out: [seq_len, n_q_head, head_dim]
// temp_out: [n_kv_head, n_group * seq_len, head_dim] -> [n_q_head, seq_len, head_dim] -> [seq_len, n_q_head, head_dim]
TRANSFORM_TENSOR_DESC(temp_out_desc, dimSplit(1, {n_group, seq_len}));
TRANSFORM_TENSOR_DESC(temp_out_desc, dimMerge(0, 1));
TRANSFORM_TENSOR_DESC(temp_out_desc, dimPermute({1, 0, 2}));
infiniopRearrangeDescriptor_t rearrange_desc_out;
CHECK_STATUS(infiniopCreateRearrangeDescriptor(handle, &rearrange_desc_out, out_desc, temp_out_desc));
// workspace size
uint64_t workspace_size = rearranged_q_size + std::max(std::max(matmul1_workspace_size + matmul1_tensor_size, matmul1_tensor_size + softmax_workspace_size), matmul1_tensor_size + matmul2_workspace_size + matmul2_tensor_size);
// k_cache_offset
uint64_t k_cache_offset = 0;
if (pos > 0) {
k_cache_offset = pos * k_cache_desc->getByteStrides()[1];
}
// v_cache_offset
uint64_t v_cache_offset = 0;
if (pos > 0) {
v_cache_offset = pos * v_cache_desc->getByteStrides()[1];
}
// create attention descriptor
*(InfiniopAttentionDescriptor **)desc_ptr = new InfiniopAttentionDescriptor{
{handle->device, handle->device_id},
rearrange_desc_k,
rearrange_desc_v,
rearrange_desc_q,
rearrange_desc_out,
matmul1_desc,
matmul2_desc,
softmax_desc,
workspace_size,
rearranged_q_size,
matmul1_workspace_size,
matmul1_tensor_size,
matmul2_workspace_size,
matmul2_tensor_size,
softmax_workspace_size,
k_cache_offset,
v_cache_offset,
1.f / std::sqrt(float(head_dim)),
};
return INFINI_STATUS_SUCCESS;
}
__C __export infiniStatus_t infiniopGetAttentionWorkspaceSize(infiniopAttentionDescriptor_t desc, uint64_t *size) {
*size = ((InfiniopAttentionDescriptor *)desc)->workspace_size;
return INFINI_STATUS_SUCCESS;
}
__C __export infiniStatus_t infiniopAttention(infiniopAttentionDescriptor_t desc_,
void *workspace,
uint64_t workspace_size,
void *out,
void const *q,
void const *k,
void const *v,
void *k_cache,
void *v_cache,
void *stream) {
auto desc = (InfiniopAttentionDescriptor *)desc_;
void *workspace_ = workspace;
if (workspace_size < desc->workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE; // STATUS_MEMORY_NOT_ALLOCATED
}
// concat k and v to k_cache and v_cache
CHECK_STATUS(infiniopRearrange(desc->rearrange_desc_k,
(char *)k_cache + desc->k_cache_offset, k, stream));
CHECK_STATUS(infiniopRearrange(desc->rearrange_desc_v,
(char *)v_cache + desc->v_cache_offset, v, stream));
// rearrange q into contiguous
void const *_q = q;
if (desc->rearrange_desc_q) {
CHECK_STATUS(infiniopRearrange(desc->rearrange_desc_q, (char *)workspace_, q, stream));
_q = workspace_;
workspace_ = (char *)workspace_ + desc->rearranged_q_size;
}
// matmul1: q * full_k
CHECK_STATUS(infiniopGemm(desc->matmul_desc1,
(char *)workspace_ + desc->matmul1_tensor_size, workspace_size - desc->matmul1_tensor_size,
workspace_, _q, k_cache, desc->qk_alpha, 0.0, stream));
// softmax(qk)
CHECK_STATUS(infiniopCausalSoftmax(desc->softmax_desc,
(char *)workspace_ + desc->matmul1_tensor_size, workspace_size - desc->matmul1_tensor_size,
workspace_, workspace_, stream));
// matmul2: softmax(qk) * full_v
CHECK_STATUS(infiniopGemm(desc->matmul_desc2,
(char *)workspace_ + desc->matmul1_tensor_size + desc->matmul2_tensor_size,
workspace_size - desc->matmul1_tensor_size - desc->matmul2_tensor_size,
(char *)workspace_ + desc->matmul1_tensor_size, workspace_, v_cache, 1.0, 0.0, stream));
// rearrange out
CHECK_STATUS(infiniopRearrange(desc->rearrange_desc_out, out, (char *)workspace_ + desc->matmul1_tensor_size, stream));
return INFINI_STATUS_SUCCESS;
}
__C __export infiniStatus_t infiniopDestroyAttentionDescriptor(infiniopAttentionDescriptor_t desc_) {
auto desc = (InfiniopAttentionDescriptor *)desc_;
if (desc->rearrange_desc_q) {
CHECK_STATUS(infiniopDestroyRearrangeDescriptor(desc->rearrange_desc_q));
}
CHECK_STATUS(infiniopDestroyRearrangeDescriptor(desc->rearrange_desc_k));
CHECK_STATUS(infiniopDestroyRearrangeDescriptor(desc->rearrange_desc_v));
CHECK_STATUS(infiniopDestroyRearrangeDescriptor(desc->rearrange_desc_out));
CHECK_STATUS(infiniopDestroyGemmDescriptor(desc->matmul_desc1));
CHECK_STATUS(infiniopDestroyGemmDescriptor(desc->matmul_desc2));
CHECK_STATUS(infiniopDestroyCausalSoftmaxDescriptor(desc->softmax_desc));
delete desc;
return INFINI_STATUS_SUCCESS;
}
......@@ -48,7 +48,7 @@ infiniStatus_t causal_softmax(const CausalSoftmaxInfo *info, T *y, const T *x) {
if constexpr (std::is_same<T, fp16_t>::value) {
y_[j * info->y_stride_j] = utils::cast<fp16_t>(utils::cast<float>(y_[j * info->y_stride_j]) / sum);
} else {
y_[j * info->y_stride_j] = y_[y_offset + j * info->y_stride_j] / sum;
y_[j * info->y_stride_j] = y_[j * info->y_stride_j] / sum;
}
}
}
......
......@@ -18,7 +18,7 @@ __device__ __forceinline__ Tcompute sumSquared(const Tdata *data_ptr, size_t cou
// Each thread computes its partial sum
for (size_t i = threadIdx.x; i < count; i += BLOCK_SIZE) {
ss += Tcompute(data_ptr[i] * data_ptr[i]);
ss += Tcompute(data_ptr[i]) * Tcompute(data_ptr[i]);
}
// Use CUB block-level reduction
......
......@@ -2,9 +2,19 @@
#define __INFINIOP_TENSOR_H__
#include "infiniop/tensor_descriptor.h"
#include "../utils.h"
#include <string>
#include <vector>
#define TRANSFORM_TENSOR_DESC(__TENSOR_DESC__, __OP__) \
do { \
auto __RESULT__ = __TENSOR_DESC__->__OP__; \
CHECK_RESULT(__RESULT__); \
__TENSOR_DESC__ = __RESULT__.take(); \
} while (0)
struct InfiniopTensorDescriptor {
private:
// Datatype
......@@ -32,9 +42,9 @@ public:
bool hasBroadcastDim() const;
std::vector<size_t> getBroadcastDim() const;
infiniopTensorDescriptor_t dimMerge(size_t dim_start, size_t dim_end) const;
infiniopTensorDescriptor_t dimSplit(size_t axis, const std::vector<size_t> &dims) const;
infiniopTensorDescriptor_t dimPermute(const std::vector<size_t> &order) const;
utils::Result<infiniopTensorDescriptor_t> dimMerge(size_t dim_start, size_t dim_end) const;
utils::Result<infiniopTensorDescriptor_t> dimSplit(size_t axis, const std::vector<size_t> &dims) const;
utils::Result<infiniopTensorDescriptor_t> dimPermute(const std::vector<size_t> &order) const;
std::string toString() const;
};
......
......@@ -12,7 +12,7 @@ __C __export infiniStatus_t infiniopCreateTensorDescriptor(infiniopTensorDescrip
std::vector<ptrdiff_t> strides(ndim);
ptrdiff_t dsize = 1;
if (ndim > 0) {
for (size_t i = ndim - 1; i >= 0; i--) {
for (int i = (int)ndim - 1; i >= 0; i--) {
strides[i] = dsize;
dsize *= shape_[i];
}
......@@ -104,10 +104,8 @@ std::vector<size_t> InfiniopTensorDescriptor::getBroadcastDim() const {
return res;
}
infiniopTensorDescriptor_t InfiniopTensorDescriptor::dimMerge(size_t dim_start, size_t dim_end) const {
if (dim_start > dim_end || dim_end >= ndim()) {
return nullptr;
}
utils::Result<infiniopTensorDescriptor_t> InfiniopTensorDescriptor::dimMerge(size_t dim_start, size_t dim_end) const {
CHECK_OR_RETURN(dim_start <= dim_end && dim_end < ndim(), INFINI_STATUS_BAD_PARAM);
size_t new_ndim = ndim() - (dim_end - dim_start);
std::vector<size_t> new_shape(new_ndim);
......@@ -120,9 +118,7 @@ infiniopTensorDescriptor_t InfiniopTensorDescriptor::dimMerge(size_t dim_start,
index++;
}
if (!isContiguous(dim_start, dim_end)) {
return nullptr;
}
CHECK_OR_RETURN(isContiguous(dim_start, dim_end), INFINI_STATUS_BAD_PARAM);
new_shape[index] = 1;
for (size_t i = dim_start; i <= dim_end; i++) {
......@@ -138,15 +134,15 @@ infiniopTensorDescriptor_t InfiniopTensorDescriptor::dimMerge(size_t dim_start,
index++;
}
return new InfiniopTensorDescriptor(_dtype, new_ndim, new_shape.data(), new_strides.data());
return utils::Result<infiniopTensorDescriptor_t>(
new InfiniopTensorDescriptor(_dtype, new_ndim, new_shape.data(), new_strides.data()));
}
infiniopTensorDescriptor_t InfiniopTensorDescriptor::dimSplit(size_t axis, const std::vector<size_t> &dims) const {
utils::Result<infiniopTensorDescriptor_t> InfiniopTensorDescriptor::dimSplit(size_t axis, const std::vector<size_t> &dims) const {
size_t ndim_ = ndim();
if (dim(axis) != std::accumulate(dims.begin(), dims.end(), (size_t)1, std::multiplies<size_t>())) {
return nullptr;
}
CHECK_OR_RETURN(dim(axis) == std::accumulate(dims.begin(), dims.end(), (size_t)1, std::multiplies<size_t>()),
INFINI_STATUS_BAD_PARAM);
size_t new_ndim = ndim_ + dims.size() - 1;
std::vector<size_t> new_shape(new_ndim);
......@@ -168,24 +164,22 @@ infiniopTensorDescriptor_t InfiniopTensorDescriptor::dimSplit(size_t axis, const
index++;
}
return new InfiniopTensorDescriptor(_dtype, new_ndim, new_shape.data(), new_strides.data());
return utils::Result<infiniopTensorDescriptor_t>(
new InfiniopTensorDescriptor(_dtype, new_ndim, new_shape.data(), new_strides.data()));
}
infiniopTensorDescriptor_t InfiniopTensorDescriptor::dimPermute(const std::vector<size_t> &order) const {
utils::Result<infiniopTensorDescriptor_t> InfiniopTensorDescriptor::dimPermute(const std::vector<size_t> &order) const {
auto ndim_ = ndim();
if (order.size() != ndim_) {
return nullptr;
}
CHECK_OR_RETURN(order.size() == ndim_, INFINI_STATUS_BAD_PARAM);
std::vector<size_t> new_shape(ndim_);
std::vector<ptrdiff_t> new_strides(ndim_);
for (size_t i = 0; i < ndim_; i++) {
if (std::find(order.begin(), order.end(), i) == order.end()) {
return nullptr;
}
CHECK_OR_RETURN(std::find(order.begin(), order.end(), i) != order.end(), INFINI_STATUS_BAD_PARAM);
new_shape[i] = dim(order[i]);
new_strides[i] = stride(order[i]);
}
return new InfiniopTensorDescriptor(_dtype, ndim_, new_shape.data(), new_strides.data());
return utils::Result<infiniopTensorDescriptor_t>(
new InfiniopTensorDescriptor(_dtype, ndim_, new_shape.data(), new_strides.data()));
}
std::string InfiniopTensorDescriptor::toString() const {
......
from ctypes import POINTER, Structure, c_int32, c_uint64, c_void_p, c_float, c_bool
from ctypes import POINTER, Structure, c_int32, c_uint64, c_void_p
import ctypes
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
from operatorspy import (
from libinfiniop import (
open_lib,
to_tensor,
CTensor,
DeviceEnum,
infiniopHandle_t,
infiniopTensorDescriptor_t,
create_handle,
destroy_handle,
check_error,
rearrange_tensor,
create_workspace,
get_args,
get_test_devices,
test_operator,
debug,
get_tolerance,
profile_operation,
)
from operatorspy.tests.test_utils import get_args
import torch
import torch.nn.functional as F
class AttentionDescriptor(Structure):
......@@ -95,13 +95,13 @@ def test(
pos,
k_cache_buf_len,
v_cache_buf_len,
dtype=torch.float16,
q_stride=None,
k_stride=None,
v_stride=None,
k_cache_stride=None,
v_cache_stride=None,
sync=None
dtype=torch.float16,
sync=None,
):
print(
f"Testing Attention on {torch_device} with n_q_head:{n_q_head} n_kv_head:{n_kv_head} seq_len:{seq_len} head_dim:{head_dim} pos:{pos} "
......@@ -160,12 +160,15 @@ def test(
)
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
out_tensor.descriptor.contents.invalidate()
q_tensor.descriptor.contents.invalidate()
k_tensor.descriptor.contents.invalidate()
v_tensor.descriptor.contents.invalidate()
k_cache_tensor.descriptor.contents.invalidate()
v_cache_tensor.descriptor.contents.invalidate()
for tensor in [
out_tensor,
q_tensor,
k_tensor,
v_tensor,
k_cache_tensor,
v_cache_tensor,
]:
tensor.destroyDesc(lib)
workspace_size = c_uint64(0)
check_error(
......@@ -173,6 +176,7 @@ def test(
)
workspace = create_workspace(workspace_size.value, out.device)
def lib_attention():
check_error(
lib.infiniopAttention(
descriptor,
......@@ -188,137 +192,36 @@ def test(
)
)
assert torch.allclose(out, ans, atol=1e-4, rtol=1e-2)
check_error(lib.infiniopDestroyAttentionDescriptor(descriptor))
def test_cpu(lib, test_cases):
device = DeviceEnum.DEVICE_CPU
handle = create_handle(lib, device)
for (
n_q_head,
n_kv_head,
seq_len,
head_dim,
pos,
k_cache_buf_len,
v_cache_buf_len,
dtype,
q_stride,
k_stride,
v_stride,
k_cache_stride,
v_cache_stride,
) in test_cases:
test(
lib,
handle,
"cpu",
n_q_head,
n_kv_head,
seq_len,
head_dim,
pos,
k_cache_buf_len,
v_cache_buf_len,
dtype,
q_stride,
k_stride,
v_stride,
k_cache_stride,
v_cache_stride,
)
destroy_handle(lib, handle)
def test_cuda(lib, test_cases):
device = DeviceEnum.DEVICE_CUDA
handle = create_handle(lib, device)
for (
n_q_head,
n_kv_head,
seq_len,
head_dim,
pos,
k_cache_buf_len,
v_cache_buf_len,
dtype,
q_stride,
k_stride,
v_stride,
k_cache_stride,
v_cache_stride,
) in test_cases:
test(
lib,
handle,
"cuda",
n_q_head,
n_kv_head,
seq_len,
head_dim,
pos,
k_cache_buf_len,
v_cache_buf_len,
dtype,
q_stride,
k_stride,
v_stride,
k_cache_stride,
v_cache_stride,
)
destroy_handle(lib, handle)
lib_attention()
def test_bang(lib, test_cases):
import torch_mlu
device = DeviceEnum.DEVICE_BANG
handle = create_handle(lib, device)
for (
n_q_head,
n_kv_head,
seq_len,
head_dim,
pos,
k_cache_buf_len,
v_cache_buf_len,
dtype,
q_stride,
k_stride,
v_stride,
k_cache_stride,
v_cache_stride,
) in test_cases:
test(
lib,
handle,
"mlu",
n_q_head,
n_kv_head,
seq_len,
head_dim,
pos,
k_cache_buf_len,
v_cache_buf_len,
dtype,
q_stride,
k_stride,
v_stride,
k_cache_stride,
v_cache_stride,
)
# Validate results
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG:
debug(out, ans, atol=atol, rtol=rtol)
assert torch.allclose(out, ans, atol=atol, rtol=rtol)
destroy_handle(lib, handle)
# Profiling workflow
if PROFILE:
# fmt: off
profile_operation("PyTorch", lambda: attention(q, k, v, k_cache, v_cache, pos), torch_device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lambda: lib_attention(), torch_device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on
check_error(lib.infiniopDestroyAttentionDescriptor(descriptor))
if __name__ == "__main__":
_TENSOR_DTYPES = [torch.float16, torch.float32]
# Tolerance map for different data types
_TOLERANCE_MAP = {
torch.float16: {"atol": 1e-4, "rtol": 1e-2},
torch.float32: {"atol": 1e-6, "rtol": 1e-4},
}
DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000
test_cases = [
# prefill
(
......@@ -329,7 +232,6 @@ if __name__ == "__main__":
0, # pos
2048, # k_cache_buf_len
2048, # v_cache_buf_len
torch.float16, # dtype
[64, 2560, 1], # q_stride
[64, 2560, 1], # k_stride
[64, 2560, 1], # v_stride
......@@ -345,7 +247,6 @@ if __name__ == "__main__":
3, # pos
2048, # k_cache_buf_len
2048, # v_cache_buf_len
torch.float16, # dtype
[64, 2560, 1], # q_stride
[64, 2560, 1], # k_stride
[64, 2560, 1], # v_stride
......@@ -361,7 +262,6 @@ if __name__ == "__main__":
1, # pos
8, # k_cache_buf_len
8, # v_cache_buf_len
torch.float16, # dtype
None, # q_stride
None, # k_stride
None, # v_stride
......@@ -410,12 +310,13 @@ if __name__ == "__main__":
infiniopAttentionDescriptor_t,
]
if args.cpu:
test_cpu(lib, test_cases)
if args.cuda:
test_cuda(lib, test_cases)
if args.bang:
test_bang(lib, test_cases)
if not (args.cpu or args.cuda or args.bang):
test_cpu(lib, test_cases)
# Configure testing options
DEBUG = args.debug
PROFILE = args.profile
NUM_PRERUN = args.num_prerun
NUM_ITERATIONS = args.num_iterations
# Execute tests
for device in get_test_devices(args):
test_operator(lib, device, test, test_cases, _TENSOR_DTYPES)
print("\033[92mTest passed!\033[0m")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment