Unverified Commit 7f24ea95 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fuse top_k and top_k in the sampler (#1457)

parent 1acccb36
......@@ -23,6 +23,7 @@ class GenerateReqInput:
# Whether to return logprobs.
return_logprob: Optional[Union[List[bool], bool]] = None
# The start location of the prompt for return_logprob.
# By default, this value is "-1", which means it will only return logprobs for output tokens.
logprob_start_len: Optional[Union[List[int], int]] = None
# The number of top logprobs to return.
top_logprobs_num: Optional[Union[List[int], int]] = None
......
......@@ -31,8 +31,11 @@ class Sampler(nn.Module):
logits = logits.next_token_logits
# Post process logits
logits = logits.contiguous()
logits.div_(sampling_info.temperatures)
probs = logits[:] = torch.softmax(logits, dim=-1)
probs = torch.softmax(logits, dim=-1)
logits = None
del logits
if torch.any(torch.isnan(probs)):
logger.warning("Detected errors during sampling! NaN in the probability.")
......@@ -53,7 +56,11 @@ class Sampler(nn.Module):
)
else:
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
probs, uniform_samples, sampling_info.top_ks, sampling_info.top_ps
probs,
uniform_samples,
sampling_info.top_ks,
sampling_info.top_ps,
filter_apply_order="joint",
)
if not torch.all(success):
......
......@@ -400,8 +400,8 @@ class ModelRunner:
)
self.req_to_token_pool = ReqToTokenPool(
max_num_reqs,
self.model_config.context_len + 8,
max_num_reqs + 1,
self.model_config.context_len + 4,
)
if (
self.model_config.attention_arch == AttentionArch.MLA
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment