Commit 336949c2 authored by fengzch's avatar fengzch
Browse files

fix: compile attention.cu complete

parent 8c9a37b1
......@@ -58,7 +58,7 @@ void attention_fp16(Tensor q, // packed [Batch, Head, TokensQ, HEAD_DIM]
using packed_k_t = typename Attention::packed_k_t;
using packed_v_t = typename Attention::packed_v_t;
auto func = invoke_kernel<typename Attention::attention_fp16_kernel<Epilogue>,
auto func = invoke_kernel<typename Attention::template attention_fp16_kernel<Epilogue>,
const packed_q_t *,
const packed_k_t *,
const packed_v_t *,
......@@ -71,7 +71,7 @@ void attention_fp16(Tensor q, // packed [Batch, Head, TokensQ, HEAD_DIM]
shmem = std::max(shmem, Attention::template attention_fp16_kernel<Epilogue>::SHMEM_SIZE);
if (shmem >= 24 * 1024) {
checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
checkCUDA(cudaFuncSetAttribute(reinterpret_cast<const void*>(func), cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
}
func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, shmem, getCurrentCUDAStream()>>>(q.data_ptr<packed_q_t>(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment