Unverified Commit bbc07c41 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Move sampling logits to float32 (#773)

parent a036d419
...@@ -136,7 +136,7 @@ class LogitsProcessor(nn.Module): ...@@ -136,7 +136,7 @@ class LogitsProcessor(nn.Module):
last_logits = torch.matmul(last_hidden, weight.T) last_logits = torch.matmul(last_hidden, weight.T)
if self.tp_size > 1: if self.tp_size > 1:
last_logits = tensor_model_parallel_all_gather(last_logits) last_logits = tensor_model_parallel_all_gather(last_logits)
last_logits = last_logits[:, : self.config.vocab_size] last_logits = last_logits[:, : self.config.vocab_size].float()
if hasattr(self.config, "final_logit_softcapping"): if hasattr(self.config, "final_logit_softcapping"):
last_logits /= self.config.final_logit_softcapping last_logits /= self.config.final_logit_softcapping
...@@ -161,9 +161,9 @@ class LogitsProcessor(nn.Module): ...@@ -161,9 +161,9 @@ class LogitsProcessor(nn.Module):
all_logits = torch.matmul(hidden_states, weight.T) all_logits = torch.matmul(hidden_states, weight.T)
if self.tp_size > 1: if self.tp_size > 1:
all_logits = tensor_model_parallel_all_gather(all_logits) all_logits = tensor_model_parallel_all_gather(all_logits)
all_logits = all_logits[:, : self.config.vocab_size] all_logits = all_logits[:, : self.config.vocab_size].float()
all_logprobs = all_logits.float() all_logprobs = all_logits
del all_logits del all_logits
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1) all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
......
...@@ -687,13 +687,21 @@ class Batch: ...@@ -687,13 +687,21 @@ class Batch:
# TODO(lmzheng): apply penalty # TODO(lmzheng): apply penalty
probs = torch.softmax(logits, dim=-1) probs = torch.softmax(logits, dim=-1)
max_top_k_round, batch_size = 32, probs.shape[0] if True:
uniform_samples = torch.rand((max_top_k_round, batch_size), device=probs.device) max_top_k_round, batch_size = 32, probs.shape[0]
batch_next_token_ids, success = top_k_top_p_sampling_from_probs( uniform_samples = torch.rand(
probs, uniform_samples, self.top_ks, self.top_ps (max_top_k_round, batch_size), device=probs.device
) )
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
probs, uniform_samples, self.top_ks, self.top_ps
)
else:
# Here we provide a slower fallback implementation.
batch_next_token_ids, success = top_k_top_p_sampling_from_probs_torch(
probs, self.top_ks, self.top_ps
)
if torch.any(~success): if not torch.all(success):
warnings.warn("Sampling failed, fallback to top_k=1 strategy") warnings.warn("Sampling failed, fallback to top_k=1 strategy")
probs = probs.masked_fill(torch.isnan(probs), 0.0) probs = probs.masked_fill(torch.isnan(probs), 0.0)
argmax_ids = torch.argmax(probs, dim=-1) argmax_ids = torch.argmax(probs, dim=-1)
...@@ -933,3 +941,29 @@ def init_triton_args(forward_mode, seq_lens, prefix_lens): ...@@ -933,3 +941,29 @@ def init_triton_args(forward_mode, seq_lens, prefix_lens):
max_extend_len = int(torch.max(extend_seq_lens)) max_extend_len = int(torch.max(extend_seq_lens))
return max_seq_len, max_extend_len, start_loc, prefix_lens return max_seq_len, max_extend_len, start_loc, prefix_lens
def top_k_top_p_sampling_from_probs_torch(
probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor
):
"""A top-k and top-k sampling implementation with native pytorch operations."""
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
probs_sort[
torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1)
>= top_ks.view(-1, 1)
] = 0.0
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
try:
sampled_index = torch.multinomial(probs_sort, num_samples=1)
except RuntimeError:
batch_next_token_ids = torch.zeros(
(probs_sort.shape[0],), dtype=torch.int64, device=probs.device
)
success = torch.zeros(probs.shape[0], dtype=torch.bool, device=probs.device)
return batch_next_token_ids, success
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1)
success = torch.ones(probs.shape[0], dtype=torch.bool, device=probs.device)
return batch_next_token_ids, success
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment