Unverified Commit 1801cd19 authored by narutolhy's avatar narutolhy Committed by GitHub
Browse files

support more model in piecewise cuda graph (#11745)

parent ffc722a6
......@@ -142,8 +142,11 @@ def unified_attention_with_output(
ret = forward_batch.attn_backend.forward(
query, key, value, attention_layer, forward_batch, save_kv_cache
)
assert output.shape == ret.shape
output.copy_(ret)
assert (
output.numel() == ret.numel()
), f"Output tensor element mismatch: {output.numel()} != {ret.numel()}"
output.view(ret.shape).copy_(ret)
return
......
......@@ -262,9 +262,14 @@ class PiecewiseCudaGraphRunner:
def can_run(self, forward_batch: ForwardBatch):
num_tokens = len(forward_batch.input_ids)
# TODO(yuwei): support return logprob
# TODO(yuwei): support return input_ids' logprob
if forward_batch.return_logprob:
return False
for start_len, seq_len in zip(
forward_batch.extend_logprob_start_lens_cpu,
forward_batch.extend_seq_lens_cpu,
):
if start_len is not None and start_len < seq_len:
return False
if num_tokens <= self.max_num_tokens:
return True
return False
......@@ -438,7 +443,7 @@ class PiecewiseCudaGraphRunner:
out_cache_loc=out_cache_loc,
seq_lens_sum=forward_batch.seq_lens_sum,
encoder_lens=forward_batch.encoder_lens,
return_logprob=forward_batch.return_logprob,
return_logprob=False,
extend_seq_lens=forward_batch.extend_seq_lens,
extend_prefix_lens=forward_batch.extend_prefix_lens,
extend_start_loc=forward_batch.extend_start_loc,
......
......@@ -44,6 +44,18 @@ class TestPiecewiseCudaGraphCorrectness(CustomTestCase):
metrics = run_eval(args)
self.assertGreaterEqual(metrics["score"], 0.235)
def test_mmlu(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
)
metrics = run_eval(args)
self.assertGreaterEqual(metrics["score"], 0.65)
class TestPiecewiseCudaGraphBenchmark(CustomTestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment