Unverified Commit 722530fa authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Enable overlap scheduler by default for the triton attention backend (#2105)

parent 56a347f7
......@@ -53,7 +53,7 @@ class TritonAttnBackend(AttentionBackend):
start_loc = torch.zeros_like(forward_batch.seq_lens, dtype=torch.int32)
start_loc[1:] = torch.cumsum(forward_batch.seq_lens[:-1], dim=0)
total_num_tokens = torch.sum(forward_batch.seq_lens).item()
total_num_tokens = forward_batch.seq_lens_sum
attn_logits = torch.empty(
(self.num_head, total_num_tokens),
dtype=self.reduce_dtype,
......
......@@ -170,18 +170,9 @@ class Scheduler:
if not self.is_generation:
self.enable_overlap = False
logger.info("Overlap scheduler is disabled for embedding models.")
if (
server_args.attention_backend == "triton"
or server_args.enable_double_sparsity
or (
self.model_config.attention_arch == AttentionArch.MLA
and not self.server_args.disable_mla
)
):
self.enable_overlap = False
logger.info(
"Overlap scheduler is disabled if using triton attention backend."
)
if self.enable_overlap:
self.disable_jump_forward = True
# Launch a tensor parallel worker
if self.enable_overlap:
......
......@@ -94,10 +94,21 @@ class TpModelWorkerClient:
@torch.no_grad()
def forward_thread_func_(self):
batch_pt = 0
batch_lists = [None] * 2
while True:
model_worker_batch, future_token_ids_ct = self.input_queue.get()
if not model_worker_batch:
break
# Keep a reference of model_worker_batch by storing it into a list.
# Otherwise, the tensor members of model_worker_batch will be released
# by pytorch and cause CUDA illegal memory access errors.
batch_lists[batch_pt % 2] = model_worker_batch
batch_pt += 1
# Create event
self.launch_done = threading.Event()
copy_done = torch.cuda.Event()
......
......@@ -170,7 +170,6 @@ class CudaGraphRunner:
self.encoder_lens = None
if self.enable_dp_attention:
self.global_num_tokens = [0] * self.tp_size
self.gathered_buffer = torch.zeros(
(
self.max_bs * self.tp_size,
......@@ -264,10 +263,10 @@ class CudaGraphRunner:
mrope_positions = self.mrope_positions[:, :bs]
if self.enable_dp_attention:
self.global_num_tokens[:] = [bs] * self.tp_size
global_num_tokens = [bs] * self.tp_size
gathered_buffer = self.gathered_buffer[: bs * self.tp_size]
else:
self.global_num_tokens = None
global_num_tokens = None
gathered_buffer = None
# Attention backend
......@@ -296,7 +295,7 @@ class CudaGraphRunner:
top_logprobs_nums=[0] * bs,
positions=clamp_position(seq_lens),
mrope_positions=mrope_positions,
global_num_tokens=self.global_num_tokens,
global_num_tokens=global_num_tokens,
gathered_buffer=gathered_buffer,
)
logits_output = forward(input_ids, forward_batch.positions, forward_batch)
......@@ -348,8 +347,6 @@ class CudaGraphRunner:
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
if forward_batch.mrope_positions is not None:
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
if self.enable_dp_attention:
self.global_num_tokens[:] = [bs] * self.tp_size
# Attention backend
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
......
......@@ -174,17 +174,17 @@ class ServerArgs:
self.cuda_graph_max_bs = 4
logger.info("Automatically adjust --chunked-prefill-size for small GPUs.")
# Choose kernel backends
if not is_flashinfer_available():
self.attention_backend = "triton"
self.sampling_backend = "pytorch"
# Default kernel backends
if self.attention_backend is None:
self.attention_backend = "flashinfer"
if self.sampling_backend is None:
self.sampling_backend = "flashinfer"
# Others
if self.enable_dp_attention:
self.dp_size = self.tp_size
self.chunked_prefill_size = self.chunked_prefill_size // 2
......@@ -205,9 +205,6 @@ class ServerArgs:
)
self.disable_overlap_schedule = True
if not self.disable_overlap_schedule:
self.disable_jump_forward = True
@staticmethod
def add_cli_args(parser: argparse.ArgumentParser):
# Model and port args
......
......@@ -2,3 +2,4 @@
kill -9 $(ps aux | grep 'multiprocessing.spawn' | grep -v 'grep' | awk '{print $2}')
kill -9 $(ps aux | grep 'sglang.launch_server' | grep -v 'grep' | awk '{print $2}')
kill -9 $(ps aux | grep 'sglang.bench' | grep -v 'grep' | awk '{print $2}')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment