Unverified Commit 25e7dbe8 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Fix ngram spec with page size > 1 (#11135)

parent 0b2aa8a7
......@@ -1229,7 +1229,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
self.device, non_blocking=True
)
seq_lens_cpu_tensor = torch.tensor(seq_lens, dtype=torch.int64)
seq_lens_cpu = torch.tensor(seq_lens, dtype=torch.int64)
orig_seq_lens_tensor = torch.tensor(orig_seq_lens, dtype=torch.int32).to(
self.device, non_blocking=True
)
......@@ -1366,7 +1366,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
prefix_lens_tensor,
prefix_lens_cpu_tensor,
seq_lens_tensor,
seq_lens_cpu_tensor,
seq_lens_cpu,
last_loc,
extend_num_tokens,
)
......@@ -1375,7 +1375,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.input_ids = input_ids_tensor
self.req_pool_indices = req_pool_indices_tensor
self.seq_lens = seq_lens_tensor
self.seq_lens_cpu = seq_lens_cpu_tensor
self.seq_lens_cpu = seq_lens_cpu
self.orig_seq_lens = orig_seq_lens_tensor
self.out_cache_loc = out_cache_loc
self.input_embeds = (
......
......@@ -1087,7 +1087,10 @@ class ServerArgs:
and self.attention_backend != "flashinfer"
):
raise ValueError(
"speculative_eagle_topk > 1 with page_size > 1 is unstable and produces incorrect results for paged attention backends. This combination is only supported for the 'flashinfer' backend."
f"speculative_eagle_topk({self.speculative_eagle_topk}) > 1 "
f"with page_size({self.page_size}) > 1 is unstable "
"and produces incorrect results for paged attention backends. "
"This combination is only supported for the 'flashinfer' backend."
)
if self.enable_dp_attention:
# TODO: support dp attention for ngram speculative decoding
......
......@@ -388,6 +388,8 @@ class EagleVerifyInput(SpecInput):
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
evict_mask[accept_index] = False
accept_length_cpu = accept_length.cpu()
# FIXME: this `tolist()` fixes the numerical calculation consistency
# try to unify the tensor representation and list representation
accept_length_list = accept_length_cpu.tolist()
if page_size == 1:
......
......@@ -79,14 +79,21 @@ class NgramVerifyInput(SpecInput):
else:
# TODO(lsyin): add prefix lens cpu here to support page size > 1
prefix_lens = batch.seq_lens
prefix_lens_cpu = batch.seq_lens_cpu
end_offset = prefix_lens + self.draft_token_num
end_offset_cpu = prefix_lens_cpu + self.draft_token_num
last_loc = get_last_loc(
batch.req_to_token_pool.req_to_token,
batch.req_pool_indices,
prefix_lens,
)
batch.out_cache_loc = batch.alloc_paged_token_slots_extend(
prefix_lens, end_offset, last_loc, len(batch.input_ids)
prefix_lens,
prefix_lens_cpu,
end_offset,
end_offset_cpu,
last_loc,
len(batch.input_ids),
)
self.last_loc = last_loc
......
......@@ -31,7 +31,7 @@ DEFAULT_SERVER_ARGS = [
]
class TestStandaloneSpeculativeDecodingBase(CustomTestCase):
class TestNgramSpeculativeDecodingBase(CustomTestCase):
model = DEFAULT_NGRAM_SPECULATIVE_TARGET_MODEL_FOR_TEST
base_url = DEFAULT_URL_FOR_TEST
......@@ -88,20 +88,30 @@ class TestStandaloneSpeculativeDecodingBase(CustomTestCase):
self.assertGreater(avg_spec_accept_length, self.spec_decode_threshold)
class TestStandaloneSpeculativeDecodingTriton(TestStandaloneSpeculativeDecodingBase):
class TestNgramSpeculativeDecodingTriton(TestNgramSpeculativeDecodingBase):
@classmethod
def get_server_args(cls):
return DEFAULT_SERVER_ARGS + ["--attention-backend", "triton"]
class TestStandaloneSpeculativeDecodingFlashinfer(
TestStandaloneSpeculativeDecodingBase
):
class TestNgramSpeculativeDecodingFlashinfer(TestNgramSpeculativeDecodingBase):
@classmethod
def get_server_args(cls):
return DEFAULT_SERVER_ARGS + ["--attention-backend", "flashinfer"]
class TestNgramSpeculativeDecodingPaged(TestNgramSpeculativeDecodingBase):
@classmethod
def get_server_args(cls):
return DEFAULT_SERVER_ARGS + [
"--attention-backend",
"flashinfer",
"--page-size",
"64",
]
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment