Unverified Commit 40b8363b authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[MRV2] Use fp32 for draft logits (#37526)


Signed-off-by: default avatarWoosuk Kwon <woosuk@inferact.ai>
parent 8b10e4fb
......@@ -195,7 +195,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_speculative_steps=self.num_speculative_steps,
vocab_size=self.vocab_size,
device=self.device,
model_dtype=self.dtype,
cache_draft_logits=not use_strict_rejection_sampling,
)
self.input_buffers = InputBuffers(
......
......@@ -15,7 +15,6 @@ class RequestState:
num_speculative_steps: int,
vocab_size: int,
device: torch.device,
model_dtype: torch.dtype,
cache_draft_logits: bool,
):
self.max_num_reqs = max_num_reqs
......@@ -81,7 +80,7 @@ class RequestState:
self.max_num_reqs,
self.num_speculative_steps,
self.vocab_size,
dtype=model_dtype,
dtype=torch.float32,
device=device,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment