Unverified Commit 734daedd authored by Byron Hsu's avatar Byron Hsu Committed by GitHub
Browse files

[fix] Clamp logprob with dtype min to prevent `-inf` (#3224)

parent 3ee62235
...@@ -72,9 +72,11 @@ class Sampler(nn.Module): ...@@ -72,9 +72,11 @@ class Sampler(nn.Module):
# NOTE: the top_p_renorm_prob from flashinfer has numerical problems, # NOTE: the top_p_renorm_prob from flashinfer has numerical problems,
# https://github.com/flashinfer-ai/flashinfer/issues/708 # https://github.com/flashinfer-ai/flashinfer/issues/708
# so we use the torch implementation. # so we use the torch implementation.
# clamp to avoid -inf
logprobs = torch.log( logprobs = torch.log(
top_p_normalize_probs_torch(probs, sampling_info.top_ps) top_p_normalize_probs_torch(probs, sampling_info.top_ps)
) ).clamp(min=torch.finfo(probs.dtype).min)
max_top_k_round, batch_size = 32, probs.shape[0] max_top_k_round, batch_size = 32, probs.shape[0]
uniform_samples = torch.rand( uniform_samples = torch.rand(
...@@ -109,9 +111,10 @@ class Sampler(nn.Module): ...@@ -109,9 +111,10 @@ class Sampler(nn.Module):
sampling_info.need_min_p_sampling, sampling_info.need_min_p_sampling,
) )
if return_logprob: if return_logprob:
# clamp to avoid -inf
logprobs = torch.log( logprobs = torch.log(
top_p_normalize_probs_torch(probs, sampling_info.top_ps) top_p_normalize_probs_torch(probs, sampling_info.top_ps)
) ).clamp(min=torch.finfo(probs.dtype).min)
else: else:
raise ValueError( raise ValueError(
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}" f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
......
...@@ -36,7 +36,7 @@ class TestBatchPenalizerE2E(unittest.TestCase): ...@@ -36,7 +36,7 @@ class TestBatchPenalizerE2E(unittest.TestCase):
def run_decode( def run_decode(
self, self,
return_logprob=True, return_logprob=True,
top_logprobs_num=3, top_logprobs_num=5,
return_text=True, return_text=True,
n=1, n=1,
**sampling_params, **sampling_params,
...@@ -58,8 +58,7 @@ class TestBatchPenalizerE2E(unittest.TestCase): ...@@ -58,8 +58,7 @@ class TestBatchPenalizerE2E(unittest.TestCase):
"logprob_start_len": 0, "logprob_start_len": 0,
}, },
) )
print(json.dumps(response.json())) assert response.status_code == 200, "Request failed: " + response.text
print("=" * 100)
def test_default_values(self): def test_default_values(self):
self.run_decode() self.run_decode()
...@@ -112,4 +111,4 @@ class TestBatchPenalizerE2E(unittest.TestCase): ...@@ -112,4 +111,4 @@ class TestBatchPenalizerE2E(unittest.TestCase):
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main(verbosity=3)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment