Unverified Commit f8d48fd3 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

Fix dtype for idle input in spec decoding (#7456)

parent 34b6b842
...@@ -89,14 +89,13 @@ class EagleDraftInput: ...@@ -89,14 +89,13 @@ class EagleDraftInput:
cls, cls,
device: torch.device, device: torch.device,
hidden_size: int, hidden_size: int,
dtype: torch.dtype,
topk: int, topk: int,
capture_hidden_mode: CaptureHiddenMode, capture_hidden_mode: CaptureHiddenMode,
): ):
return cls( return cls(
verified_id=None, verified_id=None,
hidden_states=torch.empty( hidden_states=torch.empty((0, hidden_size), device=device, dtype=dtype),
(0, hidden_size), device=device, dtype=torch.float32
),
topk_p=torch.empty((0, topk), device=device, dtype=torch.float32), topk_p=torch.empty((0, topk), device=device, dtype=torch.float32),
topk_index=torch.empty((0, topk), device=device, dtype=torch.int64), topk_index=torch.empty((0, topk), device=device, dtype=torch.int64),
capture_hidden_mode=capture_hidden_mode, capture_hidden_mode=capture_hidden_mode,
...@@ -334,6 +333,7 @@ class EagleVerifyInput: ...@@ -334,6 +333,7 @@ class EagleVerifyInput:
draft_input=EagleDraftInput.create_idle_input( draft_input=EagleDraftInput.create_idle_input(
device=batch.device, device=batch.device,
hidden_size=batch.model_config.hidden_size, hidden_size=batch.model_config.hidden_size,
dtype=batch.model_config.dtype,
topk=self.topk, topk=self.topk,
capture_hidden_mode=CaptureHiddenMode.LAST, capture_hidden_mode=CaptureHiddenMode.LAST,
), ),
......
...@@ -498,6 +498,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -498,6 +498,7 @@ class EAGLEWorker(TpModelWorker):
batch.spec_info = EagleDraftInput.create_idle_input( batch.spec_info = EagleDraftInput.create_idle_input(
device=self.device, device=self.device,
hidden_size=self.model_config.hidden_size, hidden_size=self.model_config.hidden_size,
dtype=self.model_config.dtype,
topk=self.topk, topk=self.topk,
capture_hidden_mode=CaptureHiddenMode.LAST, capture_hidden_mode=CaptureHiddenMode.LAST,
) )
...@@ -838,6 +839,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -838,6 +839,7 @@ class EAGLEWorker(TpModelWorker):
batch.spec_info = EagleDraftInput.create_idle_input( batch.spec_info = EagleDraftInput.create_idle_input(
device=self.device, device=self.device,
hidden_size=self.model_config.hidden_size, hidden_size=self.model_config.hidden_size,
dtype=self.model_config.dtype,
topk=self.topk, topk=self.topk,
capture_hidden_mode=CaptureHiddenMode.LAST, capture_hidden_mode=CaptureHiddenMode.LAST,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment