Unverified Commit 722530fa authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Enable overlap scheduler by default for the triton attention backend (#2105)

parent 56a347f7
...@@ -53,7 +53,7 @@ class TritonAttnBackend(AttentionBackend): ...@@ -53,7 +53,7 @@ class TritonAttnBackend(AttentionBackend):
start_loc = torch.zeros_like(forward_batch.seq_lens, dtype=torch.int32) start_loc = torch.zeros_like(forward_batch.seq_lens, dtype=torch.int32)
start_loc[1:] = torch.cumsum(forward_batch.seq_lens[:-1], dim=0) start_loc[1:] = torch.cumsum(forward_batch.seq_lens[:-1], dim=0)
total_num_tokens = torch.sum(forward_batch.seq_lens).item() total_num_tokens = forward_batch.seq_lens_sum
attn_logits = torch.empty( attn_logits = torch.empty(
(self.num_head, total_num_tokens), (self.num_head, total_num_tokens),
dtype=self.reduce_dtype, dtype=self.reduce_dtype,
......
...@@ -170,18 +170,9 @@ class Scheduler: ...@@ -170,18 +170,9 @@ class Scheduler:
if not self.is_generation: if not self.is_generation:
self.enable_overlap = False self.enable_overlap = False
logger.info("Overlap scheduler is disabled for embedding models.") logger.info("Overlap scheduler is disabled for embedding models.")
if (
server_args.attention_backend == "triton" if self.enable_overlap:
or server_args.enable_double_sparsity self.disable_jump_forward = True
or (
self.model_config.attention_arch == AttentionArch.MLA
and not self.server_args.disable_mla
)
):
self.enable_overlap = False
logger.info(
"Overlap scheduler is disabled if using triton attention backend."
)
# Launch a tensor parallel worker # Launch a tensor parallel worker
if self.enable_overlap: if self.enable_overlap:
......
...@@ -94,10 +94,21 @@ class TpModelWorkerClient: ...@@ -94,10 +94,21 @@ class TpModelWorkerClient:
@torch.no_grad() @torch.no_grad()
def forward_thread_func_(self): def forward_thread_func_(self):
batch_pt = 0
batch_lists = [None] * 2
while True: while True:
model_worker_batch, future_token_ids_ct = self.input_queue.get() model_worker_batch, future_token_ids_ct = self.input_queue.get()
if not model_worker_batch: if not model_worker_batch:
break break
# Keep a reference of model_worker_batch by storing it into a list.
# Otherwise, the tensor members of model_worker_batch will be released
# by pytorch and cause CUDA illegal memory access errors.
batch_lists[batch_pt % 2] = model_worker_batch
batch_pt += 1
# Create event
self.launch_done = threading.Event() self.launch_done = threading.Event()
copy_done = torch.cuda.Event() copy_done = torch.cuda.Event()
......
...@@ -170,7 +170,6 @@ class CudaGraphRunner: ...@@ -170,7 +170,6 @@ class CudaGraphRunner:
self.encoder_lens = None self.encoder_lens = None
if self.enable_dp_attention: if self.enable_dp_attention:
self.global_num_tokens = [0] * self.tp_size
self.gathered_buffer = torch.zeros( self.gathered_buffer = torch.zeros(
( (
self.max_bs * self.tp_size, self.max_bs * self.tp_size,
...@@ -264,10 +263,10 @@ class CudaGraphRunner: ...@@ -264,10 +263,10 @@ class CudaGraphRunner:
mrope_positions = self.mrope_positions[:, :bs] mrope_positions = self.mrope_positions[:, :bs]
if self.enable_dp_attention: if self.enable_dp_attention:
self.global_num_tokens[:] = [bs] * self.tp_size global_num_tokens = [bs] * self.tp_size
gathered_buffer = self.gathered_buffer[: bs * self.tp_size] gathered_buffer = self.gathered_buffer[: bs * self.tp_size]
else: else:
self.global_num_tokens = None global_num_tokens = None
gathered_buffer = None gathered_buffer = None
# Attention backend # Attention backend
...@@ -296,7 +295,7 @@ class CudaGraphRunner: ...@@ -296,7 +295,7 @@ class CudaGraphRunner:
top_logprobs_nums=[0] * bs, top_logprobs_nums=[0] * bs,
positions=clamp_position(seq_lens), positions=clamp_position(seq_lens),
mrope_positions=mrope_positions, mrope_positions=mrope_positions,
global_num_tokens=self.global_num_tokens, global_num_tokens=global_num_tokens,
gathered_buffer=gathered_buffer, gathered_buffer=gathered_buffer,
) )
logits_output = forward(input_ids, forward_batch.positions, forward_batch) logits_output = forward(input_ids, forward_batch.positions, forward_batch)
...@@ -348,8 +347,6 @@ class CudaGraphRunner: ...@@ -348,8 +347,6 @@ class CudaGraphRunner:
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens) self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
if forward_batch.mrope_positions is not None: if forward_batch.mrope_positions is not None:
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions) self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
if self.enable_dp_attention:
self.global_num_tokens[:] = [bs] * self.tp_size
# Attention backend # Attention backend
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph( self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
......
...@@ -174,17 +174,17 @@ class ServerArgs: ...@@ -174,17 +174,17 @@ class ServerArgs:
self.cuda_graph_max_bs = 4 self.cuda_graph_max_bs = 4
logger.info("Automatically adjust --chunked-prefill-size for small GPUs.") logger.info("Automatically adjust --chunked-prefill-size for small GPUs.")
# Choose kernel backends
if not is_flashinfer_available(): if not is_flashinfer_available():
self.attention_backend = "triton" self.attention_backend = "triton"
self.sampling_backend = "pytorch" self.sampling_backend = "pytorch"
# Default kernel backends
if self.attention_backend is None: if self.attention_backend is None:
self.attention_backend = "flashinfer" self.attention_backend = "flashinfer"
if self.sampling_backend is None: if self.sampling_backend is None:
self.sampling_backend = "flashinfer" self.sampling_backend = "flashinfer"
# Others
if self.enable_dp_attention: if self.enable_dp_attention:
self.dp_size = self.tp_size self.dp_size = self.tp_size
self.chunked_prefill_size = self.chunked_prefill_size // 2 self.chunked_prefill_size = self.chunked_prefill_size // 2
...@@ -205,9 +205,6 @@ class ServerArgs: ...@@ -205,9 +205,6 @@ class ServerArgs:
) )
self.disable_overlap_schedule = True self.disable_overlap_schedule = True
if not self.disable_overlap_schedule:
self.disable_jump_forward = True
@staticmethod @staticmethod
def add_cli_args(parser: argparse.ArgumentParser): def add_cli_args(parser: argparse.ArgumentParser):
# Model and port args # Model and port args
......
...@@ -2,3 +2,4 @@ ...@@ -2,3 +2,4 @@
kill -9 $(ps aux | grep 'multiprocessing.spawn' | grep -v 'grep' | awk '{print $2}') kill -9 $(ps aux | grep 'multiprocessing.spawn' | grep -v 'grep' | awk '{print $2}')
kill -9 $(ps aux | grep 'sglang.launch_server' | grep -v 'grep' | awk '{print $2}') kill -9 $(ps aux | grep 'sglang.launch_server' | grep -v 'grep' | awk '{print $2}')
kill -9 $(ps aux | grep 'sglang.bench' | grep -v 'grep' | awk '{print $2}')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment