Unverified Commit d5a89465 authored by Li Zhang's avatar Li Zhang Committed by GitHub
Browse files

set smem size for repetition penalty kernel (#818)

parent a54b16a2
...@@ -446,10 +446,16 @@ void invokeBatchApplyRepetitionPenalty(T* logits, ...@@ -446,10 +446,16 @@ void invokeBatchApplyRepetitionPenalty(T* logits,
dim3 grid(local_batch_size); dim3 grid(local_batch_size);
size_t smem_size = step * (sizeof(float) + sizeof(int)); size_t smem_size = step * (sizeof(float) + sizeof(int));
if (penalty_type == RepetitionPenaltyType::Additive) { if (penalty_type == RepetitionPenaltyType::Additive) {
check_cuda_error(cudaFuncSetAttribute(batchApplyRepetitionPenalty<T, RepetitionPenaltyType::Additive>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size));
batchApplyRepetitionPenalty<T, RepetitionPenaltyType::Additive><<<grid, block, smem_size, stream>>>( batchApplyRepetitionPenalty<T, RepetitionPenaltyType::Additive><<<grid, block, smem_size, stream>>>(
logits, penalties, output_ids, batch_size, vocab_size, input_lengths, max_input_length, step); logits, penalties, output_ids, batch_size, vocab_size, input_lengths, max_input_length, step);
} }
else if (penalty_type == RepetitionPenaltyType::Multiplicative) { else if (penalty_type == RepetitionPenaltyType::Multiplicative) {
check_cuda_error(cudaFuncSetAttribute(batchApplyRepetitionPenalty<T, RepetitionPenaltyType::Multiplicative>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size));
batchApplyRepetitionPenalty<T, RepetitionPenaltyType::Multiplicative><<<grid, block, smem_size, stream>>>( batchApplyRepetitionPenalty<T, RepetitionPenaltyType::Multiplicative><<<grid, block, smem_size, stream>>>(
logits, penalties, output_ids, batch_size, vocab_size, input_lengths, max_input_length, step); logits, penalties, output_ids, batch_size, vocab_size, input_lengths, max_input_length, step);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment