Unverified Commit 5beab8c0 authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #229 from InfiniTensor/issue/191/fix

issue/191/fix 为attention增加对齐,修复cuda causal softmax
parents 2e624a8e b79f2607
......@@ -20,15 +20,14 @@ struct InfiniopAttentionDescriptor {
infiniopGemmDescriptor_t matmul_desc1;
infiniopGemmDescriptor_t matmul_desc2;
infiniopCausalSoftmaxDescriptor_t softmax_desc;
uint64_t workspace_size;
uint64_t rearranged_q_size;
uint64_t matmul1_workspace_size;
uint64_t matmul1_tensor_size;
uint64_t matmul2_workspace_size;
uint64_t matmul2_tensor_size;
uint64_t softmax_workspace_size;
uint64_t k_cache_offset;
uint64_t v_cache_offset;
size_t workspace_size;
size_t op_workspace_offset;
size_t op_workspace_size;
size_t q_cont_offset;
size_t att_score_offset;
size_t att_val_offset;
size_t k_cache_offset;
size_t v_cache_offset;
float qk_alpha;
};
......@@ -40,7 +39,7 @@ __C __export infiniStatus_t infiniopCreateAttentionDescriptor(infiniopHandle_t h
infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
uint64_t pos) {
size_t pos) {
if (out_desc->ndim() != 3 || q_desc->ndim() != 3 || k_desc->ndim() != 3 || v_desc->ndim() != 3 || k_cache_desc->ndim() != 3 || v_cache_desc->ndim() != 3) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
......@@ -53,13 +52,14 @@ __C __export infiniStatus_t infiniopCreateAttentionDescriptor(infiniopHandle_t h
return INFINI_STATUS_BAD_TENSOR_STRIDES;
}
uint64_t n_q_head = q_desc->shape()[0];
uint64_t seq_len = q_desc->shape()[1];
uint64_t head_dim = q_desc->shape()[2];
uint64_t hidden_size = n_q_head * head_dim;
uint64_t n_kv_head = k_desc->shape()[0];
uint64_t total_seq_len = seq_len + pos;
uint64_t n_group = n_q_head / n_kv_head;
size_t n_q_head = q_desc->shape()[0];
size_t seq_len = q_desc->shape()[1];
size_t head_dim = q_desc->shape()[2];
size_t hidden_size = n_q_head * head_dim;
size_t n_kv_head = k_desc->shape()[0];
size_t total_seq_len = seq_len + pos;
size_t n_group = n_q_head / n_kv_head;
size_t alignment = 256;
if (out_desc->shape()[0] != seq_len || out_desc->shape()[1] != n_q_head || out_desc->shape()[2] != head_dim) {
return INFINI_STATUS_BAD_PARAM;
......@@ -98,12 +98,12 @@ __C __export infiniStatus_t infiniopCreateAttentionDescriptor(infiniopHandle_t h
CHECK_STATUS(infiniopCreateRearrangeDescriptor(handle, &rearrange_desc_v, dst_v_desc, v_desc));
infiniopRearrangeDescriptor_t rearrange_desc_q = nullptr;
uint64_t rearranged_q_size = 0;
size_t q_cont_size = 0;
infiniopTensorDescriptor_t rearranged_q_desc;
// Rearrange q into contiguous
if (!q_desc->isContiguous(0, 1)) {
CHECK_STATUS(infiniopCreateTensorDescriptor(&rearranged_q_desc, 3, q_desc->shape().data(), nullptr, q_desc->dtype()));
rearranged_q_size = rearranged_q_desc->numel() * infiniSizeOf(rearranged_q_desc->dtype());
q_cont_size = utils::align(rearranged_q_desc->numel() * infiniSizeOf(rearranged_q_desc->dtype()), alignment);
rearrange_desc_q = new InfiniopDescriptor;
CHECK_STATUS(infiniopCreateRearrangeDescriptor(handle, &rearrange_desc_q, rearranged_q_desc, q_desc));
}
......@@ -116,12 +116,12 @@ __C __export infiniStatus_t infiniopCreateAttentionDescriptor(infiniopHandle_t h
TRANSFORM_TENSOR_DESC(reshaped_q_desc, dimMerge(1, 2));
// full_k: [n_kv_head, head_dim, total_seq_len]
infiniopTensorDescriptor_t full_k_desc;
uint64_t full_k_shape[3] = {n_kv_head, total_seq_len, head_dim};
size_t full_k_shape[3] = {n_kv_head, total_seq_len, head_dim};
CHECK_STATUS(infiniopCreateTensorDescriptor(&full_k_desc, 3, full_k_shape, k_cache_desc->strides().data(), k_cache_desc->dtype()));
TRANSFORM_TENSOR_DESC(full_k_desc, dimPermute({0, 2, 1}));
// qk: [n_kv_head, n_group * seq_len, total_seq_len]
infiniopTensorDescriptor_t qk_desc;
uint64_t qk_shape[3] = {n_kv_head, n_group * seq_len, total_seq_len};
size_t qk_shape[3] = {n_kv_head, n_group * seq_len, total_seq_len};
CHECK_STATUS(infiniopCreateTensorDescriptor(&qk_desc, 3, qk_shape, nullptr, q_desc->dtype()));
// matmul1_desc
// qk_alpha
......@@ -129,10 +129,11 @@ __C __export infiniStatus_t infiniopCreateAttentionDescriptor(infiniopHandle_t h
infiniopGemmDescriptor_t matmul1_desc;
CHECK_STATUS(infiniopCreateGemmDescriptor(handle, &matmul1_desc, qk_desc, reshaped_q_desc, full_k_desc));
// matmul1 workspace size
uint64_t matmul1_workspace_size;
size_t matmul1_workspace_size;
CHECK_STATUS(infiniopGetGemmWorkspaceSize(matmul1_desc, &matmul1_workspace_size));
// matmul1 tensor size
uint64_t matmul1_tensor_size = qk_desc->numel() * infiniSizeOf(qk_desc->dtype());
matmul1_workspace_size = utils::align(matmul1_workspace_size, alignment);
// attention score tensor size
size_t attn_score_size = utils::align(qk_desc->numel() * infiniSizeOf(qk_desc->dtype()), alignment);
// CausalSoftmax: softmax(qk)
// qk: [n_kv_head, n_group * seq_len, total_seq_len] -> [n_q_head, seq_len, total_seq_len]
......@@ -141,8 +142,9 @@ __C __export infiniStatus_t infiniopCreateAttentionDescriptor(infiniopHandle_t h
infiniopCausalSoftmaxDescriptor_t softmax_desc;
CHECK_STATUS(infiniopCreateCausalSoftmaxDescriptor(handle, &softmax_desc, qk_desc, qk_desc));
// softmax workspace size
uint64_t softmax_workspace_size;
size_t softmax_workspace_size;
CHECK_STATUS(infiniopGetCausalSoftmaxWorkspaceSize(softmax_desc, &softmax_workspace_size));
softmax_workspace_size = utils::align(softmax_workspace_size, alignment);
// Matmul2: softmax(qk) * full_v
// softmax(qk): [n_q_head, seq_len, total_seq_len] -> [n_kv_head, n_group * seq_len, total_seq_len]
......@@ -150,41 +152,44 @@ __C __export infiniStatus_t infiniopCreateAttentionDescriptor(infiniopHandle_t h
TRANSFORM_TENSOR_DESC(qk_desc, dimSplit(0, {n_kv_head, n_group}));
TRANSFORM_TENSOR_DESC(qk_desc, dimMerge(1, 2));
infiniopTensorDescriptor_t full_v_desc;
uint64_t full_v_shape[3] = {n_kv_head, total_seq_len, head_dim};
size_t full_v_shape[3] = {n_kv_head, total_seq_len, head_dim};
CHECK_STATUS(infiniopCreateTensorDescriptor(&full_v_desc, 3, full_v_shape, v_cache_desc->strides().data(), v_cache_desc->dtype()));
// temp_out: [n_kv_head, n_group * seq_len, head_dim]
infiniopTensorDescriptor_t temp_out_desc;
uint64_t temp_out_shape[3] = {n_kv_head, n_group * seq_len, head_dim};
CHECK_STATUS(infiniopCreateTensorDescriptor(&temp_out_desc, 3, temp_out_shape, nullptr, q_desc->dtype()));
infiniopTensorDescriptor_t att_val_desc;
size_t temp_out_shape[3] = {n_kv_head, n_group * seq_len, head_dim};
CHECK_STATUS(infiniopCreateTensorDescriptor(&att_val_desc, 3, temp_out_shape, nullptr, q_desc->dtype()));
// matmul2_desc
infiniopGemmDescriptor_t matmul2_desc;
CHECK_STATUS(infiniopCreateGemmDescriptor(handle, &matmul2_desc, temp_out_desc, qk_desc, full_v_desc));
CHECK_STATUS(infiniopCreateGemmDescriptor(handle, &matmul2_desc, att_val_desc, qk_desc, full_v_desc));
// matmul2 workspace size
uint64_t matmul2_workspace_size;
size_t matmul2_workspace_size;
CHECK_STATUS(infiniopGetGemmWorkspaceSize(matmul2_desc, &matmul2_workspace_size));
// matmul2 tensor size
uint64_t matmul2_tensor_size = temp_out_desc->numel() * infiniSizeOf(temp_out_desc->dtype());
matmul2_workspace_size = utils::align(matmul2_workspace_size, alignment);
// attention value tensor size
size_t att_val_size = utils::align(att_val_desc->numel() * infiniSizeOf(att_val_desc->dtype()), alignment);
// Rearrange temp_out into out
// out: [seq_len, n_q_head, head_dim]
// temp_out: [n_kv_head, n_group * seq_len, head_dim] -> [n_q_head, seq_len, head_dim] -> [seq_len, n_q_head, head_dim]
TRANSFORM_TENSOR_DESC(temp_out_desc, dimSplit(1, {n_group, seq_len}));
TRANSFORM_TENSOR_DESC(temp_out_desc, dimMerge(0, 1));
TRANSFORM_TENSOR_DESC(temp_out_desc, dimPermute({1, 0, 2}));
TRANSFORM_TENSOR_DESC(att_val_desc, dimSplit(1, {n_group, seq_len}));
TRANSFORM_TENSOR_DESC(att_val_desc, dimMerge(0, 1));
TRANSFORM_TENSOR_DESC(att_val_desc, dimPermute({1, 0, 2}));
infiniopRearrangeDescriptor_t rearrange_desc_out;
CHECK_STATUS(infiniopCreateRearrangeDescriptor(handle, &rearrange_desc_out, out_desc, temp_out_desc));
CHECK_STATUS(infiniopCreateRearrangeDescriptor(handle, &rearrange_desc_out, out_desc, att_val_desc));
// workspace size
uint64_t workspace_size = rearranged_q_size + std::max(std::max(matmul1_workspace_size + matmul1_tensor_size, matmul1_tensor_size + softmax_workspace_size), matmul1_tensor_size + matmul2_workspace_size + matmul2_tensor_size);
size_t op_workspace_size = utils::align(std::max(std::max(matmul1_workspace_size, matmul2_workspace_size), softmax_workspace_size), alignment);
size_t temp_tensors_size = attn_score_size + std::max(q_cont_size, att_val_size);
size_t workspace_size = temp_tensors_size + op_workspace_size;
// k_cache_offset
uint64_t k_cache_offset = 0;
size_t k_cache_offset = 0;
if (pos > 0) {
k_cache_offset = pos * k_cache_desc->getByteStrides()[1];
}
// v_cache_offset
uint64_t v_cache_offset = 0;
size_t v_cache_offset = 0;
if (pos > 0) {
v_cache_offset = pos * v_cache_desc->getByteStrides()[1];
}
......@@ -200,12 +205,11 @@ __C __export infiniStatus_t infiniopCreateAttentionDescriptor(infiniopHandle_t h
matmul2_desc,
softmax_desc,
workspace_size,
rearranged_q_size,
matmul1_workspace_size,
matmul1_tensor_size,
matmul2_workspace_size,
matmul2_tensor_size,
softmax_workspace_size,
temp_tensors_size,
op_workspace_size,
attn_score_size,
0,
attn_score_size,
k_cache_offset,
v_cache_offset,
1.f / std::sqrt(float(head_dim)),
......@@ -214,14 +218,14 @@ __C __export infiniStatus_t infiniopCreateAttentionDescriptor(infiniopHandle_t h
return INFINI_STATUS_SUCCESS;
}
__C __export infiniStatus_t infiniopGetAttentionWorkspaceSize(infiniopAttentionDescriptor_t desc, uint64_t *size) {
__C __export infiniStatus_t infiniopGetAttentionWorkspaceSize(infiniopAttentionDescriptor_t desc, size_t *size) {
*size = ((InfiniopAttentionDescriptor *)desc)->workspace_size;
return INFINI_STATUS_SUCCESS;
}
__C __export infiniStatus_t infiniopAttention(infiniopAttentionDescriptor_t desc_,
void *workspace,
uint64_t workspace_size,
void *workspace_,
size_t workspace_size_,
void *out,
void const *q,
void const *k,
......@@ -230,11 +234,14 @@ __C __export infiniStatus_t infiniopAttention(infiniopAttentionDescriptor_t desc
void *v_cache,
void *stream) {
auto desc = (InfiniopAttentionDescriptor *)desc_;
void *workspace_ = workspace;
if (workspace_size < desc->workspace_size) {
if (workspace_size_ < desc->workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE; // STATUS_MEMORY_NOT_ALLOCATED
}
void *workspace = (char *)workspace_ + desc->op_workspace_offset;
size_t workspace_size = desc->op_workspace_size;
void *att_score = (char *)workspace_ + desc->att_score_offset;
void *att_val = (char *)workspace_ + desc->att_val_offset;
void const *q_ = q;
// concat k and v to k_cache and v_cache
CHECK_STATUS(infiniopRearrange(desc->rearrange_desc_k,
(char *)k_cache + desc->k_cache_offset, k, stream));
......@@ -243,28 +250,26 @@ __C __export infiniStatus_t infiniopAttention(infiniopAttentionDescriptor_t desc
(char *)v_cache + desc->v_cache_offset, v, stream));
// rearrange q into contiguous
void const *_q = q;
if (desc->rearrange_desc_q) {
CHECK_STATUS(infiniopRearrange(desc->rearrange_desc_q, (char *)workspace_, q, stream));
_q = workspace_;
workspace_ = (char *)workspace_ + desc->rearranged_q_size;
void *q_cont = (char *)workspace_ + desc->q_cont_offset;
CHECK_STATUS(infiniopRearrange(desc->rearrange_desc_q, q_cont, q, stream));
q_ = q_cont;
}
// matmul1: q * full_k
CHECK_STATUS(infiniopGemm(desc->matmul_desc1,
(char *)workspace_ + desc->matmul1_tensor_size, workspace_size - desc->matmul1_tensor_size,
workspace_, _q, k_cache, desc->qk_alpha, 0.0, stream));
workspace, workspace_size,
att_score, q_, k_cache, desc->qk_alpha, 0.0, stream));
// softmax(qk)
CHECK_STATUS(infiniopCausalSoftmax(desc->softmax_desc,
(char *)workspace_ + desc->matmul1_tensor_size, workspace_size - desc->matmul1_tensor_size,
workspace_, workspace_, stream));
workspace, workspace_size,
att_score, att_score, stream));
// matmul2: softmax(qk) * full_v
CHECK_STATUS(infiniopGemm(desc->matmul_desc2,
(char *)workspace_ + desc->matmul1_tensor_size + desc->matmul2_tensor_size,
workspace_size - desc->matmul1_tensor_size - desc->matmul2_tensor_size,
(char *)workspace_ + desc->matmul1_tensor_size, workspace_, v_cache, 1.0, 0.0, stream));
workspace, workspace_size,
att_val, att_score, v_cache, 1.0, 0.0, stream));
// rearrange out
CHECK_STATUS(infiniopRearrange(desc->rearrange_desc_out, out, (char *)workspace_ + desc->matmul1_tensor_size, stream));
CHECK_STATUS(infiniopRearrange(desc->rearrange_desc_out, out, att_val, stream));
return INFINI_STATUS_SUCCESS;
}
......
......@@ -18,7 +18,7 @@ INFINIOP_CUDA_KERNEL causalSoftmax(
// [Reduce] Find max value in each row and store in shared memory
__shared__ Tdata max_;
Tdata max_0 = op::common_cuda::reduce_op::max<BLOCK_SIZE, Tdata>(x, width);
Tdata max_0 = op::common_cuda::reduce_op::max<BLOCK_SIZE, Tdata>(x, width - height + 1 + blockIdx.x);
if (threadIdx.x == 0) {
max_ = max_0;
}
......
......@@ -100,4 +100,12 @@ inline std::string infiniDtypeToString(infiniDtype_t dtype) {
#define CEIL_DIV(x, y) (((x) + (y)-1) / (y))
namespace utils {
inline size_t align(size_t size, size_t alignment) {
return (size + alignment - 1) & ~(alignment - 1);
}
} // namespace utils
#endif
......@@ -215,7 +215,7 @@ if __name__ == "__main__":
# Tolerance map for different data types
_TOLERANCE_MAP = {
torch.float16: {"atol": 1e-4, "rtol": 1e-2},
torch.float32: {"atol": 1e-6, "rtol": 1e-4},
torch.float32: {"atol": 1e-5, "rtol": 1e-3},
}
DEBUG = False
......@@ -268,6 +268,20 @@ if __name__ == "__main__":
None, # k_cache_stride
None, # v_cache_stride
),
(
28, # n_q_head
28, # n_kv_head
15, # seq_len
128, # head_dim
0, # pos
2048, # k_cache_buf_len
2048, # v_cache_buf_len
[128, 10752, 1], # q_stride
[128, 10752, 1], # k_stride
[128, 10752, 1], # v_stride
[128, 3584, 1], # k_cache_stride
[128, 3584, 1], # v_cache_stride
),
]
args = get_args()
lib = open_lib()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment