Unverified Commit 0e7a5b26 authored by J's avatar J Committed by GitHub
Browse files

fix: prevent crashes due to logit bias dimension mismatch (#7685)

parent 4953f4ca
...@@ -322,6 +322,12 @@ class SamplingBatchInfo: ...@@ -322,6 +322,12 @@ class SamplingBatchInfo:
# Set the flag to True if any of the two has custom logit processor # Set the flag to True if any of the two has custom logit processor
self.has_custom_logit_processor = True self.has_custom_logit_processor = True
# Merge logit bias - note this has to come before the temperatures tensor update! Otherwise will cause crashes.
# See note below on len(self) and len(other).
self.logit_bias = merge_bias_tensor(
self.logit_bias, other.logit_bias, len(self), len(other), self.device, 0.0
)
# Note: because the __len()__ operator is defined on the temperatures tensor, # Note: because the __len()__ operator is defined on the temperatures tensor,
# please make sure any merge operation with len(self) or len(other) is done before # please make sure any merge operation with len(self) or len(other) is done before
# the merge operation of the temperatures tensor below. # the merge operation of the temperatures tensor below.
...@@ -340,11 +346,6 @@ class SamplingBatchInfo: ...@@ -340,11 +346,6 @@ class SamplingBatchInfo:
self.need_top_k_sampling |= other.need_top_k_sampling self.need_top_k_sampling |= other.need_top_k_sampling
self.need_min_p_sampling |= other.need_min_p_sampling self.need_min_p_sampling |= other.need_min_p_sampling
# Merge logit bias
self.logit_bias = merge_bias_tensor(
self.logit_bias, other.logit_bias, len(self), len(other), self.device, 0.0
)
def merge_bias_tensor( def merge_bias_tensor(
lhs: Optional[torch.Tensor], lhs: Optional[torch.Tensor],
......
from __future__ import annotations from __future__ import annotations
import copy
import logging import logging
import os import os
import time import time
...@@ -362,6 +363,11 @@ class EagleVerifyInput: ...@@ -362,6 +363,11 @@ class EagleVerifyInput:
) )
accept_length = torch.empty((bs,), dtype=torch.int32, device="cuda") accept_length = torch.empty((bs,), dtype=torch.int32, device="cuda")
if bs != len(sampling_info):
sampling_info = copy.deepcopy(sampling_info)
# NOTE: retrive_index are the indices of the requests that are kept.
sampling_info.filter_batch(self.retrive_index.tolist(), self.retrive_index)
# Apply the custom logit processors if registered in the sampling info. # Apply the custom logit processors if registered in the sampling info.
if sampling_info.has_custom_logit_processor: if sampling_info.has_custom_logit_processor:
apply_custom_logit_processor( apply_custom_logit_processor(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment