Unverified Commit 25e7dbe8 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Fix ngram spec with page size > 1 (#11135)

parent 0b2aa8a7
...@@ -1229,7 +1229,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1229,7 +1229,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to( seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
self.device, non_blocking=True self.device, non_blocking=True
) )
seq_lens_cpu_tensor = torch.tensor(seq_lens, dtype=torch.int64) seq_lens_cpu = torch.tensor(seq_lens, dtype=torch.int64)
orig_seq_lens_tensor = torch.tensor(orig_seq_lens, dtype=torch.int32).to( orig_seq_lens_tensor = torch.tensor(orig_seq_lens, dtype=torch.int32).to(
self.device, non_blocking=True self.device, non_blocking=True
) )
...@@ -1366,7 +1366,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1366,7 +1366,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
prefix_lens_tensor, prefix_lens_tensor,
prefix_lens_cpu_tensor, prefix_lens_cpu_tensor,
seq_lens_tensor, seq_lens_tensor,
seq_lens_cpu_tensor, seq_lens_cpu,
last_loc, last_loc,
extend_num_tokens, extend_num_tokens,
) )
...@@ -1375,7 +1375,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1375,7 +1375,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.input_ids = input_ids_tensor self.input_ids = input_ids_tensor
self.req_pool_indices = req_pool_indices_tensor self.req_pool_indices = req_pool_indices_tensor
self.seq_lens = seq_lens_tensor self.seq_lens = seq_lens_tensor
self.seq_lens_cpu = seq_lens_cpu_tensor self.seq_lens_cpu = seq_lens_cpu
self.orig_seq_lens = orig_seq_lens_tensor self.orig_seq_lens = orig_seq_lens_tensor
self.out_cache_loc = out_cache_loc self.out_cache_loc = out_cache_loc
self.input_embeds = ( self.input_embeds = (
......
...@@ -1087,7 +1087,10 @@ class ServerArgs: ...@@ -1087,7 +1087,10 @@ class ServerArgs:
and self.attention_backend != "flashinfer" and self.attention_backend != "flashinfer"
): ):
raise ValueError( raise ValueError(
"speculative_eagle_topk > 1 with page_size > 1 is unstable and produces incorrect results for paged attention backends. This combination is only supported for the 'flashinfer' backend." f"speculative_eagle_topk({self.speculative_eagle_topk}) > 1 "
f"with page_size({self.page_size}) > 1 is unstable "
"and produces incorrect results for paged attention backends. "
"This combination is only supported for the 'flashinfer' backend."
) )
if self.enable_dp_attention: if self.enable_dp_attention:
# TODO: support dp attention for ngram speculative decoding # TODO: support dp attention for ngram speculative decoding
......
...@@ -388,6 +388,8 @@ class EagleVerifyInput(SpecInput): ...@@ -388,6 +388,8 @@ class EagleVerifyInput(SpecInput):
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool) evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
evict_mask[accept_index] = False evict_mask[accept_index] = False
accept_length_cpu = accept_length.cpu() accept_length_cpu = accept_length.cpu()
# FIXME: this `tolist()` fixes the numerical calculation consistency
# try to unify the tensor representation and list representation
accept_length_list = accept_length_cpu.tolist() accept_length_list = accept_length_cpu.tolist()
if page_size == 1: if page_size == 1:
......
...@@ -79,14 +79,21 @@ class NgramVerifyInput(SpecInput): ...@@ -79,14 +79,21 @@ class NgramVerifyInput(SpecInput):
else: else:
# TODO(lsyin): add prefix lens cpu here to support page size > 1 # TODO(lsyin): add prefix lens cpu here to support page size > 1
prefix_lens = batch.seq_lens prefix_lens = batch.seq_lens
prefix_lens_cpu = batch.seq_lens_cpu
end_offset = prefix_lens + self.draft_token_num end_offset = prefix_lens + self.draft_token_num
end_offset_cpu = prefix_lens_cpu + self.draft_token_num
last_loc = get_last_loc( last_loc = get_last_loc(
batch.req_to_token_pool.req_to_token, batch.req_to_token_pool.req_to_token,
batch.req_pool_indices, batch.req_pool_indices,
prefix_lens, prefix_lens,
) )
batch.out_cache_loc = batch.alloc_paged_token_slots_extend( batch.out_cache_loc = batch.alloc_paged_token_slots_extend(
prefix_lens, end_offset, last_loc, len(batch.input_ids) prefix_lens,
prefix_lens_cpu,
end_offset,
end_offset_cpu,
last_loc,
len(batch.input_ids),
) )
self.last_loc = last_loc self.last_loc = last_loc
......
...@@ -31,7 +31,7 @@ DEFAULT_SERVER_ARGS = [ ...@@ -31,7 +31,7 @@ DEFAULT_SERVER_ARGS = [
] ]
class TestStandaloneSpeculativeDecodingBase(CustomTestCase): class TestNgramSpeculativeDecodingBase(CustomTestCase):
model = DEFAULT_NGRAM_SPECULATIVE_TARGET_MODEL_FOR_TEST model = DEFAULT_NGRAM_SPECULATIVE_TARGET_MODEL_FOR_TEST
base_url = DEFAULT_URL_FOR_TEST base_url = DEFAULT_URL_FOR_TEST
...@@ -88,20 +88,30 @@ class TestStandaloneSpeculativeDecodingBase(CustomTestCase): ...@@ -88,20 +88,30 @@ class TestStandaloneSpeculativeDecodingBase(CustomTestCase):
self.assertGreater(avg_spec_accept_length, self.spec_decode_threshold) self.assertGreater(avg_spec_accept_length, self.spec_decode_threshold)
class TestStandaloneSpeculativeDecodingTriton(TestStandaloneSpeculativeDecodingBase): class TestNgramSpeculativeDecodingTriton(TestNgramSpeculativeDecodingBase):
@classmethod @classmethod
def get_server_args(cls): def get_server_args(cls):
return DEFAULT_SERVER_ARGS + ["--attention-backend", "triton"] return DEFAULT_SERVER_ARGS + ["--attention-backend", "triton"]
class TestStandaloneSpeculativeDecodingFlashinfer( class TestNgramSpeculativeDecodingFlashinfer(TestNgramSpeculativeDecodingBase):
TestStandaloneSpeculativeDecodingBase
):
@classmethod @classmethod
def get_server_args(cls): def get_server_args(cls):
return DEFAULT_SERVER_ARGS + ["--attention-backend", "flashinfer"] return DEFAULT_SERVER_ARGS + ["--attention-backend", "flashinfer"]
class TestNgramSpeculativeDecodingPaged(TestNgramSpeculativeDecodingBase):
@classmethod
def get_server_args(cls):
return DEFAULT_SERVER_ARGS + [
"--attention-backend",
"flashinfer",
"--page-size",
"64",
]
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment