Unverified Commit fba8eccd authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Log if cuda graph is used & extend cuda graph capture to cuda-graph-max-bs (#6201)


Co-authored-by: default avatarSangBin Cho <rkooo567@gmail.com>
parent 7d3a3d45
...@@ -259,7 +259,9 @@ def throughput_test_once( ...@@ -259,7 +259,9 @@ def throughput_test_once(
measurement_results["total_input_tokens"] measurement_results["total_input_tokens"]
+ measurement_results["total_output_tokens"] + measurement_results["total_output_tokens"]
) / latency ) / latency
measurement_results["last_gen_throughput"] = server_info["last_gen_throughput"] measurement_results["last_gen_throughput"] = server_info["internal_states"][0][
"last_gen_throughput"
]
return measurement_results return measurement_results
......
...@@ -246,7 +246,7 @@ def extend(reqs, model_runner): ...@@ -246,7 +246,7 @@ def extend(reqs, model_runner):
_maybe_prepare_dp_attn_batch(batch, model_runner) _maybe_prepare_dp_attn_batch(batch, model_runner)
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner) forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
logits_output = model_runner.forward(forward_batch) logits_output, _ = model_runner.forward(forward_batch)
next_token_ids = model_runner.sample(logits_output, forward_batch) next_token_ids = model_runner.sample(logits_output, forward_batch)
return next_token_ids, logits_output.next_token_logits, batch return next_token_ids, logits_output.next_token_logits, batch
...@@ -258,7 +258,7 @@ def decode(input_token_ids, batch, model_runner): ...@@ -258,7 +258,7 @@ def decode(input_token_ids, batch, model_runner):
_maybe_prepare_dp_attn_batch(batch, model_runner) _maybe_prepare_dp_attn_batch(batch, model_runner)
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner) forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
logits_output = model_runner.forward(forward_batch) logits_output, _ = model_runner.forward(forward_batch)
next_token_ids = model_runner.sample(logits_output, forward_batch) next_token_ids = model_runner.sample(logits_output, forward_batch)
return next_token_ids, logits_output.next_token_logits return next_token_ids, logits_output.next_token_logits
......
...@@ -25,6 +25,7 @@ import requests ...@@ -25,6 +25,7 @@ import requests
from sglang.srt.entrypoints.http_server import launch_server from sglang.srt.entrypoints.http_server import launch_server
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import kill_process_tree from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import is_in_ci, write_github_step_summary
@dataclasses.dataclass @dataclasses.dataclass
...@@ -33,9 +34,13 @@ class BenchArgs: ...@@ -33,9 +34,13 @@ class BenchArgs:
batch_size: Tuple[int] = (1,) batch_size: Tuple[int] = (1,)
input_len: Tuple[int] = (1024,) input_len: Tuple[int] = (1024,)
output_len: Tuple[int] = (16,) output_len: Tuple[int] = (16,)
temperature: float = 0.0
return_logprob: bool = False
input_len_step_percentage: float = 0.0
result_filename: str = "result.jsonl" result_filename: str = "result.jsonl"
base_url: str = "" base_url: str = ""
skip_warmup: bool = False skip_warmup: bool = False
show_report: bool = False
@staticmethod @staticmethod
def add_cli_args(parser: argparse.ArgumentParser): def add_cli_args(parser: argparse.ArgumentParser):
...@@ -49,11 +54,19 @@ class BenchArgs: ...@@ -49,11 +54,19 @@ class BenchArgs:
parser.add_argument( parser.add_argument(
"--output-len", type=int, nargs="+", default=BenchArgs.output_len "--output-len", type=int, nargs="+", default=BenchArgs.output_len
) )
parser.add_argument("--temperature", type=float, default=BenchArgs.temperature)
parser.add_argument("--return-logprob", action="store_true")
parser.add_argument(
"--input-len-step-percentage",
type=float,
default=BenchArgs.input_len_step_percentage,
)
parser.add_argument( parser.add_argument(
"--result-filename", type=str, default=BenchArgs.result_filename "--result-filename", type=str, default=BenchArgs.result_filename
) )
parser.add_argument("--base-url", type=str, default=BenchArgs.base_url) parser.add_argument("--base-url", type=str, default=BenchArgs.base_url)
parser.add_argument("--skip-warmup", action="store_true") parser.add_argument("--skip-warmup", action="store_true")
parser.add_argument("--show-report", action="store_true")
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace): def from_cli_args(cls, args: argparse.Namespace):
...@@ -99,36 +112,89 @@ def run_one_case( ...@@ -99,36 +112,89 @@ def run_one_case(
batch_size: int, batch_size: int,
input_len: int, input_len: int,
output_len: int, output_len: int,
temperature: float,
return_logprob: bool,
input_len_step_percentage: float,
run_name: str, run_name: str,
result_filename: str, result_filename: str,
): ):
requests.post(url + "/flush_cache")
input_lens = [
int(input_len * (1 + (i - (batch_size - 1) / 2) * input_len_step_percentage))
for i in range(batch_size)
]
input_ids = [ input_ids = [
[int(x) for x in np.random.randint(0, high=16384, size=(input_len,))] [int(x) for x in np.random.randint(0, high=16384, size=(input_lens[i],))]
for _ in range(batch_size) for i in range(batch_size)
] ]
use_structured_outputs = False
if use_structured_outputs:
texts = []
for _ in range(batch_size):
texts.append(
"Human: What is the capital city of france? can you give as many trivial information as possible about that city? answer in json.\n"
* 50
+ "Assistant:"
)
json_schema = "$$ANY$$"
else:
json_schema = None
tic = time.time() tic = time.time()
response = requests.post( response = requests.post(
url + "/generate", url + "/generate",
json={ json={
# "text": texts,
"input_ids": input_ids, "input_ids": input_ids,
"sampling_params": { "sampling_params": {
"temperature": 0, "temperature": temperature,
"max_new_tokens": output_len, "max_new_tokens": output_len,
"ignore_eos": True, "ignore_eos": True,
"json_schema": json_schema,
}, },
"return_logprob": return_logprob,
"stream": True,
}, },
stream=True,
) )
latency = time.time() - tic
_ = response.json() # The TTFT of the last request in the batch
output_throughput = batch_size * output_len / latency ttft = 0.0
for chunk in response.iter_lines(decode_unicode=False):
chunk = chunk.decode("utf-8")
if chunk and chunk.startswith("data:"):
if chunk == "data: [DONE]":
break
data = json.loads(chunk[5:].strip("\n"))
if "error" in data:
raise RuntimeError(f"Request has failed. {data}.")
assert (
data["meta_info"]["finish_reason"] is None
or data["meta_info"]["finish_reason"]["type"] == "length"
)
if data["meta_info"]["completion_tokens"] == 1:
ttft = time.time() - tic
latency = time.time() - tic
input_throughput = batch_size * input_len / ttft
output_throughput = batch_size * output_len / (latency - ttft)
overall_throughput = batch_size * (input_len + output_len) / latency overall_throughput = batch_size * (input_len + output_len) / latency
server_info = requests.get(url + "/get_server_info").json()
acc_length = server_info["internal_states"][0].get("avg_spec_accept_length", None)
last_gen_throughput = server_info["internal_states"][0]["last_gen_throughput"]
print(f"batch size: {batch_size}") print(f"batch size: {batch_size}")
print(f"input_len: {input_len}")
print(f"output_len: {output_len}")
print(f"latency: {latency:.2f} s") print(f"latency: {latency:.2f} s")
print(f"output throughput: {output_throughput:.2f} token/s") print(f"ttft: {ttft:.2f} s")
print(f"(input + output) throughput: {overall_throughput:.2f} token/s") print(f"Last generation throughput: {last_gen_throughput:.2f} tok/s")
print(f"Input throughput: {input_throughput:.2f} tok/s")
if output_len != 1:
print(f"output throughput: {output_throughput:.2f} tok/s")
if result_filename: if result_filename:
with open(result_filename, "a") as fout: with open(result_filename, "a") as fout:
...@@ -140,9 +206,21 @@ def run_one_case( ...@@ -140,9 +206,21 @@ def run_one_case(
"latency": round(latency, 4), "latency": round(latency, 4),
"output_throughput": round(output_throughput, 2), "output_throughput": round(output_throughput, 2),
"overall_throughput": round(overall_throughput, 2), "overall_throughput": round(overall_throughput, 2),
"last_gen_throughput": round(last_gen_throughput, 2),
} }
fout.write(json.dumps(res) + "\n") fout.write(json.dumps(res) + "\n")
return (
batch_size,
latency,
ttft,
input_throughput,
output_throughput,
overall_throughput,
last_gen_throughput,
acc_length,
)
def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs): def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
if bench_args.base_url: if bench_args.base_url:
...@@ -152,27 +230,38 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs): ...@@ -152,27 +230,38 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
# warmup # warmup
if not bench_args.skip_warmup: if not bench_args.skip_warmup:
print("=" * 8 + " Warmup Begin " + "=" * 8)
run_one_case( run_one_case(
base_url, base_url,
batch_size=16, batch_size=16,
input_len=1024, input_len=1024,
output_len=16, output_len=16,
temperature=bench_args.temperature,
return_logprob=bench_args.return_logprob,
input_len_step_percentage=bench_args.input_len_step_percentage,
run_name="", run_name="",
result_filename="", result_filename="",
) )
print("=" * 8 + " Warmup End " + "=" * 8 + "\n")
# benchmark # benchmark
result = []
try: try:
for bs, il, ol in itertools.product( for bs, il, ol in itertools.product(
bench_args.batch_size, bench_args.input_len, bench_args.output_len bench_args.batch_size, bench_args.input_len, bench_args.output_len
): ):
run_one_case( result.append(
base_url, run_one_case(
bs, base_url,
il, bs,
ol, il,
bench_args.run_name, ol,
bench_args.result_filename, temperature=bench_args.temperature,
return_logprob=bench_args.return_logprob,
input_len_step_percentage=bench_args.input_len_step_percentage,
run_name=bench_args.run_name,
result_filename=bench_args.result_filename,
)
) )
finally: finally:
if proc: if proc:
...@@ -180,6 +269,45 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs): ...@@ -180,6 +269,45 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
print(f"\nResults are saved to {bench_args.result_filename}") print(f"\nResults are saved to {bench_args.result_filename}")
if not bench_args.show_report:
return
summary = " | batch size | latency (s) | input throughput (tok/s) | output throughput (tok/s) | acc length | ITL (ms) | input price ($/1M) | output price ($/1M) |\n"
summary += "| ---------- | ----------- | ------------------------- | ------------------------- | ---------- | -------- | ------------------ | ------------------- |\n"
for (
batch_size,
latency,
ttft,
input_throughput,
output_throughput,
overall_throughput,
last_gen_throughput,
acc_length,
) in result:
hourly_cost = 2 * server_args.tp_size # $2/hour for one H100
input_util = 0.7
accept_length = round(acc_length, 2) if acc_length is not None else "n/a"
line = (
f"| {batch_size} | "
f"{latency:.2f} | "
f"{input_throughput:.2f} | "
f"{output_throughput:.2f} | "
f"{accept_length} | "
f"{1 / (output_throughput/batch_size) * 1000:.2f} | "
f"{1e6 / (input_throughput * input_util) / 3600 * hourly_cost:.2f} | "
f"{1e6 / output_throughput / 3600 * hourly_cost:.2f} |\n"
)
summary += line
# print metrics table
print(summary)
if is_in_ci():
write_github_step_summary(
f"### Test Nightly Benchmark (bench_one_batch) \n{summary}"
)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
......
...@@ -1103,7 +1103,7 @@ async def benchmark( ...@@ -1103,7 +1103,7 @@ async def benchmark(
lora_names: List[str], lora_names: List[str],
extra_request_body: Dict[str, Any], extra_request_body: Dict[str, Any],
profile: bool, profile: bool,
pd_seperated: bool = False, pd_separated: bool = False,
flush_cache: bool = False, flush_cache: bool = False,
warmup_requests: int = 1, warmup_requests: int = 1,
): ):
...@@ -1239,12 +1239,14 @@ async def benchmark( ...@@ -1239,12 +1239,14 @@ async def benchmark(
if "sglang" in backend: if "sglang" in backend:
server_info = requests.get(base_url + "/get_server_info") server_info = requests.get(base_url + "/get_server_info")
if pd_seperated: if pd_separated:
accept_length = server_info.json()["decode"][0].get( accept_length = server_info.json()["decode"][0]["internal_states"][0].get(
"avg_spec_accept_length", None "avg_spec_accept_length", None
) )
else: else:
accept_length = server_info.json().get("avg_spec_accept_length", None) accept_length = server_info.json()["internal_states"][0].get(
"avg_spec_accept_length", None
)
else: else:
accept_length = None accept_length = None
...@@ -1541,7 +1543,7 @@ def run_benchmark(args_: argparse.Namespace): ...@@ -1541,7 +1543,7 @@ def run_benchmark(args_: argparse.Namespace):
lora_names=args.lora_name, lora_names=args.lora_name,
extra_request_body=extra_request_body, extra_request_body=extra_request_body,
profile=args.profile, profile=args.profile,
pd_seperated=args.pd_seperated, pd_separated=args.pd_separated,
flush_cache=args.flush_cache, flush_cache=args.flush_cache,
) )
) )
......
...@@ -37,6 +37,12 @@ class BaseGrammarObject: ...@@ -37,6 +37,12 @@ class BaseGrammarObject:
""" """
raise NotImplementedError() raise NotImplementedError()
def rollback(self, k: int):
raise NotImplementedError()
def is_terminated(self):
raise NotImplementedError()
def allocate_vocab_mask( def allocate_vocab_mask(
self, vocab_size: int, batch_size: int, device self, vocab_size: int, batch_size: int, device
) -> torch.Tensor: ) -> torch.Tensor:
......
...@@ -277,19 +277,17 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -277,19 +277,17 @@ class SchedulerDisaggregationPrefillMixin:
next_token_ids, next_token_ids,
extend_input_len_per_req, extend_input_len_per_req,
extend_logprob_start_len_per_req, extend_logprob_start_len_per_req,
bid,
) = ( ) = (
result.logits_output, result.logits_output,
result.next_token_ids, result.next_token_ids,
result.extend_input_len_per_req, result.extend_input_len_per_req,
result.extend_logprob_start_len_per_req, result.extend_logprob_start_len_per_req,
result.bid,
) )
# Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue # Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
if self.enable_overlap: if self.enable_overlap:
# wait # wait
_, next_token_ids = self.tp_worker.resolve_last_batch_result(launch_done) _, next_token_ids, _ = self.tp_worker.resolve_last_batch_result(launch_done)
else: else:
next_token_ids = result.next_token_ids.tolist() next_token_ids = result.next_token_ids.tolist()
......
...@@ -330,7 +330,7 @@ class Engine(EngineBase): ...@@ -330,7 +330,7 @@ class Engine(EngineBase):
return { return {
**dataclasses.asdict(self.tokenizer_manager.server_args), **dataclasses.asdict(self.tokenizer_manager.server_args),
**self.scheduler_info, **self.scheduler_info,
**internal_states, "internal_states": internal_states,
"version": __version__, "version": __version__,
} }
......
...@@ -222,7 +222,7 @@ async def get_server_info(): ...@@ -222,7 +222,7 @@ async def get_server_info():
return { return {
**dataclasses.asdict(_global_state.tokenizer_manager.server_args), **dataclasses.asdict(_global_state.tokenizer_manager.server_args),
**_global_state.scheduler_info, **_global_state.scheduler_info,
**internal_states, "internal_states": internal_states,
"version": __version__, "version": __version__,
} }
......
...@@ -28,7 +28,8 @@ def create_flashinfer_kv_indices_triton( ...@@ -28,7 +28,8 @@ def create_flashinfer_kv_indices_triton(
num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE) num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
for i in range(num_loop): for i in range(num_loop):
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE # index into req_to_token_ptr needs to be int64
offset = tl.arange(0, BLOCK_SIZE).to(tl.int64) + i * BLOCK_SIZE
mask = offset < kv_end - kv_start mask = offset < kv_end - kv_start
data = tl.load( data = tl.load(
req_to_token_ptr req_to_token_ptr
...@@ -70,8 +71,9 @@ def create_flashmla_kv_indices_triton( ...@@ -70,8 +71,9 @@ def create_flashmla_kv_indices_triton(
num_pages_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE) num_pages_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
for i in range(num_pages_loop): for i in range(num_pages_loop):
# index into req_to_token_ptr needs to be int64
paged_offset = ( paged_offset = (
tl.arange(0, NUM_PAGE_PER_BLOCK) + i * NUM_PAGE_PER_BLOCK tl.arange(0, NUM_PAGE_PER_BLOCK).to(tl.int64) + i * NUM_PAGE_PER_BLOCK
) * PAGED_SIZE ) * PAGED_SIZE
paged_offset_out = tl.arange(0, NUM_PAGE_PER_BLOCK) + i * NUM_PAGE_PER_BLOCK paged_offset_out = tl.arange(0, NUM_PAGE_PER_BLOCK) + i * NUM_PAGE_PER_BLOCK
......
...@@ -160,6 +160,7 @@ class GenerationBatchResult: ...@@ -160,6 +160,7 @@ class GenerationBatchResult:
extend_input_len_per_req: List[int] extend_input_len_per_req: List[int]
extend_logprob_start_len_per_req: List[int] extend_logprob_start_len_per_req: List[int]
bid: int bid: int
can_run_cuda_graph: bool
@dataclass @dataclass
...@@ -323,13 +324,14 @@ class Scheduler( ...@@ -323,13 +324,14 @@ class Scheduler(
set_random_seed(self.random_seed) set_random_seed(self.random_seed)
# Print debug info # Print debug info
logger.info( if tp_rank == 0:
f"max_total_num_tokens={self.max_total_num_tokens}, " logger.info(
f"chunked_prefill_size={server_args.chunked_prefill_size}, " f"max_total_num_tokens={self.max_total_num_tokens}, "
f"max_prefill_tokens={self.max_prefill_tokens}, " f"chunked_prefill_size={server_args.chunked_prefill_size}, "
f"max_running_requests={self.max_running_requests}, " f"max_prefill_tokens={self.max_prefill_tokens}, "
f"context_len={self.model_config.context_len}" f"max_running_requests={self.max_running_requests}, "
) f"context_len={self.model_config.context_len}"
)
# Init memory pool and cache # Init memory pool and cache
self.init_memory_pool_and_cache() self.init_memory_pool_and_cache()
...@@ -752,6 +754,7 @@ class Scheduler( ...@@ -752,6 +754,7 @@ class Scheduler(
extend_input_len_per_req=None, extend_input_len_per_req=None,
extend_logprob_start_len_per_req=None, extend_logprob_start_len_per_req=None,
bid=bids[next_mb_id], bid=bids[next_mb_id],
can_run_cuda_graph=result.can_run_cuda_graph,
) )
self.process_batch_result(mbs[next_mb_id], output_result) self.process_batch_result(mbs[next_mb_id], output_result)
last_mbs[next_mb_id] = mbs[next_mb_id] last_mbs[next_mb_id] = mbs[next_mb_id]
...@@ -1159,7 +1162,9 @@ class Scheduler( ...@@ -1159,7 +1162,9 @@ class Scheduler(
self.metrics_collector.log_stats(self.stats) self.metrics_collector.log_stats(self.stats)
def log_decode_stats(self, running_batch=None): def log_decode_stats(
self, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None
):
batch = running_batch or self.running_batch batch = running_batch or self.running_batch
gap_latency = time.time() - self.last_decode_stats_tic gap_latency = time.time() - self.last_decode_stats_tic
...@@ -1199,6 +1204,7 @@ class Scheduler( ...@@ -1199,6 +1204,7 @@ class Scheduler(
msg += f"pre-allocated usage: {self.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, " msg += f"pre-allocated usage: {self.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "
msg += ( msg += (
f"cuda graph: {can_run_cuda_graph}, "
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, " f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
f"#queue-req: {len(self.waiting_queue)}" f"#queue-req: {len(self.waiting_queue)}"
) )
...@@ -1524,11 +1530,11 @@ class Scheduler( ...@@ -1524,11 +1530,11 @@ class Scheduler(
if self.spec_algorithm.is_none(): if self.spec_algorithm.is_none():
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
if self.pp_group.is_last_rank: if self.pp_group.is_last_rank:
logits_output, next_token_ids = ( logits_output, next_token_ids, can_run_cuda_graph = (
self.tp_worker.forward_batch_generation(model_worker_batch) self.tp_worker.forward_batch_generation(model_worker_batch)
) )
else: else:
pp_hidden_states_proxy_tensors, _ = ( pp_hidden_states_proxy_tensors, _, can_run_cuda_graph = (
self.tp_worker.forward_batch_generation(model_worker_batch) self.tp_worker.forward_batch_generation(model_worker_batch)
) )
bid = model_worker_batch.bid bid = model_worker_batch.bid
...@@ -1538,6 +1544,7 @@ class Scheduler( ...@@ -1538,6 +1544,7 @@ class Scheduler(
next_token_ids, next_token_ids,
bid, bid,
num_accepted_tokens, num_accepted_tokens,
can_run_cuda_graph,
) = self.draft_worker.forward_batch_speculative_generation(batch) ) = self.draft_worker.forward_batch_speculative_generation(batch)
self.spec_num_total_accepted_tokens += ( self.spec_num_total_accepted_tokens += (
num_accepted_tokens + batch.batch_size() num_accepted_tokens + batch.batch_size()
...@@ -1571,6 +1578,7 @@ class Scheduler( ...@@ -1571,6 +1578,7 @@ class Scheduler(
extend_input_len_per_req=extend_input_len_per_req, extend_input_len_per_req=extend_input_len_per_req,
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req, extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
bid=bid, bid=bid,
can_run_cuda_graph=can_run_cuda_graph,
) )
else: # embedding or reward model else: # embedding or reward model
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
......
...@@ -38,20 +38,16 @@ class SchedulerOutputProcessorMixin: ...@@ -38,20 +38,16 @@ class SchedulerOutputProcessorMixin:
next_token_ids, next_token_ids,
extend_input_len_per_req, extend_input_len_per_req,
extend_logprob_start_len_per_req, extend_logprob_start_len_per_req,
bid,
) = ( ) = (
result.logits_output, result.logits_output,
result.next_token_ids, result.next_token_ids,
result.extend_input_len_per_req, result.extend_input_len_per_req,
result.extend_logprob_start_len_per_req, result.extend_logprob_start_len_per_req,
result.bid,
) )
if self.enable_overlap: if self.enable_overlap:
logits_output, next_token_ids = ( logits_output, next_token_ids, _ = (
self.tp_worker.resolve_last_batch_result( self.tp_worker.resolve_last_batch_result(launch_done)
launch_done,
)
) )
else: else:
# Move next_token_ids and logprobs to cpu # Move next_token_ids and logprobs to cpu
...@@ -189,16 +185,16 @@ class SchedulerOutputProcessorMixin: ...@@ -189,16 +185,16 @@ class SchedulerOutputProcessorMixin:
result: GenerationBatchResult, result: GenerationBatchResult,
launch_done: Optional[threading.Event] = None, launch_done: Optional[threading.Event] = None,
): ):
logits_output, next_token_ids, bid = ( logits_output, next_token_ids, can_run_cuda_graph = (
result.logits_output, result.logits_output,
result.next_token_ids, result.next_token_ids,
result.bid, result.can_run_cuda_graph,
) )
self.num_generated_tokens += len(batch.reqs) self.num_generated_tokens += len(batch.reqs)
if self.enable_overlap: if self.enable_overlap:
logits_output, next_token_ids = self.tp_worker.resolve_last_batch_result( logits_output, next_token_ids, can_run_cuda_graph = (
launch_done self.tp_worker.resolve_last_batch_result(launch_done)
) )
next_token_logprobs = logits_output.next_token_logprobs next_token_logprobs = logits_output.next_token_logprobs
elif batch.spec_algorithm.is_none(): elif batch.spec_algorithm.is_none():
...@@ -280,7 +276,7 @@ class SchedulerOutputProcessorMixin: ...@@ -280,7 +276,7 @@ class SchedulerOutputProcessorMixin:
self.attn_tp_rank == 0 self.attn_tp_rank == 0
and self.forward_ct_decode % self.server_args.decode_log_interval == 0 and self.forward_ct_decode % self.server_args.decode_log_interval == 0
): ):
self.log_decode_stats(running_batch=batch) self.log_decode_stats(can_run_cuda_graph, running_batch=batch)
def add_input_logprob_return_values( def add_input_logprob_return_values(
self: Scheduler, self: Scheduler,
......
...@@ -923,12 +923,13 @@ class TokenizerManager: ...@@ -923,12 +923,13 @@ class TokenizerManager:
): ):
await self.send_to_scheduler.send_pyobj(obj) await self.send_to_scheduler.send_pyobj(obj)
async def get_internal_state(self) -> Dict[Any, Any]: async def get_internal_state(self) -> List[Dict[Any, Any]]:
req = GetInternalStateReq() req = GetInternalStateReq()
res: List[GetInternalStateReqOutput] = ( responses: List[GetInternalStateReqOutput] = (
await self.get_internal_state_communicator(req) await self.get_internal_state_communicator(req)
) )
return res[0].internal_state # Many DP ranks
return [res.internal_state for res in responses]
def get_log_request_metadata(self): def get_log_request_metadata(self):
max_length = None max_length = None
......
...@@ -20,7 +20,7 @@ from typing import Optional, Tuple, Union ...@@ -20,7 +20,7 @@ from typing import Optional, Tuple, Union
import torch import torch
from sglang.srt.configs.model_config import ModelConfig from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.distributed import get_pp_group, get_tp_group, get_world_group from sglang.srt.distributed import get_pp_group, get_world_group
from sglang.srt.hf_transformers_utils import ( from sglang.srt.hf_transformers_utils import (
get_processor, get_processor,
get_tokenizer, get_tokenizer,
...@@ -183,8 +183,11 @@ class TpModelWorker: ...@@ -183,8 +183,11 @@ class TpModelWorker:
def forward_batch_generation( def forward_batch_generation(
self, self,
model_worker_batch: ModelWorkerBatch, model_worker_batch: ModelWorkerBatch,
launch_done: Optional[threading.Event] = None,
skip_sample: bool = False, skip_sample: bool = False,
) -> Tuple[Union[LogitsProcessorOutput, torch.Tensor], Optional[torch.Tensor]]: ) -> Tuple[
Union[LogitsProcessorOutput, torch.Tensor], Optional[torch.Tensor], bool
]:
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
pp_proxy_tensors = None pp_proxy_tensors = None
...@@ -196,11 +199,11 @@ class TpModelWorker: ...@@ -196,11 +199,11 @@ class TpModelWorker:
) )
if self.pp_group.is_last_rank: if self.pp_group.is_last_rank:
logits_output = self.model_runner.forward( logits_output, can_run_cuda_graph = self.model_runner.forward(
forward_batch, pp_proxy_tensors=pp_proxy_tensors forward_batch, pp_proxy_tensors=pp_proxy_tensors
) )
if model_worker_batch.launch_done is not None: if launch_done is not None:
model_worker_batch.launch_done.set() launch_done.set()
if skip_sample: if skip_sample:
next_token_ids = None next_token_ids = None
...@@ -209,17 +212,17 @@ class TpModelWorker: ...@@ -209,17 +212,17 @@ class TpModelWorker:
logits_output, model_worker_batch logits_output, model_worker_batch
) )
return logits_output, next_token_ids return logits_output, next_token_ids, can_run_cuda_graph
else: else:
pp_proxy_tensors = self.model_runner.forward( pp_proxy_tensors, can_run_cuda_graph = self.model_runner.forward(
forward_batch, forward_batch,
pp_proxy_tensors=pp_proxy_tensors, pp_proxy_tensors=pp_proxy_tensors,
) )
return pp_proxy_tensors.tensors, None return pp_proxy_tensors.tensors, None, can_run_cuda_graph
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch): def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
logits_output = self.model_runner.forward(forward_batch) logits_output, _ = self.model_runner.forward(forward_batch)
embeddings = logits_output.embeddings embeddings = logits_output.embeddings
return embeddings return embeddings
......
...@@ -18,7 +18,7 @@ import logging ...@@ -18,7 +18,7 @@ import logging
import signal import signal
import threading import threading
from queue import Queue from queue import Queue
from typing import Optional from typing import Optional, Tuple
import psutil import psutil
import torch import torch
...@@ -145,8 +145,10 @@ class TpModelWorkerClient: ...@@ -145,8 +145,10 @@ class TpModelWorkerClient:
resolve_future_token_ids(input_ids, self.future_token_ids_map) resolve_future_token_ids(input_ids, self.future_token_ids_map)
# Run forward # Run forward
logits_output, next_token_ids = self.worker.forward_batch_generation( logits_output, next_token_ids, can_run_cuda_graph = (
model_worker_batch self.worker.forward_batch_generation(
model_worker_batch, model_worker_batch.launch_done
)
) )
# Update the future token ids map # Update the future token ids map
...@@ -171,14 +173,18 @@ class TpModelWorkerClient: ...@@ -171,14 +173,18 @@ class TpModelWorkerClient:
next_token_ids = next_token_ids.to("cpu", non_blocking=True) next_token_ids = next_token_ids.to("cpu", non_blocking=True)
copy_done.record() copy_done.record()
self.output_queue.put((copy_done, logits_output, next_token_ids)) self.output_queue.put(
(copy_done, logits_output, next_token_ids, can_run_cuda_graph)
)
def resolve_last_batch_result(self, launch_done: Optional[threading.Event] = None): def resolve_last_batch_result(self, launch_done: Optional[threading.Event] = None):
""" """
This function is called to resolve the last batch result and This function is called to resolve the last batch result and
wait for the current batch to be launched. Used in overlap mode. wait for the current batch to be launched. Used in overlap mode.
""" """
copy_done, logits_output, next_token_ids = self.output_queue.get() copy_done, logits_output, next_token_ids, can_run_cuda_graph = (
self.output_queue.get()
)
if launch_done is not None: if launch_done is not None:
launch_done.wait() launch_done.wait()
...@@ -193,9 +199,11 @@ class TpModelWorkerClient: ...@@ -193,9 +199,11 @@ class TpModelWorkerClient:
logits_output.input_token_logprobs.tolist() logits_output.input_token_logprobs.tolist()
) )
next_token_ids = next_token_ids.tolist() next_token_ids = next_token_ids.tolist()
return logits_output, next_token_ids return logits_output, next_token_ids, can_run_cuda_graph
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch): def forward_batch_generation(
self, model_worker_batch: ModelWorkerBatch
) -> Tuple[None, torch.Tensor, bool]:
# Create a new copy of sampling_info because it will be updated in-place by the scheduler for the next batch. # Create a new copy of sampling_info because it will be updated in-place by the scheduler for the next batch.
sampling_info = model_worker_batch.sampling_info sampling_info = model_worker_batch.sampling_info
sampling_info.update_penalties() sampling_info.update_penalties()
...@@ -223,7 +231,7 @@ class TpModelWorkerClient: ...@@ -223,7 +231,7 @@ class TpModelWorkerClient:
self.future_token_ids_ct = ( self.future_token_ids_ct = (
self.future_token_ids_ct + bs self.future_token_ids_ct + bs
) % self.future_token_ids_limit ) % self.future_token_ids_limit
return None, future_next_token_ids return None, future_next_token_ids, False
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput): def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
success, message = self.worker.update_weights_from_disk(recv_req) success, message = self.worker.update_weights_from_disk(recv_req)
......
...@@ -19,7 +19,7 @@ import bisect ...@@ -19,7 +19,7 @@ import bisect
import inspect import inspect
import os import os
from contextlib import contextmanager from contextlib import contextmanager
from typing import TYPE_CHECKING, Callable from typing import TYPE_CHECKING, Callable, Optional, Union
import torch import torch
import tqdm import tqdm
...@@ -40,15 +40,12 @@ from sglang.srt.patch_torch import monkey_patch_torch_compile ...@@ -40,15 +40,12 @@ from sglang.srt.patch_torch import monkey_patch_torch_compile
from sglang.srt.utils import ( from sglang.srt.utils import (
get_available_gpu_memory, get_available_gpu_memory,
get_device_memory_capacity, get_device_memory_capacity,
is_hip,
rank0_log, rank0_log,
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
_is_hip = is_hip()
def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int): def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
for sub in model._modules.values(): for sub in model._modules.values():
...@@ -137,7 +134,6 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner): ...@@ -137,7 +134,6 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
) )
gpu_mem = get_device_memory_capacity() gpu_mem = get_device_memory_capacity()
# Batch size of each rank will not become so large when DP is on
if gpu_mem is not None and gpu_mem > 96 * 1024: if gpu_mem is not None and gpu_mem > 96 * 1024:
capture_bs += list(range(160, 257, 8)) capture_bs += list(range(160, 257, 8))
...@@ -148,12 +144,15 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner): ...@@ -148,12 +144,15 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
model_runner.req_to_token_pool.size model_runner.req_to_token_pool.size
] ]
capture_bs = list(sorted(set(capture_bs)))
assert len(capture_bs) > 0 and capture_bs[0] > 0
capture_bs = [bs for bs in capture_bs if bs <= model_runner.req_to_token_pool.size]
if server_args.cuda_graph_max_bs: if server_args.cuda_graph_max_bs:
capture_bs = [bs for bs in capture_bs if bs <= server_args.cuda_graph_max_bs] capture_bs = [bs for bs in capture_bs if bs <= server_args.cuda_graph_max_bs]
if max(capture_bs) < server_args.cuda_graph_max_bs:
capture_bs += list(
range(max(capture_bs), server_args.cuda_graph_max_bs + 1, 16)
)
capture_bs = [bs for bs in capture_bs if bs <= model_runner.req_to_token_pool.size]
capture_bs = list(sorted(set(capture_bs)))
assert len(capture_bs) > 0 and capture_bs[0] > 0
compile_bs = ( compile_bs = (
[bs for bs in capture_bs if bs <= server_args.torch_compile_max_bs] [bs for bs in capture_bs if bs <= server_args.torch_compile_max_bs]
if server_args.enable_torch_compile if server_args.enable_torch_compile
......
...@@ -1085,32 +1085,33 @@ class ModelRunner: ...@@ -1085,32 +1085,33 @@ class ModelRunner:
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
skip_attn_backend_init: bool = False, skip_attn_backend_init: bool = False,
pp_proxy_tensors: Optional[PPProxyTensors] = None, pp_proxy_tensors: Optional[PPProxyTensors] = None,
) -> Union[LogitsProcessorOutput, PPProxyTensors]: ) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
can_run_cuda_graph = bool( can_run_cuda_graph = bool(
forward_batch.forward_mode.is_cuda_graph() forward_batch.forward_mode.is_cuda_graph()
and self.cuda_graph_runner and self.cuda_graph_runner
and self.cuda_graph_runner.can_run(forward_batch) and self.cuda_graph_runner.can_run(forward_batch)
) )
if can_run_cuda_graph: if can_run_cuda_graph:
return self.cuda_graph_runner.replay( ret = self.cuda_graph_runner.replay(
forward_batch, forward_batch,
skip_attn_backend_init=skip_attn_backend_init, skip_attn_backend_init=skip_attn_backend_init,
pp_proxy_tensors=pp_proxy_tensors, pp_proxy_tensors=pp_proxy_tensors,
) )
elif forward_batch.forward_mode.is_decode():
if forward_batch.forward_mode.is_decode(): ret = self.forward_decode(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
return self.forward_decode(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
elif forward_batch.forward_mode.is_extend(): elif forward_batch.forward_mode.is_extend():
return self.forward_extend( ret = self.forward_extend(
forward_batch, forward_batch,
skip_attn_backend_init=skip_attn_backend_init, skip_attn_backend_init=skip_attn_backend_init,
pp_proxy_tensors=pp_proxy_tensors, pp_proxy_tensors=pp_proxy_tensors,
) )
elif forward_batch.forward_mode.is_idle(): elif forward_batch.forward_mode.is_idle():
return self.forward_idle(forward_batch, pp_proxy_tensors=pp_proxy_tensors) ret = self.forward_idle(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
else: else:
raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}") raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
return ret, can_run_cuda_graph
def _preprocess_logits( def _preprocess_logits(
self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo
): ):
......
...@@ -1086,7 +1086,7 @@ class ServerArgs: ...@@ -1086,7 +1086,7 @@ class ServerArgs:
"--cuda-graph-max-bs", "--cuda-graph-max-bs",
type=int, type=int,
default=ServerArgs.cuda_graph_max_bs, default=ServerArgs.cuda_graph_max_bs,
help="Set the maximum batch size for cuda graph.", help="Set the maximum batch size for cuda graph. It will extend the cuda graph capture batch size to this value.",
) )
parser.add_argument( parser.add_argument(
"--cuda-graph-bs", "--cuda-graph-bs",
......
...@@ -251,8 +251,8 @@ class EAGLEWorker(TpModelWorker): ...@@ -251,8 +251,8 @@ class EAGLEWorker(TpModelWorker):
if batch.forward_mode.is_decode(): if batch.forward_mode.is_decode():
with self.draft_tp_context(self.draft_model_runner.tp_group): with self.draft_tp_context(self.draft_model_runner.tp_group):
spec_info = self.draft(batch) spec_info = self.draft(batch)
logits_output, verify_output, model_worker_batch = self.verify( logits_output, verify_output, model_worker_batch, can_run_cuda_graph = (
batch, spec_info self.verify(batch, spec_info)
) )
# If it is None, it means all requests are finished # If it is None, it means all requests are finished
...@@ -264,21 +264,22 @@ class EAGLEWorker(TpModelWorker): ...@@ -264,21 +264,22 @@ class EAGLEWorker(TpModelWorker):
verify_output.verified_id, verify_output.verified_id,
model_worker_batch.bid, model_worker_batch.bid,
sum(verify_output.accept_length_per_req_cpu), sum(verify_output.accept_length_per_req_cpu),
can_run_cuda_graph,
) )
elif batch.forward_mode.is_idle(): elif batch.forward_mode.is_idle():
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
logits_output, next_token_ids = self.target_worker.forward_batch_generation( logits_output, next_token_ids, _ = (
model_worker_batch self.target_worker.forward_batch_generation(model_worker_batch)
) )
return logits_output, next_token_ids, model_worker_batch.bid, 0 return logits_output, next_token_ids, model_worker_batch.bid, 0, False
else: else:
logits_output, next_token_ids, bid = self.forward_target_extend(batch) logits_output, next_token_ids, bid = self.forward_target_extend(batch)
with self.draft_tp_context(self.draft_model_runner.tp_group): with self.draft_tp_context(self.draft_model_runner.tp_group):
self.forward_draft_extend( self.forward_draft_extend(
batch, logits_output.hidden_states, next_token_ids batch, logits_output.hidden_states, next_token_ids
) )
return logits_output, next_token_ids, bid, 0 return logits_output, next_token_ids, bid, 0, False
def forward_target_extend( def forward_target_extend(
self, batch: ScheduleBatch self, batch: ScheduleBatch
...@@ -297,7 +298,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -297,7 +298,7 @@ class EAGLEWorker(TpModelWorker):
# We need the full hidden states to prefill the KV cache of the draft model. # We need the full hidden states to prefill the KV cache of the draft model.
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
logits_output, next_token_ids = self.target_worker.forward_batch_generation( logits_output, next_token_ids, _ = self.target_worker.forward_batch_generation(
model_worker_batch model_worker_batch
) )
return logits_output, next_token_ids, model_worker_batch.bid return logits_output, next_token_ids, model_worker_batch.bid
...@@ -478,8 +479,10 @@ class EAGLEWorker(TpModelWorker): ...@@ -478,8 +479,10 @@ class EAGLEWorker(TpModelWorker):
batch.forward_mode = ForwardMode.TARGET_VERIFY batch.forward_mode = ForwardMode.TARGET_VERIFY
batch.spec_info = spec_info batch.spec_info = spec_info
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
logits_output, _ = self.target_worker.forward_batch_generation( logits_output, _, can_run_cuda_graph = (
model_worker_batch, skip_sample=True self.target_worker.forward_batch_generation(
model_worker_batch, skip_sample=True
)
) )
self._detect_nan_if_needed(logits_output) self._detect_nan_if_needed(logits_output)
spec_info.hidden_states = logits_output.hidden_states spec_info.hidden_states = logits_output.hidden_states
...@@ -504,7 +507,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -504,7 +507,7 @@ class EAGLEWorker(TpModelWorker):
if batch.return_logprob: if batch.return_logprob:
self.add_logprob_values(batch, res, logits_output) self.add_logprob_values(batch, res, logits_output)
return logits_output, res, model_worker_batch return logits_output, res, model_worker_batch, can_run_cuda_graph
def add_logprob_values( def add_logprob_values(
self, self,
...@@ -590,7 +593,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -590,7 +593,7 @@ class EAGLEWorker(TpModelWorker):
model_worker_batch, self.draft_model_runner model_worker_batch, self.draft_model_runner
) )
forward_batch.return_logprob = False forward_batch.return_logprob = False
logits_output = self.draft_model_runner.forward(forward_batch) logits_output, _ = self.draft_model_runner.forward(forward_batch)
self._detect_nan_if_needed(logits_output) self._detect_nan_if_needed(logits_output)
assert isinstance(forward_batch.spec_info, EagleDraftInput) assert isinstance(forward_batch.spec_info, EagleDraftInput)
assert forward_batch.spec_info is batch.spec_info assert forward_batch.spec_info is batch.spec_info
...@@ -617,7 +620,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -617,7 +620,7 @@ class EAGLEWorker(TpModelWorker):
) )
# Run # Run
logits_output = self.draft_model_runner.forward(forward_batch) logits_output, _ = self.draft_model_runner.forward(forward_batch)
self._detect_nan_if_needed(logits_output) self._detect_nan_if_needed(logits_output)
self.capture_for_decode(logits_output, forward_batch.spec_info) self.capture_for_decode(logits_output, forward_batch.spec_info)
......
...@@ -395,12 +395,12 @@ def popen_launch_server( ...@@ -395,12 +395,12 @@ def popen_launch_server(
other_args: list[str] = (), other_args: list[str] = (),
env: Optional[dict] = None, env: Optional[dict] = None,
return_stdout_stderr: Optional[tuple] = None, return_stdout_stderr: Optional[tuple] = None,
pd_seperated: bool = False, pd_separated: bool = False,
): ):
_, host, port = base_url.split(":") _, host, port = base_url.split(":")
host = host[2:] host = host[2:]
if pd_seperated: if pd_separated:
command = "sglang.launch_pd_server" command = "sglang.launch_pd_server"
else: else:
command = "sglang.launch_server" command = "sglang.launch_server"
...@@ -414,7 +414,7 @@ def popen_launch_server( ...@@ -414,7 +414,7 @@ def popen_launch_server(
*[str(x) for x in other_args], *[str(x) for x in other_args],
] ]
if pd_seperated: if pd_separated:
command.extend( command.extend(
[ [
"--lb-host", "--lb-host",
...@@ -656,7 +656,7 @@ def get_benchmark_args( ...@@ -656,7 +656,7 @@ def get_benchmark_args(
disable_stream=False, disable_stream=False,
disable_ignore_eos=False, disable_ignore_eos=False,
seed: int = 0, seed: int = 0,
pd_seperated: bool = False, pd_separated: bool = False,
): ):
return SimpleNamespace( return SimpleNamespace(
backend="sglang", backend="sglang",
...@@ -686,7 +686,7 @@ def get_benchmark_args( ...@@ -686,7 +686,7 @@ def get_benchmark_args(
profile=None, profile=None,
lora_name=None, lora_name=None,
prompt_suffix="", prompt_suffix="",
pd_seperated=pd_seperated, pd_separated=pd_separated,
) )
...@@ -750,7 +750,7 @@ def run_bench_serving_multi( ...@@ -750,7 +750,7 @@ def run_bench_serving_multi(
other_server_args, other_server_args,
benchmark_args, benchmark_args,
need_warmup=False, need_warmup=False,
pd_seperated=False, pd_separated=False,
): ):
# Launch the server # Launch the server
process = popen_launch_server( process = popen_launch_server(
...@@ -758,7 +758,7 @@ def run_bench_serving_multi( ...@@ -758,7 +758,7 @@ def run_bench_serving_multi(
base_url, base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_server_args, other_args=other_server_args,
pd_seperated=pd_seperated, pd_separated=pd_separated,
) )
# run benchmark for all # run benchmark for all
......
...@@ -101,8 +101,8 @@ suites = { ...@@ -101,8 +101,8 @@ suites = {
# TestFile("test_deepep_intranode.py", 50), # TestFile("test_deepep_intranode.py", 50),
# TestFile("test_deepep_low_latency.py", 50), # TestFile("test_deepep_low_latency.py", 50),
# TestFile("test_moe_deepep_eval_accuracy_large.py", 250), # TestFile("test_moe_deepep_eval_accuracy_large.py", 250),
# TestFile("test_disaggregation.py", 90),
TestFile("test_local_attn.py", 250), TestFile("test_local_attn.py", 250),
TestFile("test_disaggregation.py", 90),
TestFile("test_full_deepseek_v3.py", 250), TestFile("test_full_deepseek_v3.py", 250),
TestFile("test_pp_single_node.py", 150), TestFile("test_pp_single_node.py", 150),
], ],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment