Unverified Commit 99aefa03 authored by Jay Zhou's avatar Jay Zhou Committed by GitHub
Browse files

Fix eagle3 cuda graph (#8163)

parent bbcfbc1a
...@@ -84,7 +84,15 @@ class EAGLEDraftExtendCudaGraphRunner: ...@@ -84,7 +84,15 @@ class EAGLEDraftExtendCudaGraphRunner:
self.hidden_states = torch.zeros( self.hidden_states = torch.zeros(
( (
self.max_num_token, self.max_num_token,
self.model_runner.model_config.hidden_size * 3, (
self.model_runner.model_config.hf_config.target_hidden_size
* 3
if hasattr(
self.model_runner.model_config.hf_config,
"target_hidden_size",
)
else self.model_runner.model_config.hidden_size * 3
),
), ),
dtype=self.model_runner.dtype, dtype=self.model_runner.dtype,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment