Unverified Commit 681e7af3 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[OAI] Support non-normalized logprobs in OpenAI server (#5961)

parent 681fdc26
...@@ -86,11 +86,9 @@ class Sampler(nn.Module): ...@@ -86,11 +86,9 @@ class Sampler(nn.Module):
# NOTE: the top_p_renorm_prob from flashinfer has numerical problems, # NOTE: the top_p_renorm_prob from flashinfer has numerical problems,
# https://github.com/flashinfer-ai/flashinfer/issues/708 # https://github.com/flashinfer-ai/flashinfer/issues/708
# so we use the torch implementation. # so we use the torch implementation.
# NOTE: OpenAI's logprobs is independent of top-p, we use the
# clamp to avoid -inf # same rule.
logprobs = torch.log( logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min)
top_p_normalize_probs_torch(probs, sampling_info.top_ps)
).clamp(min=torch.finfo(probs.dtype).min)
max_top_k_round, batch_size = 32, probs.shape[0] max_top_k_round, batch_size = 32, probs.shape[0]
if sampling_info.need_min_p_sampling: if sampling_info.need_min_p_sampling:
...@@ -121,10 +119,7 @@ class Sampler(nn.Module): ...@@ -121,10 +119,7 @@ class Sampler(nn.Module):
) )
if return_logprob: if return_logprob:
# clamp to avoid -inf logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min)
logprobs = torch.log(
top_p_normalize_probs_torch(probs, sampling_info.top_ps)
).clamp(min=torch.finfo(probs.dtype).min)
else: else:
raise ValueError( raise ValueError(
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}" f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment