Unverified Commit 6c903611 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

Fix incorrect spec_num_draft_tokens in draft_extend (#7757)

parent 77cfea68
...@@ -237,6 +237,10 @@ def _dp_gather( ...@@ -237,6 +237,10 @@ def _dp_gather(
assert ( assert (
local_tokens.untyped_storage() is not global_tokens.untyped_storage() local_tokens.untyped_storage() is not global_tokens.untyped_storage()
), "aliasing between global_tokens and local_tokens not allowed" ), "aliasing between global_tokens and local_tokens not allowed"
# NOTE: During draft extend, the gathered_buffer is padded to num_tokens * (speculative_num_steps + 1).
# But the size of local_tokens is total accepted tokens. We need to reduce the local_num_tokens to the
# actual size of the accepted tokens.
if forward_batch.forward_mode.is_draft_extend(): if forward_batch.forward_mode.is_draft_extend():
shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0]) shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0])
local_num_tokens = torch.minimum(local_num_tokens, shape_tensor) local_num_tokens = torch.minimum(local_num_tokens, shape_tensor)
...@@ -291,6 +295,10 @@ def dp_scatter( ...@@ -291,6 +295,10 @@ def dp_scatter(
assert ( assert (
local_tokens.untyped_storage() is not global_tokens.untyped_storage() local_tokens.untyped_storage() is not global_tokens.untyped_storage()
), "aliasing between local_tokens and global_tokens not allowed" ), "aliasing between local_tokens and global_tokens not allowed"
# NOTE: During draft extend, the gathered_buffer is padded to num_tokens * (speculative_num_steps + 1).
# But the size of local_tokens is total accepted tokens. We need to reduce the local_num_tokens to the
# actual size of the accepted tokens.
if forward_batch.forward_mode.is_draft_extend(): if forward_batch.forward_mode.is_draft_extend():
shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0]) shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0])
local_num_tokens = torch.minimum(local_num_tokens, shape_tensor) local_num_tokens = torch.minimum(local_num_tokens, shape_tensor)
......
...@@ -844,7 +844,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -844,7 +844,7 @@ class EAGLEWorker(TpModelWorker):
) )
batch.return_hidden_states = False batch.return_hidden_states = False
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
model_worker_batch.spec_num_draft_tokens = self.speculative_num_draft_tokens model_worker_batch.spec_num_draft_tokens = self.speculative_num_steps + 1
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
forward_batch = ForwardBatch.init_new( forward_batch = ForwardBatch.init_new(
model_worker_batch, self.draft_model_runner model_worker_batch, self.draft_model_runner
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment