Commit 2e528580 authored by zhouxiang's avatar zhouxiang
Browse files

解决原框架unfused_attention不支持大于4k输入的问题

parent 441af933
...@@ -453,17 +453,23 @@ void invokeMaskedSoftmax(MaskedSoftmaxParam<T, T_IN>& param, cudaStream_t stream ...@@ -453,17 +453,23 @@ void invokeMaskedSoftmax(MaskedSoftmaxParam<T, T_IN>& param, cudaStream_t stream
bool is_half2 = sizeof(T) == 2 && sizeof(T_IN) == 2 && param.k_length % 2 == 0; bool is_half2 = sizeof(T) == 2 && sizeof(T_IN) == 2 && param.k_length % 2 == 0;
dim3 block((param.k_length / (is_half2 ? 2 : 1) + 31) / 32 * 32); dim3 block((param.k_length / (is_half2 ? 2 : 1) + 31) / 32 * 32);
if (block.x > 2048 && block.x <= 4096) { if (block.x > 8192 && block.x <= 16384) {
LAUNCH_MAKSED_SOFTMAX(16)
}
if (block.x > 4096 && block.x <= 8192) {
LAUNCH_MAKSED_SOFTMAX(8)
}
else if (block.x > 2048 && block.x <= 4096) {
LAUNCH_MAKSED_SOFTMAX(4) LAUNCH_MAKSED_SOFTMAX(4)
} }
else if (block.x > 1024) { else if (block.x > 1024 && block.x <= 2048) {
LAUNCH_MAKSED_SOFTMAX(2) LAUNCH_MAKSED_SOFTMAX(2)
} }
else if (block.x > 0) { else if (block.x > 0 && block.x <= 1024) {
LAUNCH_MAKSED_SOFTMAX(1) LAUNCH_MAKSED_SOFTMAX(1)
} }
else { else {
FT_CHECK(param.k_length <= 4096); FT_CHECK(param.k_length <= 16384);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment