Unverified Commit 1801cd19 authored by narutolhy's avatar narutolhy Committed by GitHub
Browse files

support more model in piecewise cuda graph (#11745)

parent ffc722a6
...@@ -142,8 +142,11 @@ def unified_attention_with_output( ...@@ -142,8 +142,11 @@ def unified_attention_with_output(
ret = forward_batch.attn_backend.forward( ret = forward_batch.attn_backend.forward(
query, key, value, attention_layer, forward_batch, save_kv_cache query, key, value, attention_layer, forward_batch, save_kv_cache
) )
assert output.shape == ret.shape assert (
output.copy_(ret) output.numel() == ret.numel()
), f"Output tensor element mismatch: {output.numel()} != {ret.numel()}"
output.view(ret.shape).copy_(ret)
return return
......
...@@ -262,8 +262,13 @@ class PiecewiseCudaGraphRunner: ...@@ -262,8 +262,13 @@ class PiecewiseCudaGraphRunner:
def can_run(self, forward_batch: ForwardBatch): def can_run(self, forward_batch: ForwardBatch):
num_tokens = len(forward_batch.input_ids) num_tokens = len(forward_batch.input_ids)
# TODO(yuwei): support return logprob # TODO(yuwei): support return input_ids' logprob
if forward_batch.return_logprob: if forward_batch.return_logprob:
for start_len, seq_len in zip(
forward_batch.extend_logprob_start_lens_cpu,
forward_batch.extend_seq_lens_cpu,
):
if start_len is not None and start_len < seq_len:
return False return False
if num_tokens <= self.max_num_tokens: if num_tokens <= self.max_num_tokens:
return True return True
...@@ -438,7 +443,7 @@ class PiecewiseCudaGraphRunner: ...@@ -438,7 +443,7 @@ class PiecewiseCudaGraphRunner:
out_cache_loc=out_cache_loc, out_cache_loc=out_cache_loc,
seq_lens_sum=forward_batch.seq_lens_sum, seq_lens_sum=forward_batch.seq_lens_sum,
encoder_lens=forward_batch.encoder_lens, encoder_lens=forward_batch.encoder_lens,
return_logprob=forward_batch.return_logprob, return_logprob=False,
extend_seq_lens=forward_batch.extend_seq_lens, extend_seq_lens=forward_batch.extend_seq_lens,
extend_prefix_lens=forward_batch.extend_prefix_lens, extend_prefix_lens=forward_batch.extend_prefix_lens,
extend_start_loc=forward_batch.extend_start_loc, extend_start_loc=forward_batch.extend_start_loc,
......
...@@ -44,6 +44,18 @@ class TestPiecewiseCudaGraphCorrectness(CustomTestCase): ...@@ -44,6 +44,18 @@ class TestPiecewiseCudaGraphCorrectness(CustomTestCase):
metrics = run_eval(args) metrics = run_eval(args)
self.assertGreaterEqual(metrics["score"], 0.235) self.assertGreaterEqual(metrics["score"], 0.235)
def test_mmlu(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
)
metrics = run_eval(args)
self.assertGreaterEqual(metrics["score"], 0.65)
class TestPiecewiseCudaGraphBenchmark(CustomTestCase): class TestPiecewiseCudaGraphBenchmark(CustomTestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment