"tests/multimodal/test_mapper.py" did not exist on "6984c02a2735d4d08426d2c426c34b6d73bee89e"
Unverified Commit 299ebb62 authored by Chanh Nguyen's avatar Chanh Nguyen Committed by GitHub
Browse files

[Core] Speed up decode by remove synchronizing operation in sampler (#16436)


Signed-off-by: default avatarChanh Nguyen <cnguyen@linkedin.com>
Co-authored-by: default avatarChanh Nguyen <cnguyen@linkedin.com>
parent f728ab8e
......@@ -47,10 +47,15 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
output_tokens_tensor, vocab_size, num_seqs)
repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat(
1, vocab_size)
logits[logits > 0] /= torch.where(prompt_mask | output_mask,
repetition_penalties, 1.0)[logits > 0]
logits[logits <= 0] *= torch.where(prompt_mask | output_mask,
repetition_penalties, 1.0)[logits <= 0]
# If token appears in prompt or output, apply, otherwise use 1.0 for no-op.
penalties = torch.where(prompt_mask | output_mask, repetition_penalties,
1.0)
# If logits are positive, divide by penalty, otherwise multiply by penalty.
scaling = torch.where(logits > 0, 1.0 / penalties, penalties)
logits *= scaling
# We follow the definition in OpenAI API.
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment