Commit 287d07a6 authored by Lianmin Zheng's avatar Lianmin Zheng
Browse files

Misc fixes for eagle (flush_cache, CPU overhead) (#3014)

parent d2571dd5
...@@ -49,12 +49,13 @@ class BenchArgs: ...@@ -49,12 +49,13 @@ class BenchArgs:
gsp_system_prompt_len: int = 2048 gsp_system_prompt_len: int = 2048
gsp_question_len: int = 128 gsp_question_len: int = 128
gsp_output_len: int = 256 gsp_output_len: int = 256
seed: int = 1
disable_ignore_eos: bool = False disable_ignore_eos: bool = False
extra_request_body: Optional[str] = None extra_request_body: Optional[str] = None
seed: int = 1 apply_chat_template: bool = False
profile: bool = False
skip_warmup: bool = False skip_warmup: bool = False
do_not_exit: bool = False do_not_exit: bool = False
profile: bool = False
@staticmethod @staticmethod
def add_cli_args(parser: argparse.ArgumentParser): def add_cli_args(parser: argparse.ArgumentParser):
...@@ -141,20 +142,31 @@ class BenchArgs: ...@@ -141,20 +142,31 @@ class BenchArgs:
default=BenchArgs.gsp_output_len, default=BenchArgs.gsp_output_len,
help="Target length in tokens for outputs in generated-shared-prefix dataset", help="Target length in tokens for outputs in generated-shared-prefix dataset",
) )
parser.add_argument("--seed", type=int, default=1, help="The random seed.")
parser.add_argument( parser.add_argument(
"--disable-ignore-eos", "--disable-ignore-eos",
type=bool, action="store_true",
default=BenchArgs.disable_ignore_eos,
help="Disable ignore EOS token", help="Disable ignore EOS token",
) )
parser.add_argument( parser.add_argument(
"--extra-request-body", "--extra-request-body",
metavar='{"key1": "value1", "key2": "value2"}', metavar='{"key1": "value1", "key2": "value2"}',
type=str, type=str,
default=BenchArgs.extra_request_body,
help="Append given JSON object to the request payload. You can use this to specify" help="Append given JSON object to the request payload. You can use this to specify"
"additional generate params like sampling params.", "additional generate params like sampling params.",
) )
parser.add_argument("--seed", type=int, default=1, help="The random seed.") parser.add_argument(
"--apply-chat-template",
action="store_true",
help="Apply chat template",
)
parser.add_argument(
"--profile",
action="store_true",
help="Use Torch Profiler. The endpoint must be launched with "
"SGLANG_TORCH_PROFILER_DIR to enable profiler.",
)
parser.add_argument( parser.add_argument(
"--skip-warmup", "--skip-warmup",
action="store_true", action="store_true",
...@@ -165,12 +177,6 @@ class BenchArgs: ...@@ -165,12 +177,6 @@ class BenchArgs:
action="store_true", action="store_true",
help="Do not exit the program. This is useful for nsys profile with --duration and --delay.", help="Do not exit the program. This is useful for nsys profile with --duration and --delay.",
) )
parser.add_argument(
"--profile",
action="store_true",
help="Use Torch Profiler. The endpoint must be launched with "
"SGLANG_TORCH_PROFILER_DIR to enable profiler.",
)
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace): def from_cli_args(cls, args: argparse.Namespace):
......
...@@ -453,6 +453,7 @@ def get_dataset(args, tokenizer): ...@@ -453,6 +453,7 @@ def get_dataset(args, tokenizer):
tokenizer=tokenizer, tokenizer=tokenizer,
fixed_output_len=args.sharegpt_output_len, fixed_output_len=args.sharegpt_output_len,
context_len=args.sharegpt_context_len, context_len=args.sharegpt_context_len,
apply_chat_template=args.apply_chat_template,
) )
elif args.dataset_name == "random": elif args.dataset_name == "random":
input_requests = sample_random_requests( input_requests = sample_random_requests(
...@@ -517,6 +518,7 @@ class BenchmarkMetrics: ...@@ -517,6 +518,7 @@ class BenchmarkMetrics:
median_e2e_latency_ms: float median_e2e_latency_ms: float
std_e2e_latency_ms: float std_e2e_latency_ms: float
p99_e2e_latency_ms: float p99_e2e_latency_ms: float
concurrency: float
SHAREGPT_URL = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json" SHAREGPT_URL = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json"
...@@ -562,6 +564,7 @@ def sample_sharegpt_requests( ...@@ -562,6 +564,7 @@ def sample_sharegpt_requests(
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int] = None, fixed_output_len: Optional[int] = None,
context_len: Optional[int] = None, context_len: Optional[int] = None,
apply_chat_template=False,
) -> List[Tuple[str, int, int]]: ) -> List[Tuple[str, int, int]]:
if fixed_output_len is not None and fixed_output_len < 4: if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small") raise ValueError("output_len too small")
...@@ -592,6 +595,15 @@ def sample_sharegpt_requests( ...@@ -592,6 +595,15 @@ def sample_sharegpt_requests(
# Tokenize the prompts and completions. # Tokenize the prompts and completions.
prompt = dataset[i][0] prompt = dataset[i][0]
if apply_chat_template:
prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
add_generation_prompt=True,
tokenize=False,
)
prompt = prompt.replace(tokenizer.bos_token, "")
prompt_token_ids = tokenizer.encode(prompt) prompt_token_ids = tokenizer.encode(prompt)
completion = dataset[i][1] completion = dataset[i][1]
completion_token_ids = tokenizer.encode(completion) completion_token_ids = tokenizer.encode(completion)
...@@ -600,7 +612,7 @@ def sample_sharegpt_requests( ...@@ -600,7 +612,7 @@ def sample_sharegpt_requests(
len(completion_token_ids) if fixed_output_len is None else fixed_output_len len(completion_token_ids) if fixed_output_len is None else fixed_output_len
) )
if prompt_len < 1 or output_len < 1: if prompt_len < 2 or output_len < 2:
# Prune too short sequences. # Prune too short sequences.
continue continue
...@@ -880,6 +892,7 @@ def calculate_metrics( ...@@ -880,6 +892,7 @@ def calculate_metrics(
median_e2e_latency_ms=np.median(e2e_latencies) * 1000, median_e2e_latency_ms=np.median(e2e_latencies) * 1000,
std_e2e_latency_ms=np.std(e2e_latencies) * 1000, std_e2e_latency_ms=np.std(e2e_latencies) * 1000,
p99_e2e_latency_ms=np.percentile(e2e_latencies, 99) * 1000, p99_e2e_latency_ms=np.percentile(e2e_latencies, 99) * 1000,
concurrency=np.sum(e2e_latencies) / dur_s,
) )
return metrics, output_lens return metrics, output_lens
...@@ -1031,6 +1044,7 @@ async def benchmark( ...@@ -1031,6 +1044,7 @@ async def benchmark(
"Total token throughput (tok/s):", metrics.total_throughput "Total token throughput (tok/s):", metrics.total_throughput
) )
) )
print("{:<40} {:<10.2f}".format("Concurrency:", metrics.concurrency))
print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-")) print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-"))
print( print(
"{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms) "{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms)
...@@ -1062,13 +1076,24 @@ async def benchmark( ...@@ -1062,13 +1076,24 @@ async def benchmark(
and metrics.output_throughput is not None and metrics.output_throughput is not None
): ):
result = { result = {
# Arguments
"backend": args.backend, "backend": args.backend,
"dataset_name": args.dataset_name, "dataset_name": args.dataset_name,
"request_rate": request_rate, "request_rate": request_rate,
"max_concurrency": max_concurrency, "max_concurrency": max_concurrency,
"sharegpt_output_len": args.sharegpt_output_len,
"random_input_len": args.random_input_len,
"random_output_len": args.random_output_len,
"random_range_ratio": args.random_range_ratio,
# Results
"duration": benchmark_duration,
"completed": metrics.completed,
"total_input_tokens": metrics.total_input, "total_input_tokens": metrics.total_input,
"total_output_tokens": metrics.total_output, "total_output_tokens": metrics.total_output,
"total_output_tokens_retokenized": metrics.total_output_retokenized, "total_output_tokens_retokenized": metrics.total_output_retokenized,
"request_throughput": metrics.request_throughput,
"input_throughput": metrics.input_throughput,
"output_throughput": metrics.output_throughput,
"mean_e2e_latency_ms": metrics.mean_e2e_latency_ms, "mean_e2e_latency_ms": metrics.mean_e2e_latency_ms,
"median_e2e_latency_ms": metrics.median_e2e_latency_ms, "median_e2e_latency_ms": metrics.median_e2e_latency_ms,
"std_e2e_latency_ms": metrics.std_e2e_latency_ms, "std_e2e_latency_ms": metrics.std_e2e_latency_ms,
...@@ -1085,14 +1110,7 @@ async def benchmark( ...@@ -1085,14 +1110,7 @@ async def benchmark(
"median_itl_ms": metrics.median_itl_ms, "median_itl_ms": metrics.median_itl_ms,
"std_itl_ms": metrics.std_itl_ms, "std_itl_ms": metrics.std_itl_ms,
"p99_itl_ms": metrics.p99_itl_ms, "p99_itl_ms": metrics.p99_itl_ms,
"input_throughput": metrics.input_throughput, "concurrency": metrics.concurrency,
"output_throughput": metrics.output_throughput,
"sharegpt_output_len": args.sharegpt_output_len,
"random_input_len": args.random_input_len,
"random_output_len": args.random_output_len,
"random_range_ratio": args.random_range_ratio,
"duration": benchmark_duration,
"completed": metrics.completed,
} }
else: else:
print(f"Error running benchmark for request rate: {request_rate}") print(f"Error running benchmark for request rate: {request_rate}")
...@@ -1112,36 +1130,16 @@ async def benchmark( ...@@ -1112,36 +1130,16 @@ async def benchmark(
with open(output_file_name, "a") as file: with open(output_file_name, "a") as file:
file.write(json.dumps(result) + "\n") file.write(json.dumps(result) + "\n")
result = { result.update(
"duration": benchmark_duration, {
"completed": metrics.completed,
"total_input_tokens": metrics.total_input,
"total_output_tokens": metrics.total_output,
"total_output_tokens_retokenized": metrics.total_output_retokenized,
"request_throughput": metrics.request_throughput,
"input_throughput": metrics.input_throughput,
"output_throughput": metrics.output_throughput,
"mean_ttft_ms": metrics.mean_ttft_ms,
"median_ttft_ms": metrics.median_ttft_ms,
"std_ttft_ms": metrics.std_ttft_ms,
"p99_ttft_ms": metrics.p99_ttft_ms,
"mean_tpot_ms": metrics.mean_tpot_ms,
"median_tpot_ms": metrics.median_tpot_ms,
"std_tpot_ms": metrics.std_tpot_ms,
"p99_tpot_ms": metrics.p99_tpot_ms,
"mean_itl_ms": metrics.mean_itl_ms,
"median_itl_ms": metrics.median_itl_ms,
"std_itl_ms": metrics.std_itl_ms,
"p99_itl_ms": metrics.p99_itl_ms,
"input_lens": [output.prompt_len for output in outputs], "input_lens": [output.prompt_len for output in outputs],
"output_lens": output_lens, "output_lens": output_lens,
"ttfts": [output.ttft for output in outputs], "ttfts": [output.ttft for output in outputs],
"itls": [output.itl for output in outputs], "itls": [output.itl for output in outputs],
"generated_texts": [output.generated_text for output in outputs], "generated_texts": [output.generated_text for output in outputs],
"errors": [output.error for output in outputs], "errors": [output.error for output in outputs],
"mean_e2e_latency_ms": metrics.mean_e2e_latency_ms,
"median_e2e_latency_ms": metrics.median_e2e_latency_ms,
} }
)
return result return result
...@@ -1422,7 +1420,6 @@ if __name__ == "__main__": ...@@ -1422,7 +1420,6 @@ if __name__ == "__main__":
"actual request rate may be lower than specified with --request-rate, " "actual request rate may be lower than specified with --request-rate, "
"if the server is not processing requests fast enough to keep up.", "if the server is not processing requests fast enough to keep up.",
) )
parser.add_argument("--seed", type=int, default=1, help="The random seed.")
parser.add_argument( parser.add_argument(
"--multi", "--multi",
action="store_true", action="store_true",
...@@ -1446,14 +1443,15 @@ if __name__ == "__main__": ...@@ -1446,14 +1443,15 @@ if __name__ == "__main__":
help="Disable streaming mode.", help="Disable streaming mode.",
) )
parser.add_argument( parser.add_argument(
"--disable-ignore-eos", "--return-logprob",
action="store_true", action="store_true",
help="Disable ignoring EOS.", help="Return logprob.",
) )
parser.add_argument("--seed", type=int, default=1, help="The random seed.")
parser.add_argument( parser.add_argument(
"--return-logprob", "--disable-ignore-eos",
action="store_true", action="store_true",
help="Return logprob.", help="Disable ignoring EOS.",
) )
parser.add_argument( parser.add_argument(
"--extra-request-body", "--extra-request-body",
...@@ -1462,6 +1460,11 @@ if __name__ == "__main__": ...@@ -1462,6 +1460,11 @@ if __name__ == "__main__":
help="Append given JSON object to the request payload. You can use this to specify" help="Append given JSON object to the request payload. You can use this to specify"
"additional generate params like sampling params.", "additional generate params like sampling params.",
) )
parser.add_argument(
"--apply-chat-template",
action="store_true",
help="Apply chat template",
)
parser.add_argument( parser.add_argument(
"--profile", "--profile",
action="store_true", action="store_true",
......
...@@ -1023,7 +1023,7 @@ class Scheduler: ...@@ -1023,7 +1023,7 @@ class Scheduler:
) )
# Check for jump-forward # Check for jump-forward
if not self.disable_jump_forward: if not self.disable_jump_forward and batch.has_grammar:
jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func) jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
self.waiting_queue.extend(jump_forward_reqs) self.waiting_queue.extend(jump_forward_reqs)
if batch.is_empty(): if batch.is_empty():
...@@ -1564,6 +1564,15 @@ class Scheduler: ...@@ -1564,6 +1564,15 @@ class Scheduler:
self.grammar_backend.reset() self.grammar_backend.reset()
self.req_to_token_pool.clear() self.req_to_token_pool.clear()
self.token_to_kv_pool.clear() self.token_to_kv_pool.clear()
if not self.spec_algorithm.is_none():
self.draft_worker.model_runner.req_to_token_pool.clear()
self.draft_worker.model_runner.token_to_kv_pool.clear()
self.num_generated_tokens = 0
self.forward_ct_decode = 0
self.spec_num_total_accepted_tokens = 0
self.spec_num_total_forward_ct = 0
torch.cuda.empty_cache() torch.cuda.empty_cache()
logger.info("Cache flushed successfully!") logger.info("Cache flushed successfully!")
if_success = True if_success = True
......
...@@ -282,6 +282,9 @@ class ForwardBatch: ...@@ -282,6 +282,9 @@ class ForwardBatch:
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph, can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
lora_paths=batch.lora_paths, lora_paths=batch.lora_paths,
sampling_info=batch.sampling_info, sampling_info=batch.sampling_info,
req_to_token_pool=model_runner.req_to_token_pool,
token_to_kv_pool=model_runner.token_to_kv_pool,
attn_backend=model_runner.attn_backend,
spec_algorithm=batch.spec_algorithm, spec_algorithm=batch.spec_algorithm,
spec_info=batch.spec_info, spec_info=batch.spec_info,
capture_hidden_mode=batch.capture_hidden_mode, capture_hidden_mode=batch.capture_hidden_mode,
...@@ -336,11 +339,6 @@ class ForwardBatch: ...@@ -336,11 +339,6 @@ class ForwardBatch:
if model_runner.model_is_mrope: if model_runner.model_is_mrope:
ret.compute_mrope_positions(model_runner, batch) ret.compute_mrope_positions(model_runner, batch)
# Init attention information
ret.req_to_token_pool = model_runner.req_to_token_pool
ret.token_to_kv_pool = model_runner.token_to_kv_pool
ret.attn_backend = model_runner.attn_backend
# Init lora information # Init lora information
if model_runner.server_args.lora_paths is not None: if model_runner.server_args.lora_paths is not None:
model_runner.lora_manager.prepare_lora_batch(ret) model_runner.lora_manager.prepare_lora_batch(ret)
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
# Some shortcuts for backward compatbility. # Some shortcuts for backward compatibility.
# They will be removed in new versions. # They will be removed in new versions.
from sglang.srt.entrypoints.engine import Engine from sglang.srt.entrypoints.engine import Engine
from sglang.srt.entrypoints.http_server import launch_server from sglang.srt.entrypoints.http_server import kill_process_tree, launch_server
...@@ -180,7 +180,6 @@ def generate_draft_decode_kv_indices( ...@@ -180,7 +180,6 @@ def generate_draft_decode_kv_indices(
class EAGLEDraftInput(SpecInfo): class EAGLEDraftInput(SpecInfo):
def __init__(self): def __init__(self):
self.prev_mode = ForwardMode.DECODE self.prev_mode = ForwardMode.DECODE
self.sample_output = None
self.scores: torch.Tensor = None self.scores: torch.Tensor = None
self.score_list: List[torch.Tensor] = [] self.score_list: List[torch.Tensor] = []
...@@ -190,12 +189,16 @@ class EAGLEDraftInput(SpecInfo): ...@@ -190,12 +189,16 @@ class EAGLEDraftInput(SpecInfo):
self.cache_list: List[torch.Tenor] = [] self.cache_list: List[torch.Tenor] = []
self.iter = 0 self.iter = 0
# shape: (b, hidden_size)
self.hidden_states: torch.Tensor = None self.hidden_states: torch.Tensor = None
# shape: (b,)
self.verified_id: torch.Tensor = None self.verified_id: torch.Tensor = None
# shape: (b, vocab_size)
self.sample_output: torch.Tensor = None
self.positions: torch.Tensor = None self.positions: torch.Tensor = None
self.accept_length: torch.Tensor = None self.accept_length: torch.Tensor = None
self.has_finished: bool = False self.accept_length_cpu: List[int] = None
self.unfinished_index: List[int] = None
def load_server_args(self, server_args: ServerArgs): def load_server_args(self, server_args: ServerArgs):
self.topk: int = server_args.speculative_eagle_topk self.topk: int = server_args.speculative_eagle_topk
...@@ -218,7 +221,7 @@ class EAGLEDraftInput(SpecInfo): ...@@ -218,7 +221,7 @@ class EAGLEDraftInput(SpecInfo):
:pre_len :pre_len
] = req.prefix_indices ] = req.prefix_indices
batch.req_to_token_pool.req_to_token[req.req_pool_idx][pre_len:seq_len] = ( batch.req_to_token_pool.req_to_token[req.req_pool_idx, pre_len:seq_len] = (
out_cache_loc[pt : pt + req.extend_input_len] out_cache_loc[pt : pt + req.extend_input_len]
) )
...@@ -295,7 +298,9 @@ class EAGLEDraftInput(SpecInfo): ...@@ -295,7 +298,9 @@ class EAGLEDraftInput(SpecInfo):
self.cache_list.append(batch.out_cache_loc) self.cache_list.append(batch.out_cache_loc)
self.positions = ( self.positions = (
batch.seq_lens[:, None] batch.seq_lens[:, None]
+ torch.ones([1, self.topk], device="cuda", dtype=torch.long) * self.iter + torch.full(
[1, self.topk], fill_value=self.iter, device="cuda", dtype=torch.long
)
).flatten() ).flatten()
bs = len(batch.seq_lens) bs = len(batch.seq_lens)
...@@ -312,24 +317,25 @@ class EAGLEDraftInput(SpecInfo): ...@@ -312,24 +317,25 @@ class EAGLEDraftInput(SpecInfo):
def prepare_extend_after_decode(self, batch: ScheduleBatch): def prepare_extend_after_decode(self, batch: ScheduleBatch):
batch.out_cache_loc = batch.alloc_token_slots(self.verified_id.numel()) batch.out_cache_loc = batch.alloc_token_slots(self.verified_id.numel())
batch.extend_lens = (self.accept_length + 1).tolist() accept_length_cpu = batch.spec_info.accept_length_cpu
batch.extend_lens = [x + 1 for x in accept_length_cpu]
batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend
seq_lens_cpu = batch.seq_lens.tolist()
pt = 0 pt = 0
seq_lens = batch.seq_lens.tolist()
i = 0 i = 0
for req in batch.reqs: for req in batch.reqs:
if req.finished(): if req.finished():
continue continue
# assert seq_len - pre_len == req.extend_input_len # assert seq_len - pre_len == req.extend_input_len
input_len = self.accept_length[i] + 1 input_len = batch.extend_lens[i]
seq_len = seq_lens[i] seq_len = seq_lens_cpu[i]
batch.req_to_token_pool.req_to_token[req.req_pool_idx][ batch.req_to_token_pool.req_to_token[req.req_pool_idx][
seq_len - input_len : seq_len seq_len - input_len : seq_len
] = batch.out_cache_loc[pt : pt + input_len] ] = batch.out_cache_loc[pt : pt + input_len]
pt += input_len pt += input_len
i += 1 i += 1
assert pt == batch.out_cache_loc.shape[0]
self.positions = torch.empty_like(self.verified_id) self.positions = torch.empty_like(self.verified_id)
new_verified_id = torch.empty_like(self.accept_length, dtype=torch.long) new_verified_id = torch.empty_like(self.accept_length, dtype=torch.long)
...@@ -345,7 +351,7 @@ class EAGLEDraftInput(SpecInfo): ...@@ -345,7 +351,7 @@ class EAGLEDraftInput(SpecInfo):
triton.next_power_of_2(self.spec_steps + 1), triton.next_power_of_2(self.spec_steps + 1),
) )
batch.seq_lens_sum = sum(batch.seq_lens) batch.seq_lens_sum = sum(seq_lens_cpu)
batch.input_ids = self.verified_id batch.input_ids = self.verified_id
self.verified_id = new_verified_id self.verified_id = new_verified_id
...@@ -573,6 +579,8 @@ class EagleVerifyInput(SpecInfo): ...@@ -573,6 +579,8 @@ class EagleVerifyInput(SpecInfo):
finished_extend_len = {} # {rid:accept_length + 1} finished_extend_len = {} # {rid:accept_length + 1}
accept_index_cpu = accept_index.tolist() accept_index_cpu = accept_index.tolist()
predict_cpu = predict.tolist() predict_cpu = predict.tolist()
has_finished = False
# iterate every accepted token and check if req has finished after append the token # iterate every accepted token and check if req has finished after append the token
# should be checked BEFORE free kv cache slots # should be checked BEFORE free kv cache slots
for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)): for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)):
...@@ -586,7 +594,7 @@ class EagleVerifyInput(SpecInfo): ...@@ -586,7 +594,7 @@ class EagleVerifyInput(SpecInfo):
finished_extend_len[req.rid] = j + 1 finished_extend_len[req.rid] = j + 1
req.check_finished() req.check_finished()
if req.finished(): if req.finished():
draft_input.has_finished = True has_finished = True
# set all tokens after finished token to -1 and break # set all tokens after finished token to -1 and break
accept_index[i, j + 1 :] = -1 accept_index[i, j + 1 :] = -1
break break
...@@ -600,7 +608,6 @@ class EagleVerifyInput(SpecInfo): ...@@ -600,7 +608,6 @@ class EagleVerifyInput(SpecInfo):
accept_index = accept_index[accept_index != -1] accept_index = accept_index[accept_index != -1]
accept_length_cpu = accept_length.tolist() accept_length_cpu = accept_length.tolist()
verified_id = predict[accept_index] verified_id = predict[accept_index]
verified_id_cpu = verified_id.tolist()
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool) evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
evict_mask[accept_index] = False evict_mask[accept_index] = False
...@@ -622,7 +629,13 @@ class EagleVerifyInput(SpecInfo): ...@@ -622,7 +629,13 @@ class EagleVerifyInput(SpecInfo):
draft_input.verified_id = predict[new_accept_index] draft_input.verified_id = predict[new_accept_index]
draft_input.hidden_states = batch.spec_info.hidden_states[new_accept_index] draft_input.hidden_states = batch.spec_info.hidden_states[new_accept_index]
draft_input.accept_length = accept_length[unfinished_index] draft_input.accept_length = accept_length[unfinished_index]
draft_input.unfinished_index = unfinished_index draft_input.accept_length_cpu = [
accept_length_cpu[i] for i in unfinished_index
]
if has_finished:
draft_input.seq_lens_for_draft_extend = batch.seq_lens[unfinished_index]
else:
draft_input.seq_lens_for_draft_extend = batch.seq_lens
logits_output.next_token_logits = logits_output.next_token_logits[accept_index] logits_output.next_token_logits = logits_output.next_token_logits[accept_index]
return ( return (
......
...@@ -13,6 +13,7 @@ from sglang.srt.model_executor.forward_batch_info import ( ...@@ -13,6 +13,7 @@ from sglang.srt.model_executor.forward_batch_info import (
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.eagle_utils import EAGLEDraftInput from sglang.srt.speculative.eagle_utils import EAGLEDraftInput
from sglang.srt.utils import rank0_print
class EAGLEWorker(TpModelWorker): class EAGLEWorker(TpModelWorker):
...@@ -50,18 +51,18 @@ class EAGLEWorker(TpModelWorker): ...@@ -50,18 +51,18 @@ class EAGLEWorker(TpModelWorker):
def forward_draft_decode(self, batch: ScheduleBatch): def forward_draft_decode(self, batch: ScheduleBatch):
batch.spec_info.prepare_for_decode(batch) batch.spec_info.prepare_for_decode(batch)
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST
logits_output = self.model_runner.forward(forward_batch) logits_output = self.model_runner.forward(forward_batch)
self.capture_for_decode(logits_output, forward_batch) self.capture_for_decode(logits_output, forward_batch)
def forward_draft_extend(self, batch: ScheduleBatch): def forward_draft_extend(self, batch: ScheduleBatch):
self._set_mem_pool(batch, self.model_runner) self._set_mem_pool(batch, self.model_runner)
batch.spec_info.prepare_for_extend(batch) batch.spec_info.prepare_for_extend(batch)
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST
logits_output = self.model_runner.forward(forward_batch) logits_output = self.model_runner.forward(forward_batch)
self.capture_for_decode(logits_output, forward_batch) self.capture_for_decode(logits_output, forward_batch)
self._set_mem_pool(batch, self.target_worker.model_runner) self._set_mem_pool(batch, self.target_worker.model_runner)
...@@ -134,26 +135,23 @@ class EAGLEWorker(TpModelWorker): ...@@ -134,26 +135,23 @@ class EAGLEWorker(TpModelWorker):
batch.req_to_token_pool = runner.req_to_token_pool batch.req_to_token_pool = runner.req_to_token_pool
def forward_draft_extend_after_decode(self, batch: ScheduleBatch): def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
seq_lens_backup = batch.seq_lens
self._set_mem_pool(batch, self.model_runner) self._set_mem_pool(batch, self.model_runner)
batch.forward_mode = ForwardMode.DRAFT_EXTEND batch.forward_mode = ForwardMode.DRAFT_EXTEND
if batch.spec_info.has_finished:
index = batch.spec_info.unfinished_index
seq_lens = batch.seq_lens
batch.seq_lens = batch.seq_lens[index]
batch.spec_info.prepare_extend_after_decode(batch) batch.spec_info.prepare_extend_after_decode(batch)
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST
logits_output = self.model_runner.forward(forward_batch) logits_output = self.model_runner.forward(forward_batch)
batch.spec_info.hidden_states = logits_output.hidden_states
self.capture_for_decode(logits_output, forward_batch) self.capture_for_decode(logits_output, forward_batch)
batch.forward_mode = ForwardMode.DECODE
if batch.spec_info.has_finished:
batch.seq_lens = seq_lens
self._set_mem_pool(batch, self.target_worker.model_runner) self._set_mem_pool(batch, self.target_worker.model_runner)
# Restore backup.
# This is because `seq_lens` can be modified in `prepare_extend_after_decode`
batch.forward_mode = ForwardMode.DECODE
batch.seq_lens = seq_lens_backup
def capture_for_decode( def capture_for_decode(
self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
): ):
......
...@@ -1442,3 +1442,10 @@ def is_valid_ipv6_address(address: str) -> bool: ...@@ -1442,3 +1442,10 @@ def is_valid_ipv6_address(address: str) -> bool:
return True return True
except ValueError: except ValueError:
return False return False
def rank0_print(msg: str):
from sglang.srt.distributed import get_tensor_model_parallel_rank
if get_tensor_model_parallel_rank() == 0:
print(msg, flush=True)
...@@ -535,7 +535,8 @@ def test_hellaswag_select(): ...@@ -535,7 +535,8 @@ def test_hellaswag_select():
# Compute accuracy # Compute accuracy
accuracy_gen = np.mean(np.array(preds_gen) == np.array(labels)) accuracy_gen = np.mean(np.array(preds_gen) == np.array(labels))
assert np.abs(accuracy_gen - accuracy) < 0.1 print(f"{accuracy=}, {accuracy_gen=}")
assert np.abs(accuracy_gen - accuracy) < 0.05
assert np.abs(latency_gen - latency) < 1 assert np.abs(latency_gen - latency) < 1
return accuracy, latency return accuracy, latency
......
...@@ -567,15 +567,16 @@ def run_bench_serving( ...@@ -567,15 +567,16 @@ def run_bench_serving(
random_range_ratio=0.0, random_range_ratio=0.0,
request_rate=request_rate, request_rate=request_rate,
multi=None, multi=None,
seed=0,
output_file=None, output_file=None,
disable_tqdm=False, disable_tqdm=False,
disable_stream=disable_stream, disable_stream=disable_stream,
disable_ignore_eos=False,
return_logprob=False, return_logprob=False,
lora_name=None, seed=0,
disable_ignore_eos=False,
extra_request_body=None, extra_request_body=None,
apply_chat_template=False,
profile=None, profile=None,
lora_name=None,
) )
try: try:
......
""" """
Usage: Usage:
python3 -m unittest test_srt_backend.TestSRTBackend.test_gen_min_new_tokens python3 -m unittest test_srt_backend.TestSRTBackend.test_gen_min_new_tokens
python3 -m unittest test_srt_backend.TestSRTBackend.test_hellaswag_select
""" """
import unittest import unittest
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment