"tests/git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "cdfbdc0bd108a315ff8bf5ae5ef877084f30336a"
Unverified Commit 022f6515 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Fixing graph capture for flash decoding. (#2163)

parent 4327210e
......@@ -926,7 +926,7 @@ class FlashCausalLM(Model):
"slots": slots,
"input_lengths": input_lengths,
}
input_lengths = Seqlen(input_lengths=input_lengths)
input_lengths_ = Seqlen(input_lengths=input_lengths)
graph = torch.cuda.CUDAGraph()
self.cuda_graphs[bs]["graph"] = graph
......@@ -939,7 +939,7 @@ class FlashCausalLM(Model):
kv_cache=self.kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
input_lengths=input_lengths_,
max_s=max_s,
prefill_cache_indices=None,
lm_head_indices=None,
......@@ -947,6 +947,7 @@ class FlashCausalLM(Model):
torch.cuda.synchronize()
with torch.cuda.graph(graph, pool=MEM_POOL):
input_lengths = Seqlen(input_lengths=input_lengths)
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment