Unverified Commit 61970b08 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Let `bench_one_batch` support `enable_dp_attention` (#4058)

parent 76c48a09
...@@ -60,6 +60,7 @@ from sglang.srt.configs.model_config import ModelConfig ...@@ -60,6 +60,7 @@ from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.entrypoints.engine import _set_envs_and_config from sglang.srt.entrypoints.engine import _set_envs_and_config
from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.managers.scheduler import Scheduler
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import SamplingParams
...@@ -184,6 +185,7 @@ def prepare_inputs_for_correctness_test(bench_args, tokenizer): ...@@ -184,6 +185,7 @@ def prepare_inputs_for_correctness_test(bench_args, tokenizer):
req.prefix_indices = [] req.prefix_indices = []
req.fill_ids = req.origin_input_ids req.fill_ids = req.origin_input_ids
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices) req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
req.logprob_start_len = len(req.origin_input_ids) - 1
reqs.append(req) reqs.append(req)
return input_ids, reqs return input_ids, reqs
...@@ -199,6 +201,7 @@ def prepare_extend_inputs_for_correctness_test( ...@@ -199,6 +201,7 @@ def prepare_extend_inputs_for_correctness_test(
i, : bench_args.cut_len i, : bench_args.cut_len
] ]
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices) req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
req.logprob_start_len = len(req.origin_input_ids) - 1
return reqs return reqs
...@@ -220,6 +223,7 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len): ...@@ -220,6 +223,7 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
req.prefix_indices = [] req.prefix_indices = []
req.fill_ids = req.origin_input_ids req.fill_ids = req.origin_input_ids
req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices) req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
req.logprob_start_len = len(req.origin_input_ids) - 1
reqs.append(req) reqs.append(req)
return reqs return reqs
...@@ -238,6 +242,7 @@ def extend(reqs, model_runner): ...@@ -238,6 +242,7 @@ def extend(reqs, model_runner):
enable_custom_logit_processor=False, enable_custom_logit_processor=False,
) )
batch.prepare_for_extend() batch.prepare_for_extend()
_maybe_prepare_dp_attn_batch(batch, model_runner)
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner) forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
logits_output = model_runner.forward(forward_batch) logits_output = model_runner.forward(forward_batch)
...@@ -249,6 +254,7 @@ def extend(reqs, model_runner): ...@@ -249,6 +254,7 @@ def extend(reqs, model_runner):
def decode(input_token_ids, batch, model_runner): def decode(input_token_ids, batch, model_runner):
batch.output_ids = input_token_ids batch.output_ids = input_token_ids
batch.prepare_for_decode() batch.prepare_for_decode()
_maybe_prepare_dp_attn_batch(batch, model_runner)
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner) forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
logits_output = model_runner.forward(forward_batch) logits_output = model_runner.forward(forward_batch)
...@@ -256,6 +262,20 @@ def decode(input_token_ids, batch, model_runner): ...@@ -256,6 +262,20 @@ def decode(input_token_ids, batch, model_runner):
return next_token_ids, logits_output.next_token_logits return next_token_ids, logits_output.next_token_logits
def _maybe_prepare_dp_attn_batch(batch: ScheduleBatch, model_runner):
if model_runner.server_args.enable_dp_attention:
Scheduler.prepare_dp_attn_batch_raw(
batch,
dp_size=model_runner.server_args.dp_size,
attn_tp_size=1,
tp_cpu_group=model_runner.tp_group.cpu_group,
get_idle_batch=None,
disable_cuda_graph=model_runner.server_args.disable_cuda_graph,
spec_algorithm=SpeculativeAlgorithm.NONE,
speculative_num_draft_tokens=None,
)
def correctness_test( def correctness_test(
server_args, server_args,
port_args, port_args,
......
...@@ -1466,14 +1466,36 @@ class Scheduler( ...@@ -1466,14 +1466,36 @@ class Scheduler(
self.send_to_tokenizer.send_pyobj(HealthCheckOutput()) self.send_to_tokenizer.send_pyobj(HealthCheckOutput())
def prepare_dp_attn_batch(self, local_batch: ScheduleBatch): def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
return self.prepare_dp_attn_batch_raw(
local_batch,
dp_size=self.server_args.dp_size,
attn_tp_size=self.attn_tp_size,
tp_cpu_group=self.tp_cpu_group,
get_idle_batch=self.get_idle_batch,
disable_cuda_graph=self.server_args.disable_cuda_graph,
spec_algorithm=self.spec_algorithm,
speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
)
@staticmethod
def prepare_dp_attn_batch_raw(
local_batch: ScheduleBatch,
dp_size,
attn_tp_size: int,
tp_cpu_group,
get_idle_batch,
disable_cuda_graph: bool,
spec_algorithm,
speculative_num_draft_tokens,
):
# Check if other DP workers have running batches # Check if other DP workers have running batches
if local_batch is None: if local_batch is None:
num_tokens = 0 num_tokens = 0
global_num_tokens_for_logprob = 0 global_num_tokens_for_logprob = 0
elif local_batch.forward_mode.is_decode(): elif local_batch.forward_mode.is_decode():
num_tokens = local_batch.batch_size() num_tokens = local_batch.batch_size()
if not self.spec_algorithm.is_none() and self.spec_algorithm.is_eagle(): if not spec_algorithm.is_none() and spec_algorithm.is_eagle():
num_tokens = num_tokens * self.server_args.speculative_num_draft_tokens num_tokens = num_tokens * speculative_num_draft_tokens
global_num_tokens_for_logprob = num_tokens global_num_tokens_for_logprob = num_tokens
else: else:
num_tokens = local_batch.extend_num_tokens num_tokens = local_batch.extend_num_tokens
...@@ -1492,7 +1514,7 @@ class Scheduler( ...@@ -1492,7 +1514,7 @@ class Scheduler(
else: else:
can_cuda_graph = 0 can_cuda_graph = 0
if not self.spec_algorithm.is_none(): if not spec_algorithm.is_none():
# TODO(sang): Support cuda graph when idle batch is there. # TODO(sang): Support cuda graph when idle batch is there.
if local_batch is None or local_batch.forward_mode.is_idle(): if local_batch is None or local_batch.forward_mode.is_idle():
can_cuda_graph = 0 can_cuda_graph = 0
...@@ -1510,13 +1532,13 @@ class Scheduler( ...@@ -1510,13 +1532,13 @@ class Scheduler(
dtype=torch.int64, dtype=torch.int64,
) )
global_info = torch.empty( global_info = torch.empty(
(self.server_args.dp_size, self.attn_tp_size, 4), (dp_size, attn_tp_size, 4),
dtype=torch.int64, dtype=torch.int64,
) )
torch.distributed.all_gather_into_tensor( torch.distributed.all_gather_into_tensor(
global_info.flatten(), global_info.flatten(),
local_info, local_info,
group=self.tp_cpu_group, group=tp_cpu_group,
) )
global_num_tokens = global_info[:, 0, 0].tolist() global_num_tokens = global_info[:, 0, 0].tolist()
can_cuda_graph = min(global_info[:, 0, 1].tolist()) can_cuda_graph = min(global_info[:, 0, 1].tolist())
...@@ -1524,14 +1546,14 @@ class Scheduler( ...@@ -1524,14 +1546,14 @@ class Scheduler(
is_extend_in_batch = global_info[:, 0, 3].tolist() is_extend_in_batch = global_info[:, 0, 3].tolist()
if local_batch is None and max(global_num_tokens) > 0: if local_batch is None and max(global_num_tokens) > 0:
local_batch = self.get_idle_batch() local_batch = get_idle_batch()
if local_batch is not None: if local_batch is not None:
local_batch.global_num_tokens = global_num_tokens local_batch.global_num_tokens = global_num_tokens
local_batch.global_num_tokens_for_logprob = global_num_tokens_for_logprob local_batch.global_num_tokens_for_logprob = global_num_tokens_for_logprob
# Check forward mode for cuda graph # Check forward mode for cuda graph
if not self.server_args.disable_cuda_graph: if not disable_cuda_graph:
local_batch.can_run_dp_cuda_graph = can_cuda_graph local_batch.can_run_dp_cuda_graph = can_cuda_graph
return local_batch, any(is_extend_in_batch) return local_batch, any(is_extend_in_batch)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment