Unverified Commit 5beab8c0 authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #229 from InfiniTensor/issue/191/fix

issue/191/fix 为attention增加对齐,修复cuda causal softmax
parents 2e624a8e b79f2607
...@@ -20,15 +20,14 @@ struct InfiniopAttentionDescriptor { ...@@ -20,15 +20,14 @@ struct InfiniopAttentionDescriptor {
infiniopGemmDescriptor_t matmul_desc1; infiniopGemmDescriptor_t matmul_desc1;
infiniopGemmDescriptor_t matmul_desc2; infiniopGemmDescriptor_t matmul_desc2;
infiniopCausalSoftmaxDescriptor_t softmax_desc; infiniopCausalSoftmaxDescriptor_t softmax_desc;
uint64_t workspace_size; size_t workspace_size;
uint64_t rearranged_q_size; size_t op_workspace_offset;
uint64_t matmul1_workspace_size; size_t op_workspace_size;
uint64_t matmul1_tensor_size; size_t q_cont_offset;
uint64_t matmul2_workspace_size; size_t att_score_offset;
uint64_t matmul2_tensor_size; size_t att_val_offset;
uint64_t softmax_workspace_size; size_t k_cache_offset;
uint64_t k_cache_offset; size_t v_cache_offset;
uint64_t v_cache_offset;
float qk_alpha; float qk_alpha;
}; };
...@@ -40,7 +39,7 @@ __C __export infiniStatus_t infiniopCreateAttentionDescriptor(infiniopHandle_t h ...@@ -40,7 +39,7 @@ __C __export infiniStatus_t infiniopCreateAttentionDescriptor(infiniopHandle_t h
infiniopTensorDescriptor_t v_desc, infiniopTensorDescriptor_t v_desc,
infiniopTensorDescriptor_t k_cache_desc, infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc, infiniopTensorDescriptor_t v_cache_desc,
uint64_t pos) { size_t pos) {
if (out_desc->ndim() != 3 || q_desc->ndim() != 3 || k_desc->ndim() != 3 || v_desc->ndim() != 3 || k_cache_desc->ndim() != 3 || v_cache_desc->ndim() != 3) { if (out_desc->ndim() != 3 || q_desc->ndim() != 3 || k_desc->ndim() != 3 || v_desc->ndim() != 3 || k_cache_desc->ndim() != 3 || v_cache_desc->ndim() != 3) {
return INFINI_STATUS_BAD_TENSOR_SHAPE; return INFINI_STATUS_BAD_TENSOR_SHAPE;
} }
...@@ -53,13 +52,14 @@ __C __export infiniStatus_t infiniopCreateAttentionDescriptor(infiniopHandle_t h ...@@ -53,13 +52,14 @@ __C __export infiniStatus_t infiniopCreateAttentionDescriptor(infiniopHandle_t h
return INFINI_STATUS_BAD_TENSOR_STRIDES; return INFINI_STATUS_BAD_TENSOR_STRIDES;
} }
uint64_t n_q_head = q_desc->shape()[0]; size_t n_q_head = q_desc->shape()[0];
uint64_t seq_len = q_desc->shape()[1]; size_t seq_len = q_desc->shape()[1];
uint64_t head_dim = q_desc->shape()[2]; size_t head_dim = q_desc->shape()[2];
uint64_t hidden_size = n_q_head * head_dim; size_t hidden_size = n_q_head * head_dim;
uint64_t n_kv_head = k_desc->shape()[0]; size_t n_kv_head = k_desc->shape()[0];
uint64_t total_seq_len = seq_len + pos; size_t total_seq_len = seq_len + pos;
uint64_t n_group = n_q_head / n_kv_head; size_t n_group = n_q_head / n_kv_head;
size_t alignment = 256;
if (out_desc->shape()[0] != seq_len || out_desc->shape()[1] != n_q_head || out_desc->shape()[2] != head_dim) { if (out_desc->shape()[0] != seq_len || out_desc->shape()[1] != n_q_head || out_desc->shape()[2] != head_dim) {
return INFINI_STATUS_BAD_PARAM; return INFINI_STATUS_BAD_PARAM;
...@@ -98,12 +98,12 @@ __C __export infiniStatus_t infiniopCreateAttentionDescriptor(infiniopHandle_t h ...@@ -98,12 +98,12 @@ __C __export infiniStatus_t infiniopCreateAttentionDescriptor(infiniopHandle_t h
CHECK_STATUS(infiniopCreateRearrangeDescriptor(handle, &rearrange_desc_v, dst_v_desc, v_desc)); CHECK_STATUS(infiniopCreateRearrangeDescriptor(handle, &rearrange_desc_v, dst_v_desc, v_desc));
infiniopRearrangeDescriptor_t rearrange_desc_q = nullptr; infiniopRearrangeDescriptor_t rearrange_desc_q = nullptr;
uint64_t rearranged_q_size = 0; size_t q_cont_size = 0;
infiniopTensorDescriptor_t rearranged_q_desc; infiniopTensorDescriptor_t rearranged_q_desc;
// Rearrange q into contiguous // Rearrange q into contiguous
if (!q_desc->isContiguous(0, 1)) { if (!q_desc->isContiguous(0, 1)) {
CHECK_STATUS(infiniopCreateTensorDescriptor(&rearranged_q_desc, 3, q_desc->shape().data(), nullptr, q_desc->dtype())); CHECK_STATUS(infiniopCreateTensorDescriptor(&rearranged_q_desc, 3, q_desc->shape().data(), nullptr, q_desc->dtype()));
rearranged_q_size = rearranged_q_desc->numel() * infiniSizeOf(rearranged_q_desc->dtype()); q_cont_size = utils::align(rearranged_q_desc->numel() * infiniSizeOf(rearranged_q_desc->dtype()), alignment);
rearrange_desc_q = new InfiniopDescriptor; rearrange_desc_q = new InfiniopDescriptor;
CHECK_STATUS(infiniopCreateRearrangeDescriptor(handle, &rearrange_desc_q, rearranged_q_desc, q_desc)); CHECK_STATUS(infiniopCreateRearrangeDescriptor(handle, &rearrange_desc_q, rearranged_q_desc, q_desc));
} }
...@@ -116,12 +116,12 @@ __C __export infiniStatus_t infiniopCreateAttentionDescriptor(infiniopHandle_t h ...@@ -116,12 +116,12 @@ __C __export infiniStatus_t infiniopCreateAttentionDescriptor(infiniopHandle_t h
TRANSFORM_TENSOR_DESC(reshaped_q_desc, dimMerge(1, 2)); TRANSFORM_TENSOR_DESC(reshaped_q_desc, dimMerge(1, 2));
// full_k: [n_kv_head, head_dim, total_seq_len] // full_k: [n_kv_head, head_dim, total_seq_len]
infiniopTensorDescriptor_t full_k_desc; infiniopTensorDescriptor_t full_k_desc;
uint64_t full_k_shape[3] = {n_kv_head, total_seq_len, head_dim}; size_t full_k_shape[3] = {n_kv_head, total_seq_len, head_dim};
CHECK_STATUS(infiniopCreateTensorDescriptor(&full_k_desc, 3, full_k_shape, k_cache_desc->strides().data(), k_cache_desc->dtype())); CHECK_STATUS(infiniopCreateTensorDescriptor(&full_k_desc, 3, full_k_shape, k_cache_desc->strides().data(), k_cache_desc->dtype()));
TRANSFORM_TENSOR_DESC(full_k_desc, dimPermute({0, 2, 1})); TRANSFORM_TENSOR_DESC(full_k_desc, dimPermute({0, 2, 1}));
// qk: [n_kv_head, n_group * seq_len, total_seq_len] // qk: [n_kv_head, n_group * seq_len, total_seq_len]
infiniopTensorDescriptor_t qk_desc; infiniopTensorDescriptor_t qk_desc;
uint64_t qk_shape[3] = {n_kv_head, n_group * seq_len, total_seq_len}; size_t qk_shape[3] = {n_kv_head, n_group * seq_len, total_seq_len};
CHECK_STATUS(infiniopCreateTensorDescriptor(&qk_desc, 3, qk_shape, nullptr, q_desc->dtype())); CHECK_STATUS(infiniopCreateTensorDescriptor(&qk_desc, 3, qk_shape, nullptr, q_desc->dtype()));
// matmul1_desc // matmul1_desc
// qk_alpha // qk_alpha
...@@ -129,10 +129,11 @@ __C __export infiniStatus_t infiniopCreateAttentionDescriptor(infiniopHandle_t h ...@@ -129,10 +129,11 @@ __C __export infiniStatus_t infiniopCreateAttentionDescriptor(infiniopHandle_t h
infiniopGemmDescriptor_t matmul1_desc; infiniopGemmDescriptor_t matmul1_desc;
CHECK_STATUS(infiniopCreateGemmDescriptor(handle, &matmul1_desc, qk_desc, reshaped_q_desc, full_k_desc)); CHECK_STATUS(infiniopCreateGemmDescriptor(handle, &matmul1_desc, qk_desc, reshaped_q_desc, full_k_desc));
// matmul1 workspace size // matmul1 workspace size
uint64_t matmul1_workspace_size; size_t matmul1_workspace_size;
CHECK_STATUS(infiniopGetGemmWorkspaceSize(matmul1_desc, &matmul1_workspace_size)); CHECK_STATUS(infiniopGetGemmWorkspaceSize(matmul1_desc, &matmul1_workspace_size));
// matmul1 tensor size matmul1_workspace_size = utils::align(matmul1_workspace_size, alignment);
uint64_t matmul1_tensor_size = qk_desc->numel() * infiniSizeOf(qk_desc->dtype()); // attention score tensor size
size_t attn_score_size = utils::align(qk_desc->numel() * infiniSizeOf(qk_desc->dtype()), alignment);
// CausalSoftmax: softmax(qk) // CausalSoftmax: softmax(qk)
// qk: [n_kv_head, n_group * seq_len, total_seq_len] -> [n_q_head, seq_len, total_seq_len] // qk: [n_kv_head, n_group * seq_len, total_seq_len] -> [n_q_head, seq_len, total_seq_len]
...@@ -141,8 +142,9 @@ __C __export infiniStatus_t infiniopCreateAttentionDescriptor(infiniopHandle_t h ...@@ -141,8 +142,9 @@ __C __export infiniStatus_t infiniopCreateAttentionDescriptor(infiniopHandle_t h
infiniopCausalSoftmaxDescriptor_t softmax_desc; infiniopCausalSoftmaxDescriptor_t softmax_desc;
CHECK_STATUS(infiniopCreateCausalSoftmaxDescriptor(handle, &softmax_desc, qk_desc, qk_desc)); CHECK_STATUS(infiniopCreateCausalSoftmaxDescriptor(handle, &softmax_desc, qk_desc, qk_desc));
// softmax workspace size // softmax workspace size
uint64_t softmax_workspace_size; size_t softmax_workspace_size;
CHECK_STATUS(infiniopGetCausalSoftmaxWorkspaceSize(softmax_desc, &softmax_workspace_size)); CHECK_STATUS(infiniopGetCausalSoftmaxWorkspaceSize(softmax_desc, &softmax_workspace_size));
softmax_workspace_size = utils::align(softmax_workspace_size, alignment);
// Matmul2: softmax(qk) * full_v // Matmul2: softmax(qk) * full_v
// softmax(qk): [n_q_head, seq_len, total_seq_len] -> [n_kv_head, n_group * seq_len, total_seq_len] // softmax(qk): [n_q_head, seq_len, total_seq_len] -> [n_kv_head, n_group * seq_len, total_seq_len]
...@@ -150,41 +152,44 @@ __C __export infiniStatus_t infiniopCreateAttentionDescriptor(infiniopHandle_t h ...@@ -150,41 +152,44 @@ __C __export infiniStatus_t infiniopCreateAttentionDescriptor(infiniopHandle_t h
TRANSFORM_TENSOR_DESC(qk_desc, dimSplit(0, {n_kv_head, n_group})); TRANSFORM_TENSOR_DESC(qk_desc, dimSplit(0, {n_kv_head, n_group}));
TRANSFORM_TENSOR_DESC(qk_desc, dimMerge(1, 2)); TRANSFORM_TENSOR_DESC(qk_desc, dimMerge(1, 2));
infiniopTensorDescriptor_t full_v_desc; infiniopTensorDescriptor_t full_v_desc;
uint64_t full_v_shape[3] = {n_kv_head, total_seq_len, head_dim}; size_t full_v_shape[3] = {n_kv_head, total_seq_len, head_dim};
CHECK_STATUS(infiniopCreateTensorDescriptor(&full_v_desc, 3, full_v_shape, v_cache_desc->strides().data(), v_cache_desc->dtype())); CHECK_STATUS(infiniopCreateTensorDescriptor(&full_v_desc, 3, full_v_shape, v_cache_desc->strides().data(), v_cache_desc->dtype()));
// temp_out: [n_kv_head, n_group * seq_len, head_dim] // temp_out: [n_kv_head, n_group * seq_len, head_dim]
infiniopTensorDescriptor_t temp_out_desc; infiniopTensorDescriptor_t att_val_desc;
uint64_t temp_out_shape[3] = {n_kv_head, n_group * seq_len, head_dim}; size_t temp_out_shape[3] = {n_kv_head, n_group * seq_len, head_dim};
CHECK_STATUS(infiniopCreateTensorDescriptor(&temp_out_desc, 3, temp_out_shape, nullptr, q_desc->dtype())); CHECK_STATUS(infiniopCreateTensorDescriptor(&att_val_desc, 3, temp_out_shape, nullptr, q_desc->dtype()));
// matmul2_desc // matmul2_desc
infiniopGemmDescriptor_t matmul2_desc; infiniopGemmDescriptor_t matmul2_desc;
CHECK_STATUS(infiniopCreateGemmDescriptor(handle, &matmul2_desc, temp_out_desc, qk_desc, full_v_desc)); CHECK_STATUS(infiniopCreateGemmDescriptor(handle, &matmul2_desc, att_val_desc, qk_desc, full_v_desc));
// matmul2 workspace size // matmul2 workspace size
uint64_t matmul2_workspace_size; size_t matmul2_workspace_size;
CHECK_STATUS(infiniopGetGemmWorkspaceSize(matmul2_desc, &matmul2_workspace_size)); CHECK_STATUS(infiniopGetGemmWorkspaceSize(matmul2_desc, &matmul2_workspace_size));
// matmul2 tensor size matmul2_workspace_size = utils::align(matmul2_workspace_size, alignment);
uint64_t matmul2_tensor_size = temp_out_desc->numel() * infiniSizeOf(temp_out_desc->dtype()); // attention value tensor size
size_t att_val_size = utils::align(att_val_desc->numel() * infiniSizeOf(att_val_desc->dtype()), alignment);
// Rearrange temp_out into out // Rearrange temp_out into out
// out: [seq_len, n_q_head, head_dim] // out: [seq_len, n_q_head, head_dim]
// temp_out: [n_kv_head, n_group * seq_len, head_dim] -> [n_q_head, seq_len, head_dim] -> [seq_len, n_q_head, head_dim] // temp_out: [n_kv_head, n_group * seq_len, head_dim] -> [n_q_head, seq_len, head_dim] -> [seq_len, n_q_head, head_dim]
TRANSFORM_TENSOR_DESC(temp_out_desc, dimSplit(1, {n_group, seq_len})); TRANSFORM_TENSOR_DESC(att_val_desc, dimSplit(1, {n_group, seq_len}));
TRANSFORM_TENSOR_DESC(temp_out_desc, dimMerge(0, 1)); TRANSFORM_TENSOR_DESC(att_val_desc, dimMerge(0, 1));
TRANSFORM_TENSOR_DESC(temp_out_desc, dimPermute({1, 0, 2})); TRANSFORM_TENSOR_DESC(att_val_desc, dimPermute({1, 0, 2}));
infiniopRearrangeDescriptor_t rearrange_desc_out; infiniopRearrangeDescriptor_t rearrange_desc_out;
CHECK_STATUS(infiniopCreateRearrangeDescriptor(handle, &rearrange_desc_out, out_desc, temp_out_desc)); CHECK_STATUS(infiniopCreateRearrangeDescriptor(handle, &rearrange_desc_out, out_desc, att_val_desc));
// workspace size // workspace size
uint64_t workspace_size = rearranged_q_size + std::max(std::max(matmul1_workspace_size + matmul1_tensor_size, matmul1_tensor_size + softmax_workspace_size), matmul1_tensor_size + matmul2_workspace_size + matmul2_tensor_size); size_t op_workspace_size = utils::align(std::max(std::max(matmul1_workspace_size, matmul2_workspace_size), softmax_workspace_size), alignment);
size_t temp_tensors_size = attn_score_size + std::max(q_cont_size, att_val_size);
size_t workspace_size = temp_tensors_size + op_workspace_size;
// k_cache_offset // k_cache_offset
uint64_t k_cache_offset = 0; size_t k_cache_offset = 0;
if (pos > 0) { if (pos > 0) {
k_cache_offset = pos * k_cache_desc->getByteStrides()[1]; k_cache_offset = pos * k_cache_desc->getByteStrides()[1];
} }
// v_cache_offset // v_cache_offset
uint64_t v_cache_offset = 0; size_t v_cache_offset = 0;
if (pos > 0) { if (pos > 0) {
v_cache_offset = pos * v_cache_desc->getByteStrides()[1]; v_cache_offset = pos * v_cache_desc->getByteStrides()[1];
} }
...@@ -200,12 +205,11 @@ __C __export infiniStatus_t infiniopCreateAttentionDescriptor(infiniopHandle_t h ...@@ -200,12 +205,11 @@ __C __export infiniStatus_t infiniopCreateAttentionDescriptor(infiniopHandle_t h
matmul2_desc, matmul2_desc,
softmax_desc, softmax_desc,
workspace_size, workspace_size,
rearranged_q_size, temp_tensors_size,
matmul1_workspace_size, op_workspace_size,
matmul1_tensor_size, attn_score_size,
matmul2_workspace_size, 0,
matmul2_tensor_size, attn_score_size,
softmax_workspace_size,
k_cache_offset, k_cache_offset,
v_cache_offset, v_cache_offset,
1.f / std::sqrt(float(head_dim)), 1.f / std::sqrt(float(head_dim)),
...@@ -214,14 +218,14 @@ __C __export infiniStatus_t infiniopCreateAttentionDescriptor(infiniopHandle_t h ...@@ -214,14 +218,14 @@ __C __export infiniStatus_t infiniopCreateAttentionDescriptor(infiniopHandle_t h
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
__C __export infiniStatus_t infiniopGetAttentionWorkspaceSize(infiniopAttentionDescriptor_t desc, uint64_t *size) { __C __export infiniStatus_t infiniopGetAttentionWorkspaceSize(infiniopAttentionDescriptor_t desc, size_t *size) {
*size = ((InfiniopAttentionDescriptor *)desc)->workspace_size; *size = ((InfiniopAttentionDescriptor *)desc)->workspace_size;
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
__C __export infiniStatus_t infiniopAttention(infiniopAttentionDescriptor_t desc_, __C __export infiniStatus_t infiniopAttention(infiniopAttentionDescriptor_t desc_,
void *workspace, void *workspace_,
uint64_t workspace_size, size_t workspace_size_,
void *out, void *out,
void const *q, void const *q,
void const *k, void const *k,
...@@ -230,11 +234,14 @@ __C __export infiniStatus_t infiniopAttention(infiniopAttentionDescriptor_t desc ...@@ -230,11 +234,14 @@ __C __export infiniStatus_t infiniopAttention(infiniopAttentionDescriptor_t desc
void *v_cache, void *v_cache,
void *stream) { void *stream) {
auto desc = (InfiniopAttentionDescriptor *)desc_; auto desc = (InfiniopAttentionDescriptor *)desc_;
void *workspace_ = workspace; if (workspace_size_ < desc->workspace_size) {
if (workspace_size < desc->workspace_size) {
return INFINI_STATUS_INSUFFICIENT_WORKSPACE; // STATUS_MEMORY_NOT_ALLOCATED return INFINI_STATUS_INSUFFICIENT_WORKSPACE; // STATUS_MEMORY_NOT_ALLOCATED
} }
void *workspace = (char *)workspace_ + desc->op_workspace_offset;
size_t workspace_size = desc->op_workspace_size;
void *att_score = (char *)workspace_ + desc->att_score_offset;
void *att_val = (char *)workspace_ + desc->att_val_offset;
void const *q_ = q;
// concat k and v to k_cache and v_cache // concat k and v to k_cache and v_cache
CHECK_STATUS(infiniopRearrange(desc->rearrange_desc_k, CHECK_STATUS(infiniopRearrange(desc->rearrange_desc_k,
(char *)k_cache + desc->k_cache_offset, k, stream)); (char *)k_cache + desc->k_cache_offset, k, stream));
...@@ -243,28 +250,26 @@ __C __export infiniStatus_t infiniopAttention(infiniopAttentionDescriptor_t desc ...@@ -243,28 +250,26 @@ __C __export infiniStatus_t infiniopAttention(infiniopAttentionDescriptor_t desc
(char *)v_cache + desc->v_cache_offset, v, stream)); (char *)v_cache + desc->v_cache_offset, v, stream));
// rearrange q into contiguous // rearrange q into contiguous
void const *_q = q;
if (desc->rearrange_desc_q) { if (desc->rearrange_desc_q) {
CHECK_STATUS(infiniopRearrange(desc->rearrange_desc_q, (char *)workspace_, q, stream)); void *q_cont = (char *)workspace_ + desc->q_cont_offset;
_q = workspace_; CHECK_STATUS(infiniopRearrange(desc->rearrange_desc_q, q_cont, q, stream));
workspace_ = (char *)workspace_ + desc->rearranged_q_size; q_ = q_cont;
} }
// matmul1: q * full_k // matmul1: q * full_k
CHECK_STATUS(infiniopGemm(desc->matmul_desc1, CHECK_STATUS(infiniopGemm(desc->matmul_desc1,
(char *)workspace_ + desc->matmul1_tensor_size, workspace_size - desc->matmul1_tensor_size, workspace, workspace_size,
workspace_, _q, k_cache, desc->qk_alpha, 0.0, stream)); att_score, q_, k_cache, desc->qk_alpha, 0.0, stream));
// softmax(qk) // softmax(qk)
CHECK_STATUS(infiniopCausalSoftmax(desc->softmax_desc, CHECK_STATUS(infiniopCausalSoftmax(desc->softmax_desc,
(char *)workspace_ + desc->matmul1_tensor_size, workspace_size - desc->matmul1_tensor_size, workspace, workspace_size,
workspace_, workspace_, stream)); att_score, att_score, stream));
// matmul2: softmax(qk) * full_v // matmul2: softmax(qk) * full_v
CHECK_STATUS(infiniopGemm(desc->matmul_desc2, CHECK_STATUS(infiniopGemm(desc->matmul_desc2,
(char *)workspace_ + desc->matmul1_tensor_size + desc->matmul2_tensor_size, workspace, workspace_size,
workspace_size - desc->matmul1_tensor_size - desc->matmul2_tensor_size, att_val, att_score, v_cache, 1.0, 0.0, stream));
(char *)workspace_ + desc->matmul1_tensor_size, workspace_, v_cache, 1.0, 0.0, stream));
// rearrange out // rearrange out
CHECK_STATUS(infiniopRearrange(desc->rearrange_desc_out, out, (char *)workspace_ + desc->matmul1_tensor_size, stream)); CHECK_STATUS(infiniopRearrange(desc->rearrange_desc_out, out, att_val, stream));
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
......
...@@ -18,7 +18,7 @@ INFINIOP_CUDA_KERNEL causalSoftmax( ...@@ -18,7 +18,7 @@ INFINIOP_CUDA_KERNEL causalSoftmax(
// [Reduce] Find max value in each row and store in shared memory // [Reduce] Find max value in each row and store in shared memory
__shared__ Tdata max_; __shared__ Tdata max_;
Tdata max_0 = op::common_cuda::reduce_op::max<BLOCK_SIZE, Tdata>(x, width); Tdata max_0 = op::common_cuda::reduce_op::max<BLOCK_SIZE, Tdata>(x, width - height + 1 + blockIdx.x);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
max_ = max_0; max_ = max_0;
} }
......
...@@ -100,4 +100,12 @@ inline std::string infiniDtypeToString(infiniDtype_t dtype) { ...@@ -100,4 +100,12 @@ inline std::string infiniDtypeToString(infiniDtype_t dtype) {
#define CEIL_DIV(x, y) (((x) + (y)-1) / (y)) #define CEIL_DIV(x, y) (((x) + (y)-1) / (y))
namespace utils {
inline size_t align(size_t size, size_t alignment) {
return (size + alignment - 1) & ~(alignment - 1);
}
} // namespace utils
#endif #endif
...@@ -215,7 +215,7 @@ if __name__ == "__main__": ...@@ -215,7 +215,7 @@ if __name__ == "__main__":
# Tolerance map for different data types # Tolerance map for different data types
_TOLERANCE_MAP = { _TOLERANCE_MAP = {
torch.float16: {"atol": 1e-4, "rtol": 1e-2}, torch.float16: {"atol": 1e-4, "rtol": 1e-2},
torch.float32: {"atol": 1e-6, "rtol": 1e-4}, torch.float32: {"atol": 1e-5, "rtol": 1e-3},
} }
DEBUG = False DEBUG = False
...@@ -268,6 +268,20 @@ if __name__ == "__main__": ...@@ -268,6 +268,20 @@ if __name__ == "__main__":
None, # k_cache_stride None, # k_cache_stride
None, # v_cache_stride None, # v_cache_stride
), ),
(
28, # n_q_head
28, # n_kv_head
15, # seq_len
128, # head_dim
0, # pos
2048, # k_cache_buf_len
2048, # v_cache_buf_len
[128, 10752, 1], # q_stride
[128, 10752, 1], # k_stride
[128, 10752, 1], # v_stride
[128, 3584, 1], # k_cache_stride
[128, 3584, 1], # v_cache_stride
),
] ]
args = get_args() args = get_args()
lib = open_lib() lib = open_lib()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment