Unverified Commit 019851d0 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix eagle on AMD (#7051)

parent 2dae104d
...@@ -123,6 +123,9 @@ class EagleDraftInput: ...@@ -123,6 +123,9 @@ class EagleDraftInput:
cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda") cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0) cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
if paged_kernel_lens_sum is None:
paged_kernel_lens_sum = cum_kv_seq_len[-1]
kv_indices = torch.empty( kv_indices = torch.empty(
paged_kernel_lens_sum, dtype=torch.int32, device="cuda" paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
) )
......
...@@ -194,7 +194,7 @@ class TestBenchServing(CustomTestCase): ...@@ -194,7 +194,7 @@ class TestBenchServing(CustomTestCase):
self.assertLess(res["median_ttft_ms"], 150) self.assertLess(res["median_ttft_ms"], 150)
# TODO: not set yet, need AMD machine # TODO: not set yet, need AMD machine
else: else:
self.assertLess(res["median_ttft_ms"], 94) self.assertLess(res["median_ttft_ms"], 98)
self.assertLess(res["median_itl_ms"], 8) self.assertLess(res["median_itl_ms"], 8)
def test_online_latency_eagle(self): def test_online_latency_eagle(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment