Unverified Commit 71ade772 authored by Lyu Han's avatar Lyu Han Committed by GitHub
Browse files

[Fix] Set max dynamic smem size for decoder MHA to support context length > 8k (#377)

* Fix crash when context window size is large by setting max dynamic smem size

* fix linting
parent 57cf99b9
......@@ -28,16 +28,18 @@
#define MMHA_LAUNCH_KERNEL( \
T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, QUANT_POLICY, stream) \
auto func = &mmha::masked_multihead_attention_kernel<T, \
Dh, \
Dh_MAX, \
THDS_PER_KEY, \
THDS_PER_VALUE, \
THDS_PER_BLOCK, \
HAS_BEAMS, \
QUANT_POLICY>; \
size_t smem_sz = mmha::smem_size_in_bytes<T>(params, THDS_PER_VALUE, THDS_PER_BLOCK); \
dim3 grid(params.num_heads, params.batch_size); \
mmha::masked_multihead_attention_kernel<T, \
Dh, \
Dh_MAX, \
THDS_PER_KEY, \
THDS_PER_VALUE, \
THDS_PER_BLOCK, \
HAS_BEAMS, \
QUANT_POLICY><<<grid, THDS_PER_BLOCK, smem_sz, stream>>>(params)
cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz); \
func<<<grid, THDS_PER_BLOCK, smem_sz, stream>>>(params)
////////////////////////////////////////////////////////////////////////////////////////////////////
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment