Unverified Commit 7b4e61ff authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

[Fix] Fix eagle with disable cuda graph (#3411)

parent 6222e1c2
...@@ -924,38 +924,50 @@ class FlashInferMultiStepDraftBackend: ...@@ -924,38 +924,50 @@ class FlashInferMultiStepDraftBackend:
self.max_context_len = self.attn_backends[0].max_context_len self.max_context_len = self.attn_backends[0].max_context_len
# Cached variables for generate_draft_decode_kv_indices # Cached variables for generate_draft_decode_kv_indices
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1] self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
self.kv_indptr_stride = self.kv_indptr.shape[1]
def common_template(self, forward_batch: ForwardBatch, call_fn: int): def common_template(
self, forward_batch: ForwardBatch, kv_indices_buffer: torch.Tensor, call_fn: int
):
num_seqs = forward_batch.batch_size num_seqs = forward_batch.batch_size
bs = self.topk * num_seqs bs = self.topk * num_seqs
seq_lens_sum = forward_batch.seq_lens_sum seq_lens_sum = forward_batch.seq_lens_sum
self.generate_draft_decode_kv_indices[ self.generate_draft_decode_kv_indices[
(self.speculative_num_steps, num_seqs, self.topk) (self.speculative_num_steps, num_seqs, self.topk)
]( ](
forward_batch.req_pool_indices, forward_batch.req_pool_indices,
forward_batch.req_to_token_pool.req_to_token, forward_batch.req_to_token_pool.req_to_token,
forward_batch.seq_lens, forward_batch.seq_lens,
self.cuda_graph_kv_indices, kv_indices_buffer,
self.kv_indptr, self.kv_indptr,
forward_batch.positions, forward_batch.positions,
num_seqs, num_seqs,
self.topk, self.topk,
self.pool_len, self.pool_len,
self.kv_indptr_stride, kv_indices_buffer.shape[1],
self.kv_indptr.shape[1], self.kv_indptr.shape[1],
triton.next_power_of_2(num_seqs), triton.next_power_of_2(num_seqs),
triton.next_power_of_2(self.speculative_num_steps), triton.next_power_of_2(self.speculative_num_steps),
triton.next_power_of_2(bs), triton.next_power_of_2(bs),
) )
for i in range(self.speculative_num_steps): for i in range(self.speculative_num_steps):
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1] forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
forward_batch.spec_info.kv_indices = self.cuda_graph_kv_indices[i][ forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
: seq_lens_sum * self.topk + bs * (i + 1) : seq_lens_sum * self.topk + bs * (i + 1)
] ]
call_fn(i, forward_batch) call_fn(i, forward_batch)
def init_forward_metadata(self, forward_batch: ForwardBatch): def init_forward_metadata(self, forward_batch: ForwardBatch):
kv_indices = torch.zeros(
(
self.speculative_num_steps,
forward_batch.batch_size * self.topk * self.max_context_len,
),
dtype=torch.int32,
device="cuda",
)
def call_fn(i, forward_batch): def call_fn(i, forward_batch):
forward_batch.spec_info.kv_indptr = ( forward_batch.spec_info.kv_indptr = (
forward_batch.spec_info.kv_indptr.clone() forward_batch.spec_info.kv_indptr.clone()
...@@ -965,7 +977,7 @@ class FlashInferMultiStepDraftBackend: ...@@ -965,7 +977,7 @@ class FlashInferMultiStepDraftBackend:
) )
self.attn_backends[i].init_forward_metadata(forward_batch) self.attn_backends[i].init_forward_metadata(forward_batch)
self.common_template(forward_batch, call_fn) self.common_template(forward_batch, kv_indices, call_fn)
def init_cuda_graph_state(self, max_bs: int): def init_cuda_graph_state(self, max_bs: int):
self.cuda_graph_kv_indices = torch.zeros( self.cuda_graph_kv_indices = torch.zeros(
...@@ -973,7 +985,6 @@ class FlashInferMultiStepDraftBackend: ...@@ -973,7 +985,6 @@ class FlashInferMultiStepDraftBackend:
dtype=torch.int32, dtype=torch.int32,
device="cuda", device="cuda",
) )
self.kv_indptr_stride = self.cuda_graph_kv_indices.shape[1]
for i in range(self.speculative_num_steps): for i in range(self.speculative_num_steps):
self.attn_backends[i].init_cuda_graph_state( self.attn_backends[i].init_cuda_graph_state(
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i] max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
...@@ -995,7 +1006,7 @@ class FlashInferMultiStepDraftBackend: ...@@ -995,7 +1006,7 @@ class FlashInferMultiStepDraftBackend:
][0] ][0]
decode_wrapper.begin_forward = partial(fast_decode_plan, decode_wrapper) decode_wrapper.begin_forward = partial(fast_decode_plan, decode_wrapper)
self.common_template(forward_batch, call_fn) self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
def init_forward_metadata_replay_cuda_graph(self, forward_batch): def init_forward_metadata_replay_cuda_graph(self, forward_batch):
def call_fn(i, forward_batch): def call_fn(i, forward_batch):
...@@ -1009,7 +1020,7 @@ class FlashInferMultiStepDraftBackend: ...@@ -1009,7 +1020,7 @@ class FlashInferMultiStepDraftBackend:
spec_info=forward_batch.spec_info, spec_info=forward_batch.spec_info,
) )
self.common_template(forward_batch, call_fn) self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
@triton.jit @triton.jit
......
...@@ -22,12 +22,12 @@ from sglang.test.test_utils import ( ...@@ -22,12 +22,12 @@ from sglang.test.test_utils import (
class TestEAGLEEngine(unittest.TestCase): class TestEAGLEEngine(unittest.TestCase):
def test_eagle_accuracy(self): def test_eagle_accuracy(self):
prompt = "Today is a sunny day and I like" prompt1 = "Today is a sunny day and I like"
sampling_params = {"temperature": 0, "max_new_tokens": 8} sampling_params1 = {"temperature": 0, "max_new_tokens": 8}
# Get the reference output # Get the reference output
ref_engine = sgl.Engine(model_path=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST) ref_engine = sgl.Engine(model_path=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST)
ref_output = ref_engine.generate(prompt, sampling_params)["text"] ref_output = ref_engine.generate(prompt1, sampling_params1)["text"]
ref_engine.shutdown() ref_engine.shutdown()
# Test cases with different configurations # Test cases with different configurations
...@@ -60,20 +60,20 @@ class TestEAGLEEngine(unittest.TestCase): ...@@ -60,20 +60,20 @@ class TestEAGLEEngine(unittest.TestCase):
engine = sgl.Engine(**config) engine = sgl.Engine(**config)
# Case 1: Test the output of EAGLE engine is the same as normal engine # Case 1: Test the output of EAGLE engine is the same as normal engine
out1 = engine.generate(prompt, sampling_params)["text"] out1 = engine.generate(prompt1, sampling_params1)["text"]
print(f"{out1=}, {ref_output=}") print(f"{out1=}, {ref_output=}")
self.assertEqual(out1, ref_output) self.assertEqual(out1, ref_output)
# Case 2: Test the output of EAGLE engine does not contain unexpected EOS # Case 2: Test the output of EAGLE engine does not contain unexpected EOS
prompt = "[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nToday is a sunny day and I like [/INST]" prompt2 = "[INST] <<SYS>>\\nYou are a helpful assistant.\\n<</SYS>>\\nToday is a sunny day and I like [/INST]"
sampling_params = { sampling_params2 = {
"temperature": 0, "temperature": 0,
"max_new_tokens": 1024, "max_new_tokens": 1024,
"skip_special_tokens": False, "skip_special_tokens": False,
} }
tokenizer = get_tokenizer(DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST) tokenizer = get_tokenizer(DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST)
out2 = engine.generate(prompt, sampling_params)["text"] out2 = engine.generate(prompt2, sampling_params2)["text"]
print(f"{out2=}") print(f"{out2=}")
tokens = tokenizer.encode(out2, truncation=False) tokens = tokenizer.encode(out2, truncation=False)
assert tokenizer.eos_token_id not in tokens assert tokenizer.eos_token_id not in tokens
...@@ -85,8 +85,8 @@ class TestEAGLEEngine(unittest.TestCase): ...@@ -85,8 +85,8 @@ class TestEAGLEEngine(unittest.TestCase):
"The capital of France is", "The capital of France is",
"The future of AI is", "The future of AI is",
] ]
sampling_params = {"temperature": 0, "max_new_tokens": 30} sampling_params3 = {"temperature": 0, "max_new_tokens": 30}
outputs = engine.generate(prompts, sampling_params) outputs = engine.generate(prompts, sampling_params3)
for prompt, output in zip(prompts, outputs): for prompt, output in zip(prompts, outputs):
print("===============================") print("===============================")
print(f"Prompt: {prompt}\nGenerated text: {output['text']}") print(f"Prompt: {prompt}\nGenerated text: {output['text']}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment