Unverified Commit 99aefa03 authored by Jay Zhou's avatar Jay Zhou Committed by GitHub
Browse files

Fix eagle3 cuda graph (#8163)

parent bbcfbc1a
......@@ -84,7 +84,15 @@ class EAGLEDraftExtendCudaGraphRunner:
self.hidden_states = torch.zeros(
(
self.max_num_token,
self.model_runner.model_config.hidden_size * 3,
(
self.model_runner.model_config.hf_config.target_hidden_size
* 3
if hasattr(
self.model_runner.model_config.hf_config,
"target_hidden_size",
)
else self.model_runner.model_config.hidden_size * 3
),
),
dtype=self.model_runner.dtype,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment