Unverified Commit 3de2f30a authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Flashinfer sample kernel (#617)

parent 4efcc59d
...@@ -156,14 +156,14 @@ def extend(reqs, model_runner): ...@@ -156,14 +156,14 @@ def extend(reqs, model_runner):
) )
batch.prepare_for_extend(model_runner.model_config.vocab_size, None) batch.prepare_for_extend(model_runner.model_config.vocab_size, None)
output = model_runner.forward(batch, ForwardMode.EXTEND) output = model_runner.forward(batch, ForwardMode.EXTEND)
next_token_ids, _ = batch.sample(output.next_token_logits) next_token_ids = batch.sample(output.next_token_logits)
return next_token_ids, output.next_token_logits, batch return next_token_ids, output.next_token_logits, batch
def decode(input_token_ids, batch, model_runner): def decode(input_token_ids, batch, model_runner):
batch.prepare_for_decode(input_token_ids.cpu().numpy()) batch.prepare_for_decode(input_token_ids.cpu().numpy())
output = model_runner.forward(batch, ForwardMode.DECODE) output = model_runner.forward(batch, ForwardMode.DECODE)
next_token_ids, _ = batch.sample(output.next_token_logits) next_token_ids = batch.sample(output.next_token_logits)
return next_token_ids, output.next_token_logits return next_token_ids, output.next_token_logits
......
...@@ -7,6 +7,7 @@ from typing import List, Union ...@@ -7,6 +7,7 @@ from typing import List, Union
import numpy as np import numpy as np
import torch import torch
from flashinfer.sampling import top_k_top_p_sampling_from_probs
from sglang.srt.constrained import RegexGuide from sglang.srt.constrained import RegexGuide
from sglang.srt.constrained.jump_forward import JumpForwardMap from sglang.srt.constrained.jump_forward import JumpForwardMap
...@@ -398,10 +399,10 @@ class Batch: ...@@ -398,10 +399,10 @@ class Batch:
).view(-1, 1) ).view(-1, 1)
self.top_ps = torch.tensor( self.top_ps = torch.tensor(
[r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device [r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device
).view(-1, 1) )
self.top_ks = torch.tensor( self.top_ks = torch.tensor(
[r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device [r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
).view(-1, 1) )
self.frequency_penalties = torch.tensor( self.frequency_penalties = torch.tensor(
[r.sampling_params.frequency_penalty for r in reqs], [r.sampling_params.frequency_penalty for r in reqs],
dtype=torch.float, dtype=torch.float,
...@@ -659,20 +660,17 @@ class Batch: ...@@ -659,20 +660,17 @@ class Batch:
# TODO(lmzheng): apply penalty # TODO(lmzheng): apply penalty
probs = torch.softmax(logits, dim=-1) probs = torch.softmax(logits, dim=-1)
probs_sort, probs_idx = _top_p_top_k(probs, self.top_ps, self.top_ks)
try: try:
sampled_index = torch.multinomial(probs_sort, num_samples=1) max_top_k_round, batch_size = 32, probs.shape[0]
except RuntimeError as e: uniform_samples = torch.rand(
warnings.warn(f"Ignore errors in sampling: {e}") (max_top_k_round, batch_size), device=probs.device
sampled_index = torch.ones(
probs_sort.shape[:-1] + (1,), dtype=torch.int64, device=probs.device
) )
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view( batch_next_token_ids, _ = top_k_top_p_sampling_from_probs(
-1 probs, uniform_samples, self.top_ks, self.top_ps
) )
batch_next_token_probs = torch.gather( except RuntimeError as e:
probs_sort, dim=1, index=sampled_index warnings.warn(f"Ignore errors in sampling: {e}")
).view(-1) batch_next_token_ids = torch.argmax(probs, dim=-1)
if has_regex: if has_regex:
batch_next_token_ids_cpu = batch_next_token_ids.cpu().numpy() batch_next_token_ids_cpu = batch_next_token_ids.cpu().numpy()
...@@ -682,18 +680,7 @@ class Batch: ...@@ -682,18 +680,7 @@ class Batch:
req.regex_fsm_state, batch_next_token_ids_cpu[i] req.regex_fsm_state, batch_next_token_ids_cpu[i]
) )
return batch_next_token_ids, batch_next_token_probs return batch_next_token_ids
def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor):
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
probs_sort[(probs_sum - probs_sort) > top_ps] = 0.0
probs_sort[
torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1) >= top_ks
] = 0.0
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
return probs_sort, probs_idx
@dataclass @dataclass
......
...@@ -451,7 +451,7 @@ class ModelTpServer: ...@@ -451,7 +451,7 @@ class ModelTpServer:
# Forward and sample the next tokens # Forward and sample the next tokens
if batch.extend_num_tokens != 0: if batch.extend_num_tokens != 0:
output = self.model_runner.forward(batch, ForwardMode.EXTEND) output = self.model_runner.forward(batch, ForwardMode.EXTEND)
next_token_ids, _ = batch.sample(output.next_token_logits) next_token_ids = batch.sample(output.next_token_logits)
# Move logprobs to cpu # Move logprobs to cpu
if output.next_token_logprobs is not None: if output.next_token_logprobs is not None:
...@@ -574,7 +574,7 @@ class ModelTpServer: ...@@ -574,7 +574,7 @@ class ModelTpServer:
# Forward and sample the next tokens # Forward and sample the next tokens
output = self.model_runner.forward(batch, ForwardMode.DECODE) output = self.model_runner.forward(batch, ForwardMode.DECODE)
next_token_ids, _ = batch.sample(output.next_token_logits) next_token_ids = batch.sample(output.next_token_logits)
# Move logprobs to cpu # Move logprobs to cpu
if output.next_token_logprobs is not None: if output.next_token_logprobs is not None:
......
...@@ -154,7 +154,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg ...@@ -154,7 +154,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
if not server_args.disable_flashinfer: if not server_args.disable_flashinfer:
assert_pkg_version( assert_pkg_version(
"flashinfer", "flashinfer",
"0.0.8", "0.1.0",
"Please uninstall the old version and " "Please uninstall the old version and "
"reinstall the latest version by following the instructions " "reinstall the latest version by following the instructions "
"at https://docs.flashinfer.ai/installation.html.", "at https://docs.flashinfer.ai/installation.html.",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment