Unverified Commit e646c590 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix logprob in the overlapped mode (#1795)

parent c555ce2c
...@@ -60,7 +60,7 @@ pip install "sglang[all]" ...@@ -60,7 +60,7 @@ pip install "sglang[all]"
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/
``` ```
**Important: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html) to install the proper version according to your PyTorch and CUDA versions.** Note: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html) to install the proper version according to your PyTorch and CUDA versions.
### Method 2: From source ### Method 2: From source
``` ```
...@@ -75,7 +75,7 @@ pip install -e "python[all]" ...@@ -75,7 +75,7 @@ pip install -e "python[all]"
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/
``` ```
**Important: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html) to install the proper version according to your PyTorch and CUDA versions.** Note: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html) to install the proper version according to your PyTorch and CUDA versions.
### Method 3: Using docker ### Method 3: Using docker
The docker images are available on Docker Hub as [lmsysorg/sglang](https://hub.docker.com/r/lmsysorg/sglang/tags), built from [Dockerfile](https://github.com/sgl-project/sglang/tree/main/docker). The docker images are available on Docker Hub as [lmsysorg/sglang](https://hub.docker.com/r/lmsysorg/sglang/tags), built from [Dockerfile](https://github.com/sgl-project/sglang/tree/main/docker).
......
...@@ -11,7 +11,7 @@ pip install "sglang[all]" ...@@ -11,7 +11,7 @@ pip install "sglang[all]"
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/
``` ```
**Important: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html) to install the proper version according to your PyTorch and CUDA versions.** Note: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html) to install the proper version according to your PyTorch and CUDA versions.
### Method 2: From source ### Method 2: From source
``` ```
...@@ -26,7 +26,7 @@ pip install -e "python[all]" ...@@ -26,7 +26,7 @@ pip install -e "python[all]"
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/
``` ```
**Important: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html) to install the proper version according to your PyTorch and CUDA versions.** Note: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html) to install the proper version according to your PyTorch and CUDA versions.
### Method 3: Using docker ### Method 3: Using docker
The docker images are available on Docker Hub as [lmsysorg/sglang](https://hub.docker.com/r/lmsysorg/sglang/tags), built from [Dockerfile](https://github.com/sgl-project/sglang/tree/main/docker). The docker images are available on Docker Hub as [lmsysorg/sglang](https://hub.docker.com/r/lmsysorg/sglang/tags), built from [Dockerfile](https://github.com/sgl-project/sglang/tree/main/docker).
......
...@@ -33,17 +33,17 @@ class LogitsProcessorOutput: ...@@ -33,17 +33,17 @@ class LogitsProcessorOutput:
# The logits of the next tokens. shape: [#seq, vocab_size] # The logits of the next tokens. shape: [#seq, vocab_size]
next_token_logits: torch.Tensor next_token_logits: torch.Tensor
# The logprobs of the next tokens. shape: [#seq, vocab_size] # The logprobs of the next tokens. shape: [#seq, vocab_size]
next_token_logprobs: torch.Tensor next_token_logprobs: torch.Tensor = None
# The normlaized logprobs of prompts. shape: [#seq] # The normlaized logprobs of prompts. shape: [#seq]
normalized_prompt_logprobs: torch.Tensor normalized_prompt_logprobs: torch.Tensor = None
# The logprobs of input tokens. shape: [#token, vocab_size] # The logprobs of input tokens. shape: [#token, vocab_size]
input_token_logprobs: torch.Tensor input_token_logprobs: torch.Tensor = None
# The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k] of Tuple(logprob, token_id) # The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
input_top_logprobs: List input_top_logprobs: List = None
# The logprob and id of the top-k tokens in output positions. shape [#seq, #token, k] of Tuple(logprob, token_id) # The logprob and id of the top-k tokens in output positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
output_top_logprobs: List output_top_logprobs: List = None
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -833,6 +833,7 @@ class Scheduler: ...@@ -833,6 +833,7 @@ class Scheduler:
if self.enable_overlap: if self.enable_overlap:
logits_output, next_token_ids = self.tp_worker.resulve_batch_result(bid) logits_output, next_token_ids = self.tp_worker.resulve_batch_result(bid)
next_token_logprobs = logits_output.next_token_logprobs
else: else:
# Move next_token_ids and logprobs to cpu # Move next_token_ids and logprobs to cpu
if batch.return_logprob: if batch.return_logprob:
......
...@@ -103,6 +103,8 @@ class TpModelWorkerClient: ...@@ -103,6 +103,8 @@ class TpModelWorkerClient:
while True: while True:
self.has_inflight_batch = False self.has_inflight_batch = False
model_worker_batch, future_token_ids_ct = self.input_queue.get() model_worker_batch, future_token_ids_ct = self.input_queue.get()
if not model_worker_batch:
break
self.has_inflight_batch = True self.has_inflight_batch = True
self.launch_event = threading.Event() self.launch_event = threading.Event()
...@@ -122,19 +124,48 @@ class TpModelWorkerClient: ...@@ -122,19 +124,48 @@ class TpModelWorkerClient:
] = next_token_ids ] = next_token_ids
# Copy results to the CPU # Copy results to the CPU
if model_worker_batch.return_logprob:
logits_output.next_token_logprobs = logits_output.next_token_logprobs[
torch.arange(len(next_token_ids), device=self.device),
next_token_ids,
].to("cpu", non_blocking=True)
if logits_output.input_token_logprobs is not None:
logits_output.input_token_logprobs = (
logits_output.input_token_logprobs.to("cpu", non_blocking=True)
)
logits_output.normalized_prompt_logprobs = (
logits_output.normalized_prompt_logprobs.to(
"cpu", non_blocking=True
)
)
next_token_ids = next_token_ids.to("cpu", non_blocking=True) next_token_ids = next_token_ids.to("cpu", non_blocking=True)
copy_event = torch.cuda.Event(blocking=True) copy_event = torch.cuda.Event(blocking=True)
copy_event.record() copy_event.record()
self.launch_event.set() self.launch_event.set()
self.copy_queue.put((copy_event, next_token_ids)) self.copy_queue.put((copy_event, logits_output, next_token_ids))
def copy_thread_func(self): def copy_thread_func(self):
while True: while True:
copy_event, next_token_ids = self.copy_queue.get() copy_event, logits_output, next_token_ids = self.copy_queue.get()
if not copy_event:
break
while not copy_event.query(): while not copy_event.query():
time.sleep(1e-5) time.sleep(1e-5)
self.output_queue.put((None, next_token_ids.tolist()))
if logits_output.next_token_logprobs is not None:
logits_output.next_token_logprobs = (
logits_output.next_token_logprobs.tolist()
)
if logits_output.input_token_logprobs is not None:
logits_output.input_token_logprobs = (
logits_output.input_token_logprobs.tolist()
)
logits_output.normalized_prompt_logprobs = (
logits_output.normalized_prompt_logprobs.tolist()
)
self.output_queue.put((logits_output, next_token_ids.tolist()))
def resulve_batch_result(self, bid: int): def resulve_batch_result(self, bid: int):
logits_output, next_token_ids = self.output_queue.get() logits_output, next_token_ids = self.output_queue.get()
...@@ -172,3 +203,7 @@ class TpModelWorkerClient: ...@@ -172,3 +203,7 @@ class TpModelWorkerClient:
recv_req.model_path, recv_req.load_format recv_req.model_path, recv_req.load_format
) )
return success, message return success, message
def __delete__(self):
self.input_queue.put((None, None))
self.copy_queue.put((None, None, None))
...@@ -263,7 +263,8 @@ class CudaGraphRunner: ...@@ -263,7 +263,8 @@ class CudaGraphRunner:
positions=clamp_position(seq_lens), positions=clamp_position(seq_lens),
mrope_positions=mrope_positions, mrope_positions=mrope_positions,
) )
return forward(input_ids, forward_batch.positions, forward_batch) logits_output = forward(input_ids, forward_batch.positions, forward_batch)
return logits_output.next_token_logits
for _ in range(2): for _ in range(2):
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -318,23 +319,16 @@ class CudaGraphRunner: ...@@ -318,23 +319,16 @@ class CudaGraphRunner:
# Replay # Replay
self.graphs[bs].replay() self.graphs[bs].replay()
logits_output = self.output_buffers[bs] next_token_logits = self.output_buffers[bs][:raw_bs]
# Unpad
if bs != raw_bs:
logits_output = LogitsProcessorOutput(
next_token_logits=logits_output.next_token_logits[:raw_bs],
next_token_logprobs=None,
normalized_prompt_logprobs=None,
input_token_logprobs=None,
input_top_logprobs=None,
output_top_logprobs=None,
)
# Extract logprobs # Extract logprobs
if forward_batch.return_logprob: if forward_batch.return_logprob:
logits_output.next_token_logprobs = torch.nn.functional.log_softmax( next_token_logprobs = torch.nn.functional.log_softmax(
logits_output.next_token_logits, dim=-1 next_token_logits, dim=-1
)
logits_output = LogitsProcessorOutput(
next_token_logits=next_token_logits,
next_token_logprobs=next_token_logprobs,
) )
return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums) return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums)
if return_top_logprob: if return_top_logprob:
...@@ -343,7 +337,11 @@ class CudaGraphRunner: ...@@ -343,7 +337,11 @@ class CudaGraphRunner:
top_logprobs_nums=forward_batch.top_logprobs_nums, top_logprobs_nums=forward_batch.top_logprobs_nums,
) )
logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs( logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
logits_output.next_token_logprobs, logits_metadata next_token_logprobs, logits_metadata
)[1] )[1]
else:
logits_output = LogitsProcessorOutput(
next_token_logits=next_token_logits,
)
return logits_output return logits_output
...@@ -19,7 +19,6 @@ suites = { ...@@ -19,7 +19,6 @@ suites = {
"test_openai_server.py", "test_openai_server.py",
"test_overlap_schedule.py", "test_overlap_schedule.py",
"test_pytorch_sampling_backend.py", "test_pytorch_sampling_backend.py",
"test_radix_attention.py",
"test_retract_decode.py", "test_retract_decode.py",
"test_server_args.py", "test_server_args.py",
"test_skip_tokenizer_init.py", "test_skip_tokenizer_init.py",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment