Unverified Commit 62832bb2 authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Support cuda graph for DP attention (#2061)

parent 11f881d1
...@@ -455,6 +455,7 @@ class ScheduleBatch: ...@@ -455,6 +455,7 @@ class ScheduleBatch:
# For DP attention # For DP attention
global_num_tokens: Optional[List[int]] = None global_num_tokens: Optional[List[int]] = None
can_run_dp_cuda_graph: bool = False
# For processing logprobs # For processing logprobs
return_logprob: bool = False return_logprob: bool = False
...@@ -891,6 +892,13 @@ class ScheduleBatch: ...@@ -891,6 +892,13 @@ class ScheduleBatch:
self.seq_lens = torch.empty(0, dtype=torch.int32).to( self.seq_lens = torch.empty(0, dtype=torch.int32).to(
self.device, non_blocking=True self.device, non_blocking=True
) )
self.out_cache_loc = torch.empty(0, dtype=torch.int32).to(
self.device, non_blocking=True
)
self.req_pool_indices = torch.empty(0, dtype=torch.int32).to(
self.device, non_blocking=True
)
self.seq_lens_sum = 0
self.extend_num_tokens = 0 self.extend_num_tokens = 0
def prepare_for_decode(self, enable_overlap: bool = False): def prepare_for_decode(self, enable_overlap: bool = False):
...@@ -1032,6 +1040,7 @@ class ScheduleBatch: ...@@ -1032,6 +1040,7 @@ class ScheduleBatch:
return_logprob=self.return_logprob, return_logprob=self.return_logprob,
top_logprobs_nums=self.top_logprobs_nums, top_logprobs_nums=self.top_logprobs_nums,
global_num_tokens=self.global_num_tokens, global_num_tokens=self.global_num_tokens,
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
extend_num_tokens=self.extend_num_tokens, extend_num_tokens=self.extend_num_tokens,
extend_seq_lens=extend_seq_lens, extend_seq_lens=extend_seq_lens,
extend_prefix_lens=extend_prefix_lens, extend_prefix_lens=extend_prefix_lens,
...@@ -1093,6 +1102,7 @@ class ModelWorkerBatch: ...@@ -1093,6 +1102,7 @@ class ModelWorkerBatch:
# For DP attention # For DP attention
global_num_tokens: Optional[List[int]] global_num_tokens: Optional[List[int]]
can_run_dp_cuda_graph: bool
# For extend # For extend
extend_num_tokens: Optional[int] extend_num_tokens: Optional[int]
......
...@@ -337,7 +337,7 @@ class Scheduler: ...@@ -337,7 +337,7 @@ class Scheduler:
kill_parent_process() kill_parent_process()
@torch.inference_mode() @torch.no_grad()
def event_loop_normal(self): def event_loop_normal(self):
"""A normal blocking scheduler loop.""" """A normal blocking scheduler loop."""
self.last_batch = None self.last_batch = None
...@@ -375,7 +375,7 @@ class Scheduler: ...@@ -375,7 +375,7 @@ class Scheduler:
self.last_batch = batch self.last_batch = batch
@torch.inference_mode() @torch.no_grad()
def event_loop_overlap(self): def event_loop_overlap(self):
"""A scheduler loop that overlaps the CPU processing and GPU computation.""" """A scheduler loop that overlaps the CPU processing and GPU computation."""
result_queue = deque() result_queue = deque()
...@@ -411,16 +411,12 @@ class Scheduler: ...@@ -411,16 +411,12 @@ class Scheduler:
else: else:
num_tokens = local_batch.extend_num_tokens num_tokens = local_batch.extend_num_tokens
local_num_tokens = torch.tensor( local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64)
num_tokens, dtype=torch.int64, device=self.device global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64)
)
global_num_tokens = torch.empty(
self.tp_size, dtype=torch.int64, device=self.device
)
torch.distributed.all_gather_into_tensor( torch.distributed.all_gather_into_tensor(
global_num_tokens, global_num_tokens,
local_num_tokens, local_num_tokens,
group=self.tp_worker.get_tp_device_group(), group=self.tp_cpu_group,
) )
if local_batch is None and global_num_tokens.max().item() > 0: if local_batch is None and global_num_tokens.max().item() > 0:
...@@ -429,6 +425,24 @@ class Scheduler: ...@@ -429,6 +425,24 @@ class Scheduler:
if local_batch is not None: if local_batch is not None:
local_batch.global_num_tokens = global_num_tokens.tolist() local_batch.global_num_tokens = global_num_tokens.tolist()
# Check forward mode for cuda graph
if not self.server_args.disable_cuda_graph:
forward_mode_state = torch.tensor(
(
1
if local_batch.forward_mode.is_decode()
or local_batch.forward_mode.is_idle()
else 0
),
dtype=torch.int32,
)
torch.distributed.all_reduce(
forward_mode_state,
op=torch.distributed.ReduceOp.MIN,
group=self.tp_cpu_group,
)
local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1
return local_batch return local_batch
def get_idle_batch(self): def get_idle_batch(self):
......
...@@ -128,9 +128,6 @@ class TpModelWorker: ...@@ -128,9 +128,6 @@ class TpModelWorker:
def get_tp_cpu_group(self): def get_tp_cpu_group(self):
return self.model_runner.tp_group.cpu_group return self.model_runner.tp_group.cpu_group
def get_tp_device_group(self):
return self.model_runner.tp_group.device_group
def get_memory_pool(self): def get_memory_pool(self):
return ( return (
self.model_runner.req_to_token_pool, self.model_runner.req_to_token_pool,
......
...@@ -83,9 +83,6 @@ class TpModelWorkerClient: ...@@ -83,9 +83,6 @@ class TpModelWorkerClient:
def get_tp_cpu_group(self): def get_tp_cpu_group(self):
return self.worker.get_tp_cpu_group() return self.worker.get_tp_cpu_group()
def get_tp_device_group(self):
return self.worker.get_tp_device_group()
def get_memory_pool(self): def get_memory_pool(self):
return ( return (
self.worker.model_runner.req_to_token_pool, self.worker.model_runner.req_to_token_pool,
...@@ -96,7 +93,7 @@ class TpModelWorkerClient: ...@@ -96,7 +93,7 @@ class TpModelWorkerClient:
with torch.cuda.stream(self.forward_stream): with torch.cuda.stream(self.forward_stream):
self.forward_thread_func_() self.forward_thread_func_()
@torch.inference_mode() @torch.no_grad()
def forward_thread_func_(self): def forward_thread_func_(self):
while True: while True:
model_worker_batch, future_token_ids_ct = self.input_queue.get() model_worker_batch, future_token_ids_ct = self.input_queue.get()
......
...@@ -111,6 +111,8 @@ class CudaGraphRunner: ...@@ -111,6 +111,8 @@ class CudaGraphRunner:
self.use_torch_compile = model_runner.server_args.enable_torch_compile self.use_torch_compile = model_runner.server_args.enable_torch_compile
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
self.is_encoder_decoder = self.model_runner.model_config.is_encoder_decoder self.is_encoder_decoder = self.model_runner.model_config.is_encoder_decoder
self.enable_dp_attention = self.model_runner.server_args.enable_dp_attention
self.tp_size = self.model_runner.tp_size
# Batch sizes to capture # Batch sizes to capture
if model_runner.server_args.disable_cuda_graph_padding: if model_runner.server_args.disable_cuda_graph_padding:
...@@ -165,6 +167,16 @@ class CudaGraphRunner: ...@@ -165,6 +167,16 @@ class CudaGraphRunner:
else: else:
self.encoder_lens = None self.encoder_lens = None
if self.enable_dp_attention:
self.global_num_tokens = [0] * self.tp_size
self.gathered_buffer = torch.zeros(
(
self.max_bs * self.tp_size,
self.model_runner.model_config.hidden_size,
),
dtype=self.model_runner.dtype,
)
# Capture # Capture
try: try:
with self.model_capture_mode(): with self.model_capture_mode():
...@@ -190,6 +202,16 @@ class CudaGraphRunner: ...@@ -190,6 +202,16 @@ class CudaGraphRunner:
self.model_runner.model.capture_mode = False self.model_runner.model.capture_mode = False
def can_run(self, forward_batch: ForwardBatch): def can_run(self, forward_batch: ForwardBatch):
if self.enable_dp_attention:
min_num_tokens, max_num_tokens = min(forward_batch.global_num_tokens), max(
forward_batch.global_num_tokens
)
is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
(min_num_tokens == max_num_tokens and max_num_tokens in self.graphs)
if self.disable_padding
else max_num_tokens <= self.max_bs
)
else:
is_bs_supported = ( is_bs_supported = (
forward_batch.batch_size in self.graphs forward_batch.batch_size in self.graphs
if self.disable_padding if self.disable_padding
...@@ -239,6 +261,13 @@ class CudaGraphRunner: ...@@ -239,6 +261,13 @@ class CudaGraphRunner:
seq_lens_sum = seq_lens.sum().item() seq_lens_sum = seq_lens.sum().item()
mrope_positions = self.mrope_positions[:, :bs] mrope_positions = self.mrope_positions[:, :bs]
if self.enable_dp_attention:
self.global_num_tokens[:] = [bs] * self.tp_size
gathered_buffer = self.gathered_buffer[: bs * self.tp_size]
else:
self.global_num_tokens = None
gathered_buffer = None
# Attention backend # Attention backend
self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph( self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
bs, bs,
...@@ -265,6 +294,8 @@ class CudaGraphRunner: ...@@ -265,6 +294,8 @@ class CudaGraphRunner:
top_logprobs_nums=[0] * bs, top_logprobs_nums=[0] * bs,
positions=clamp_position(seq_lens), positions=clamp_position(seq_lens),
mrope_positions=mrope_positions, mrope_positions=mrope_positions,
global_num_tokens=self.global_num_tokens,
gathered_buffer=gathered_buffer,
) )
logits_output = forward(input_ids, forward_batch.positions, forward_batch) logits_output = forward(input_ids, forward_batch.positions, forward_batch)
return logits_output.next_token_logits return logits_output.next_token_logits
...@@ -295,6 +326,11 @@ class CudaGraphRunner: ...@@ -295,6 +326,11 @@ class CudaGraphRunner:
raw_bs = forward_batch.batch_size raw_bs = forward_batch.batch_size
# Pad # Pad
if self.enable_dp_attention:
index = bisect.bisect_left(
self.capture_bs, max(forward_batch.global_num_tokens)
)
else:
index = bisect.bisect_left(self.capture_bs, raw_bs) index = bisect.bisect_left(self.capture_bs, raw_bs)
bs = self.capture_bs[index] bs = self.capture_bs[index]
if bs != raw_bs: if bs != raw_bs:
...@@ -310,6 +346,8 @@ class CudaGraphRunner: ...@@ -310,6 +346,8 @@ class CudaGraphRunner:
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens) self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
if forward_batch.mrope_positions is not None: if forward_batch.mrope_positions is not None:
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions) self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
if self.enable_dp_attention:
self.global_num_tokens[:] = [bs] * self.tp_size
# Attention backend # Attention backend
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph( self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
......
...@@ -138,6 +138,7 @@ class ForwardBatch: ...@@ -138,6 +138,7 @@ class ForwardBatch:
# For DP attention # For DP attention
global_num_tokens: Optional[List[int]] = None global_num_tokens: Optional[List[int]] = None
gathered_buffer: Optional[torch.Tensor] = None gathered_buffer: Optional[torch.Tensor] = None
can_run_dp_cuda_graph: bool = False
def compute_mrope_positions( def compute_mrope_positions(
self, model_runner: ModelRunner, batch: ModelWorkerBatch self, model_runner: ModelRunner, batch: ModelWorkerBatch
...@@ -221,6 +222,7 @@ class ForwardBatch: ...@@ -221,6 +222,7 @@ class ForwardBatch:
return_logprob=batch.return_logprob, return_logprob=batch.return_logprob,
top_logprobs_nums=batch.top_logprobs_nums, top_logprobs_nums=batch.top_logprobs_nums,
global_num_tokens=batch.global_num_tokens, global_num_tokens=batch.global_num_tokens,
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
lora_paths=batch.lora_paths, lora_paths=batch.lora_paths,
sampling_info=batch.sampling_info, sampling_info=batch.sampling_info,
) )
......
...@@ -592,6 +592,9 @@ class ModelRunner: ...@@ -592,6 +592,9 @@ class ModelRunner:
) )
def forward_idle(self, forward_batch: ForwardBatch): def forward_idle(self, forward_batch: ForwardBatch):
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch):
return self.cuda_graph_runner.replay(forward_batch)
return self.model.forward( return self.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch forward_batch.input_ids, forward_batch.positions, forward_batch
) )
......
...@@ -191,11 +191,12 @@ class ServerArgs: ...@@ -191,11 +191,12 @@ class ServerArgs:
if self.enable_dp_attention: if self.enable_dp_attention:
self.dp_size = self.tp_size self.dp_size = self.tp_size
self.chunked_prefill_size = self.chunked_prefill_size // 2 self.chunked_prefill_size = self.chunked_prefill_size // 2
self.disable_cuda_graph = True self.cuda_graph_max_bs = min(self.cuda_graph_max_bs, 96)
self.enable_overlap_schedule = False self.enable_overlap_schedule = False
logger.warning( logger.warning(
f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE workload issue. " f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE workload issue. "
"The CUDA graph is disabled. Data parallel size is adjust to be the same as tensor parallel size." f"The CUDA graph max batch size is adjusted to {self.cuda_graph_max_bs}. "
"Data parallel size is adjusted to be the same as tensor parallel size."
) )
if self.enable_overlap_schedule: if self.enable_overlap_schedule:
......
...@@ -31,7 +31,7 @@ from transformers import AutoModelForCausalLM ...@@ -31,7 +31,7 @@ from transformers import AutoModelForCausalLM
from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.hf_transformers_utils import get_tokenizer
@torch.inference_mode() @torch.no_grad()
def normal_text(args): def normal_text(args):
t = get_tokenizer(args.model_path, trust_remote_code=True) t = get_tokenizer(args.model_path, trust_remote_code=True)
m = AutoModelForCausalLM.from_pretrained( m = AutoModelForCausalLM.from_pretrained(
...@@ -69,7 +69,7 @@ def normal_text(args): ...@@ -69,7 +69,7 @@ def normal_text(args):
print(output_str) print(output_str)
@torch.inference_mode() @torch.no_grad()
def synthetic_tokens(args): def synthetic_tokens(args):
m = AutoModelForCausalLM.from_pretrained( m = AutoModelForCausalLM.from_pretrained(
args.model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True args.model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment