Unverified Commit d5a89465 authored by Li Zhang's avatar Li Zhang Committed by GitHub
Browse files

set smem size for repetition penalty kernel (#818)

parent a54b16a2
......@@ -446,10 +446,16 @@ void invokeBatchApplyRepetitionPenalty(T* logits,
dim3 grid(local_batch_size);
size_t smem_size = step * (sizeof(float) + sizeof(int));
if (penalty_type == RepetitionPenaltyType::Additive) {
check_cuda_error(cudaFuncSetAttribute(batchApplyRepetitionPenalty<T, RepetitionPenaltyType::Additive>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size));
batchApplyRepetitionPenalty<T, RepetitionPenaltyType::Additive><<<grid, block, smem_size, stream>>>(
logits, penalties, output_ids, batch_size, vocab_size, input_lengths, max_input_length, step);
}
else if (penalty_type == RepetitionPenaltyType::Multiplicative) {
check_cuda_error(cudaFuncSetAttribute(batchApplyRepetitionPenalty<T, RepetitionPenaltyType::Multiplicative>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size));
batchApplyRepetitionPenalty<T, RepetitionPenaltyType::Multiplicative><<<grid, block, smem_size, stream>>>(
logits, penalties, output_ids, batch_size, vocab_size, input_lengths, max_input_length, step);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment