Unverified Commit b48edff6 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Split the overlapped version of TpModelWorkerClient into a separate file (#1726)

parent 593b19f2
...@@ -639,8 +639,8 @@ class ScheduleBatch: ...@@ -639,8 +639,8 @@ class ScheduleBatch:
if isinstance(self.tree_cache, ChunkCache): if isinstance(self.tree_cache, ChunkCache):
# ChunkCache does not have eviction # ChunkCache does not have eviction
token_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx][ token_indices = self.req_to_token_pool.req_to_token[
: seq_lens_cpu[idx] req.req_pool_idx, : seq_lens_cpu[idx]
] ]
self.token_to_kv_pool.free(token_indices) self.token_to_kv_pool.free(token_indices)
self.req_to_token_pool.free(req.req_pool_idx) self.req_to_token_pool.free(req.req_pool_idx)
...@@ -648,8 +648,8 @@ class ScheduleBatch: ...@@ -648,8 +648,8 @@ class ScheduleBatch:
else: else:
# TODO: apply more fine-grained retraction # TODO: apply more fine-grained retraction
last_uncached_pos = len(req.prefix_indices) last_uncached_pos = len(req.prefix_indices)
token_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx][ token_indices = self.req_to_token_pool.req_to_token[
last_uncached_pos : seq_lens_cpu[idx] req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
] ]
self.token_to_kv_pool.free(token_indices) self.token_to_kv_pool.free(token_indices)
self.req_to_token_pool.free(req.req_pool_idx) self.req_to_token_pool.free(req.req_pool_idx)
......
...@@ -59,6 +59,7 @@ from sglang.srt.managers.schedule_policy import ( ...@@ -59,6 +59,7 @@ from sglang.srt.managers.schedule_policy import (
SchedulePolicy, SchedulePolicy,
) )
from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
...@@ -146,9 +147,14 @@ class Scheduler: ...@@ -146,9 +147,14 @@ class Scheduler:
# Launch a tensor parallel worker # Launch a tensor parallel worker
if self.server_args.enable_overlap_schedule: if self.server_args.enable_overlap_schedule:
TpWorkerClass = TpModelWorker TpWorkerClass = TpModelWorkerClient
self.resolve_next_token_ids = (
lambda bid, x: self.tp_worker.resolve_future_token_ids(bid)
)
else: else:
TpWorkerClass = TpModelWorker TpWorkerClass = TpModelWorker
self.resolve_next_token_ids = lambda bid, x: x.tolist()
self.tp_worker = TpWorkerClass( self.tp_worker = TpWorkerClass(
server_args=server_args, server_args=server_args,
gpu_id=gpu_id, gpu_id=gpu_id,
...@@ -156,16 +162,6 @@ class Scheduler: ...@@ -156,16 +162,6 @@ class Scheduler:
dp_rank=dp_rank, dp_rank=dp_rank,
nccl_port=port_args.nccl_port, nccl_port=port_args.nccl_port,
) )
if self.server_args.enable_overlap_schedule:
self.resolve_next_token_ids = (
lambda bid, x: self.tp_worker.resolve_future_token_ids(bid)
)
self.forward_batch_generation = (
self.tp_worker.forward_batch_generation_non_blocking
)
else:
self.resolve_next_token_ids = lambda bid, x: x.tolist()
self.forward_batch_generation = self.tp_worker.forward_batch_generation
# Get token and memory info from the model worker # Get token and memory info from the model worker
( (
...@@ -728,7 +724,7 @@ class Scheduler: ...@@ -728,7 +724,7 @@ class Scheduler:
if self.is_generation: if self.is_generation:
if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0: if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
logits_output, next_token_ids = self.forward_batch_generation( logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
model_worker_batch model_worker_batch
) )
else: else:
......
...@@ -17,13 +17,8 @@ limitations under the License. ...@@ -17,13 +17,8 @@ limitations under the License.
import json import json
import logging import logging
import threading
import time
from queue import Queue
from typing import Optional from typing import Optional
import torch
from sglang.srt.configs.model_config import ModelConfig from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.managers.io_struct import UpdateWeightReqInput from sglang.srt.managers.io_struct import UpdateWeightReqInput
...@@ -108,9 +103,6 @@ class TpModelWorker: ...@@ -108,9 +103,6 @@ class TpModelWorker:
)[0] )[0]
set_random_seed(self.random_seed) set_random_seed(self.random_seed)
if server_args.enable_overlap_schedule:
self.init_overlap_status()
def get_worker_info(self): def get_worker_info(self):
return ( return (
self.max_total_num_tokens, self.max_total_num_tokens,
...@@ -137,81 +129,6 @@ class TpModelWorker: ...@@ -137,81 +129,6 @@ class TpModelWorker:
self.model_runner.token_to_kv_pool, self.model_runner.token_to_kv_pool,
) )
def init_overlap_status(self):
self.future_logits_output_dict = dict()
self.future_logits_output_ct = 0
self.future_token_ids_ct = 0
self.future_token_ids_map = torch.empty(
(self.max_running_requests * 5,), dtype=torch.int32, device=self.device
)
self.future_token_ids_limit = self.max_running_requests * 3
self.future_token_ids_output = dict()
self.future_event_map = dict()
self.forward_queue = Queue()
self.forward_stream = torch.cuda.Stream()
self.forward_thread = threading.Thread(
target=self.forward_thread_func,
)
self.forward_thread.start()
def forward_thread_func(self):
with torch.cuda.stream(self.forward_stream):
self.forward_thread_func_()
@torch.inference_mode()
def forward_thread_func_(self):
while True:
tic1 = time.time()
model_worker_batch, future_logits_output, future_next_token_ids = (
self.forward_queue.get()
)
# Resolve future tokens in the input
tic2 = time.time()
resolved_input_ids = model_worker_batch.input_ids
future_mask = resolved_input_ids < 0
resolved_input_ids[future_mask] = self.future_token_ids_map[
-resolved_input_ids[future_mask]
]
# Run forward
logits_output, next_token_ids = self.forward_batch_generation(
model_worker_batch
)
# Set future values
if model_worker_batch.return_logprob:
self.future_logits_output_dict[future_logits_output] = logits_output
# logger.info(f"set output {future_next_token_ids=}, {next_token_ids=}")
self.future_token_ids_map[-future_next_token_ids] = next_token_ids.to(
torch.int32
)
# logger.info("Set event")
self.future_token_ids_output[model_worker_batch.bid] = (
next_token_ids.tolist()
)
self.future_event_map[model_worker_batch.bid].set()
if False:
tic3 = time.time()
self.acc_time_with_waiting += tic3 - tic1
self.acc_time_without_waiting += tic3 - tic2
if self.forward_queue.qsize() == 0:
logger.info(
f"{self.acc_time_with_waiting=:.3f}, {self.acc_time_without_waiting=:.3f}, {self.forward_queue.qsize()=}"
)
def resolve_future_token_ids(self, bid: int):
self.future_event_map[bid].wait()
ret = self.future_token_ids_output[bid]
del self.future_event_map[bid]
return ret
def resolve_future_logits_output(self, future_obj):
return self.future_logits_output_dict.pop(future_obj)
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch): def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
logits_output = self.model_runner.forward(forward_batch) logits_output = self.model_runner.forward(forward_batch)
...@@ -224,32 +141,6 @@ class TpModelWorker: ...@@ -224,32 +141,6 @@ class TpModelWorker:
embeddings = logits_output.embeddings embeddings = logits_output.embeddings
return embeddings return embeddings
def forward_batch_generation_non_blocking(
self, model_worker_batch: ModelWorkerBatch
):
# Allocate output future objects
future_logits_output = self.future_logits_output_ct
self.future_logits_output_ct += 1
bs = len(model_worker_batch.seq_lens)
with torch.cuda.stream(self.forward_stream):
future_next_token_ids = -torch.arange(
self.future_token_ids_ct + 1,
self.future_token_ids_ct + 1 + bs,
dtype=torch.int32,
device=self.device,
)
self.future_token_ids_ct = (
self.future_token_ids_ct + bs
) % self.future_token_ids_limit
ret = future_logits_output, future_next_token_ids
self.future_event_map[model_worker_batch.bid] = threading.Event()
self.forward_queue.put(
(model_worker_batch.copy(), future_logits_output, future_next_token_ids)
)
return ret
def update_weights(self, recv_req: UpdateWeightReqInput): def update_weights(self, recv_req: UpdateWeightReqInput):
success, message = self.model_runner.update_weights( success, message = self.model_runner.update_weights(
recv_req.model_path, recv_req.load_format recv_req.model_path, recv_req.load_format
......
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""A tensor parallel worker."""
import logging
import threading
import time
from queue import Queue
from typing import Optional
import torch
from sglang.srt.managers.io_struct import UpdateWeightReqInput
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.server_args import ServerArgs
logger = logging.getLogger(__name__)
class TpModelWorkerClient:
"""A tensor parallel model worker."""
def __init__(
self,
server_args: ServerArgs,
gpu_id: int,
tp_rank: int,
dp_rank: Optional[int],
nccl_port: int,
):
# Load the model
self.worker = TpModelWorker(server_args, gpu_id, tp_rank, dp_rank, nccl_port)
self.max_running_requests = self.worker.max_running_requests
self.device = self.worker.device
# Create future mappings
self.future_logits_output_dict = dict()
self.future_logits_output_ct = 0
self.future_token_ids_ct = 0
self.future_token_ids_map = torch.empty(
(self.max_running_requests * 5,), dtype=torch.int32, device=self.device
)
self.future_token_ids_limit = self.max_running_requests * 3
self.future_token_ids_output = dict()
# Launch a thread
self.future_event_map = dict()
self.forward_queue = Queue()
self.forward_stream = torch.cuda.Stream()
self.forward_thread = threading.Thread(
target=self.forward_thread_func,
)
self.forward_thread.start()
def get_worker_info(self):
return self.worker.get_worker_info()
def get_pad_input_ids_func(self):
return self.worker.get_pad_input_ids_func()
def get_tp_cpu_group(self):
return self.worker.get_tp_cpu_group()
def get_memory_pool(self):
return (
self.worker.model_runner.req_to_token_pool,
self.worker.model_runner.token_to_kv_pool,
)
def forward_thread_func(self):
with torch.cuda.stream(self.forward_stream):
self.forward_thread_func_()
@torch.inference_mode()
def forward_thread_func_(self):
while True:
tic1 = time.time()
model_worker_batch, future_logits_output, future_next_token_ids = (
self.forward_queue.get()
)
# Resolve future tokens in the input
tic2 = time.time()
resolved_input_ids = model_worker_batch.input_ids
future_mask = resolved_input_ids < 0
resolved_input_ids[future_mask] = self.future_token_ids_map[
-resolved_input_ids[future_mask]
]
# Run forward
logits_output, next_token_ids = self.worker.forward_batch_generation(
model_worker_batch
)
# Set future values
if model_worker_batch.return_logprob:
self.future_logits_output_dict[future_logits_output] = logits_output
self.future_token_ids_map[-future_next_token_ids] = next_token_ids.to(
torch.int32
)
self.future_token_ids_output[model_worker_batch.bid] = (
next_token_ids.tolist()
)
self.future_event_map[model_worker_batch.bid].set()
if False:
tic3 = time.time()
self.acc_time_with_waiting += tic3 - tic1
self.acc_time_without_waiting += tic3 - tic2
if self.forward_queue.qsize() == 0:
logger.info(
f"{self.acc_time_with_waiting=:.3f}, {self.acc_time_without_waiting=:.3f}, {self.forward_queue.qsize()=}"
)
def resolve_future_token_ids(self, bid: int):
self.future_event_map[bid].wait()
ret = self.future_token_ids_output[bid]
del self.future_event_map[bid]
return ret
def resolve_future_logits_output(self, future_obj):
return self.future_logits_output_dict.pop(future_obj)
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
# Allocate output future objects
future_logits_output = self.future_logits_output_ct
self.future_logits_output_ct += 1
bs = len(model_worker_batch.seq_lens)
with torch.cuda.stream(self.forward_stream):
future_next_token_ids = -torch.arange(
self.future_token_ids_ct + 1,
self.future_token_ids_ct + 1 + bs,
dtype=torch.int32,
device=self.device,
)
self.future_token_ids_ct = (
self.future_token_ids_ct + bs
) % self.future_token_ids_limit
ret = future_logits_output, future_next_token_ids
self.future_event_map[model_worker_batch.bid] = threading.Event()
self.forward_queue.put(
(model_worker_batch.copy(), future_logits_output, future_next_token_ids)
)
return ret
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
logits_output = self.model_runner.forward(forward_batch)
embeddings = logits_output.embeddings
return embeddings
def update_weights(self, recv_req: UpdateWeightReqInput):
success, message = self.model_runner.update_weights(
recv_req.model_path, recv_req.load_format
)
return success, message
...@@ -13,7 +13,13 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
"""Memory pool.""" """
Memory pool.
SGLang has two levels of memory pool.
ReqToTokenPool maps a a request to its token locations.
BaseTokenToKVPool maps a token location to its KV cache data.
"""
import logging import logging
from typing import List, Tuple, Union from typing import List, Tuple, Union
...@@ -26,7 +32,7 @@ logger = logging.getLogger(__name__) ...@@ -26,7 +32,7 @@ logger = logging.getLogger(__name__)
class ReqToTokenPool: class ReqToTokenPool:
"""A memory pool that maps a request to its token locations.""" """A memory pool that maps a request to its token locations."""
def __init__(self, size: int, max_context_len: int, device: str): def __init__(self, size: int, max_context_len: int, device: str, use_records: bool):
self.size = size self.size = size
self.max_context_len = max_context_len self.max_context_len = max_context_len
self.device = device self.device = device
...@@ -34,6 +40,13 @@ class ReqToTokenPool: ...@@ -34,6 +40,13 @@ class ReqToTokenPool:
(size, max_context_len), dtype=torch.int32, device=device (size, max_context_len), dtype=torch.int32, device=device
) )
self.free_slots = list(range(size)) self.free_slots = list(range(size))
self.write_records = []
if use_records:
# records all write operations
self.write = self.write_with_records
else:
self.write = self.write_without_records
def available_size(self): def available_size(self):
return len(self.free_slots) return len(self.free_slots)
...@@ -55,16 +68,27 @@ class ReqToTokenPool: ...@@ -55,16 +68,27 @@ class ReqToTokenPool:
def clear(self): def clear(self):
self.free_slots = list(range(self.size)) self.free_slots = list(range(self.size))
self.write_records = []
def write(self, indices, values): def write_without_records(self, indices, values):
self.req_to_token[indices] = values self.req_to_token[indices] = values
def write_with_records(self, indices, values):
self.req_to_token[indices] = values
self.write_records.append((indices, values))
def get_write_records(self): def get_write_records(self):
return None ret = self.write_records
self.write_records = []
return ret
def apply_write_records(self, write_records: List[Tuple]):
for indices, values in write_records:
self.req_to_token[indices] = values
class BaseTokenToKVPool: class BaseTokenToKVPool:
"""A memory pool that maps a token to its kv cache locations""" """A memory pool that maps a token location to its kv cache data."""
def __init__( def __init__(
self, self,
......
...@@ -461,6 +461,7 @@ class ModelRunner: ...@@ -461,6 +461,7 @@ class ModelRunner:
size=max_num_reqs + 1, size=max_num_reqs + 1,
max_context_len=self.model_config.context_len + 4, max_context_len=self.model_config.context_len + 4,
device=self.device, device=self.device,
use_records=False,
) )
if ( if (
self.model_config.attention_arch == AttentionArch.MLA self.model_config.attention_arch == AttentionArch.MLA
......
...@@ -170,7 +170,7 @@ class TestOpenAIVisionServer(unittest.TestCase): ...@@ -170,7 +170,7 @@ class TestOpenAIVisionServer(unittest.TestCase):
text = response.choices[0].message.content text = response.choices[0].message.content
assert isinstance(text, str) assert isinstance(text, str)
print(text) print(text)
assert "man" in text and "taxi" in text, text assert "man" in text or "cab" in text, text
assert "logo" in text, text assert "logo" in text, text
assert response.id assert response.id
assert response.created assert response.created
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment