".github/vscode:/vscode.git/clone" did not exist on "3f14b88db5b98552b9dc637a86ea3998cb4b4c16"
Unverified Commit 62832bb2 authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Support cuda graph for DP attention (#2061)

parent 11f881d1
......@@ -455,6 +455,7 @@ class ScheduleBatch:
# For DP attention
global_num_tokens: Optional[List[int]] = None
can_run_dp_cuda_graph: bool = False
# For processing logprobs
return_logprob: bool = False
......@@ -891,6 +892,13 @@ class ScheduleBatch:
self.seq_lens = torch.empty(0, dtype=torch.int32).to(
self.device, non_blocking=True
)
self.out_cache_loc = torch.empty(0, dtype=torch.int32).to(
self.device, non_blocking=True
)
self.req_pool_indices = torch.empty(0, dtype=torch.int32).to(
self.device, non_blocking=True
)
self.seq_lens_sum = 0
self.extend_num_tokens = 0
def prepare_for_decode(self, enable_overlap: bool = False):
......@@ -1032,6 +1040,7 @@ class ScheduleBatch:
return_logprob=self.return_logprob,
top_logprobs_nums=self.top_logprobs_nums,
global_num_tokens=self.global_num_tokens,
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
extend_num_tokens=self.extend_num_tokens,
extend_seq_lens=extend_seq_lens,
extend_prefix_lens=extend_prefix_lens,
......@@ -1093,6 +1102,7 @@ class ModelWorkerBatch:
# For DP attention
global_num_tokens: Optional[List[int]]
can_run_dp_cuda_graph: bool
# For extend
extend_num_tokens: Optional[int]
......
......@@ -337,7 +337,7 @@ class Scheduler:
kill_parent_process()
@torch.inference_mode()
@torch.no_grad()
def event_loop_normal(self):
"""A normal blocking scheduler loop."""
self.last_batch = None
......@@ -375,7 +375,7 @@ class Scheduler:
self.last_batch = batch
@torch.inference_mode()
@torch.no_grad()
def event_loop_overlap(self):
"""A scheduler loop that overlaps the CPU processing and GPU computation."""
result_queue = deque()
......@@ -411,16 +411,12 @@ class Scheduler:
else:
num_tokens = local_batch.extend_num_tokens
local_num_tokens = torch.tensor(
num_tokens, dtype=torch.int64, device=self.device
)
global_num_tokens = torch.empty(
self.tp_size, dtype=torch.int64, device=self.device
)
local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64)
global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64)
torch.distributed.all_gather_into_tensor(
global_num_tokens,
local_num_tokens,
group=self.tp_worker.get_tp_device_group(),
group=self.tp_cpu_group,
)
if local_batch is None and global_num_tokens.max().item() > 0:
......@@ -429,6 +425,24 @@ class Scheduler:
if local_batch is not None:
local_batch.global_num_tokens = global_num_tokens.tolist()
# Check forward mode for cuda graph
if not self.server_args.disable_cuda_graph:
forward_mode_state = torch.tensor(
(
1
if local_batch.forward_mode.is_decode()
or local_batch.forward_mode.is_idle()
else 0
),
dtype=torch.int32,
)
torch.distributed.all_reduce(
forward_mode_state,
op=torch.distributed.ReduceOp.MIN,
group=self.tp_cpu_group,
)
local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1
return local_batch
def get_idle_batch(self):
......
......@@ -128,9 +128,6 @@ class TpModelWorker:
def get_tp_cpu_group(self):
return self.model_runner.tp_group.cpu_group
def get_tp_device_group(self):
return self.model_runner.tp_group.device_group
def get_memory_pool(self):
return (
self.model_runner.req_to_token_pool,
......
......@@ -83,9 +83,6 @@ class TpModelWorkerClient:
def get_tp_cpu_group(self):
return self.worker.get_tp_cpu_group()
def get_tp_device_group(self):
return self.worker.get_tp_device_group()
def get_memory_pool(self):
return (
self.worker.model_runner.req_to_token_pool,
......@@ -96,7 +93,7 @@ class TpModelWorkerClient:
with torch.cuda.stream(self.forward_stream):
self.forward_thread_func_()
@torch.inference_mode()
@torch.no_grad()
def forward_thread_func_(self):
while True:
model_worker_batch, future_token_ids_ct = self.input_queue.get()
......
......@@ -111,6 +111,8 @@ class CudaGraphRunner:
self.use_torch_compile = model_runner.server_args.enable_torch_compile
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
self.is_encoder_decoder = self.model_runner.model_config.is_encoder_decoder
self.enable_dp_attention = self.model_runner.server_args.enable_dp_attention
self.tp_size = self.model_runner.tp_size
# Batch sizes to capture
if model_runner.server_args.disable_cuda_graph_padding:
......@@ -165,6 +167,16 @@ class CudaGraphRunner:
else:
self.encoder_lens = None
if self.enable_dp_attention:
self.global_num_tokens = [0] * self.tp_size
self.gathered_buffer = torch.zeros(
(
self.max_bs * self.tp_size,
self.model_runner.model_config.hidden_size,
),
dtype=self.model_runner.dtype,
)
# Capture
try:
with self.model_capture_mode():
......@@ -190,6 +202,16 @@ class CudaGraphRunner:
self.model_runner.model.capture_mode = False
def can_run(self, forward_batch: ForwardBatch):
if self.enable_dp_attention:
min_num_tokens, max_num_tokens = min(forward_batch.global_num_tokens), max(
forward_batch.global_num_tokens
)
is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
(min_num_tokens == max_num_tokens and max_num_tokens in self.graphs)
if self.disable_padding
else max_num_tokens <= self.max_bs
)
else:
is_bs_supported = (
forward_batch.batch_size in self.graphs
if self.disable_padding
......@@ -239,6 +261,13 @@ class CudaGraphRunner:
seq_lens_sum = seq_lens.sum().item()
mrope_positions = self.mrope_positions[:, :bs]
if self.enable_dp_attention:
self.global_num_tokens[:] = [bs] * self.tp_size
gathered_buffer = self.gathered_buffer[: bs * self.tp_size]
else:
self.global_num_tokens = None
gathered_buffer = None
# Attention backend
self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
bs,
......@@ -265,6 +294,8 @@ class CudaGraphRunner:
top_logprobs_nums=[0] * bs,
positions=clamp_position(seq_lens),
mrope_positions=mrope_positions,
global_num_tokens=self.global_num_tokens,
gathered_buffer=gathered_buffer,
)
logits_output = forward(input_ids, forward_batch.positions, forward_batch)
return logits_output.next_token_logits
......@@ -295,6 +326,11 @@ class CudaGraphRunner:
raw_bs = forward_batch.batch_size
# Pad
if self.enable_dp_attention:
index = bisect.bisect_left(
self.capture_bs, max(forward_batch.global_num_tokens)
)
else:
index = bisect.bisect_left(self.capture_bs, raw_bs)
bs = self.capture_bs[index]
if bs != raw_bs:
......@@ -310,6 +346,8 @@ class CudaGraphRunner:
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
if forward_batch.mrope_positions is not None:
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
if self.enable_dp_attention:
self.global_num_tokens[:] = [bs] * self.tp_size
# Attention backend
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
......
......@@ -138,6 +138,7 @@ class ForwardBatch:
# For DP attention
global_num_tokens: Optional[List[int]] = None
gathered_buffer: Optional[torch.Tensor] = None
can_run_dp_cuda_graph: bool = False
def compute_mrope_positions(
self, model_runner: ModelRunner, batch: ModelWorkerBatch
......@@ -221,6 +222,7 @@ class ForwardBatch:
return_logprob=batch.return_logprob,
top_logprobs_nums=batch.top_logprobs_nums,
global_num_tokens=batch.global_num_tokens,
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
lora_paths=batch.lora_paths,
sampling_info=batch.sampling_info,
)
......
......@@ -592,6 +592,9 @@ class ModelRunner:
)
def forward_idle(self, forward_batch: ForwardBatch):
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch):
return self.cuda_graph_runner.replay(forward_batch)
return self.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch
)
......
......@@ -191,11 +191,12 @@ class ServerArgs:
if self.enable_dp_attention:
self.dp_size = self.tp_size
self.chunked_prefill_size = self.chunked_prefill_size // 2
self.disable_cuda_graph = True
self.cuda_graph_max_bs = min(self.cuda_graph_max_bs, 96)
self.enable_overlap_schedule = False
logger.warning(
f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE workload issue. "
"The CUDA graph is disabled. Data parallel size is adjust to be the same as tensor parallel size."
f"The CUDA graph max batch size is adjusted to {self.cuda_graph_max_bs}. "
"Data parallel size is adjusted to be the same as tensor parallel size."
)
if self.enable_overlap_schedule:
......
......@@ -31,7 +31,7 @@ from transformers import AutoModelForCausalLM
from sglang.srt.hf_transformers_utils import get_tokenizer
@torch.inference_mode()
@torch.no_grad()
def normal_text(args):
t = get_tokenizer(args.model_path, trust_remote_code=True)
m = AutoModelForCausalLM.from_pretrained(
......@@ -69,7 +69,7 @@ def normal_text(args):
print(output_str)
@torch.inference_mode()
@torch.no_grad()
def synthetic_tokens(args):
m = AutoModelForCausalLM.from_pretrained(
args.model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment