Unverified Commit eb19ccad authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

[bug] fix errors related to context length in SD (#9388)

parent 25ef53f0
...@@ -32,6 +32,7 @@ from sglang.srt.hf_transformers_utils import ( ...@@ -32,6 +32,7 @@ from sglang.srt.hf_transformers_utils import (
from sglang.srt.layers.quantization import QUANTIZATION_METHODS from sglang.srt.layers.quantization import QUANTIZATION_METHODS
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_bool_env_var, is_hip from sglang.srt.utils import get_bool_env_var, is_hip
from sglang.utils import is_in_ci
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -166,19 +167,20 @@ class ModelConfig: ...@@ -166,19 +167,20 @@ class ModelConfig:
derived_context_len = get_context_length(self.hf_text_config) derived_context_len = get_context_length(self.hf_text_config)
if context_length is not None: if context_length is not None:
if context_length > derived_context_len: if context_length > derived_context_len:
if get_bool_env_var( reason = "Target model's" if is_draft_model else "User-specified"
"SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN", default="True" msg = (
f"Warning: {reason} context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). "
f"This may lead to incorrect model outputs or CUDA errors. Note that the derived context_length may differ from max_position_embeddings in the model's config."
)
if (
get_bool_env_var("SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN")
or is_in_ci() # FIXME: fix this special case
): ):
logger.warning( logger.warning(msg)
f"Warning: User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). "
f"This may lead to incorrect model outputs or CUDA errors."
)
self.context_len = context_length self.context_len = context_length
else: else:
raise ValueError( raise ValueError(
f"User-specified context_length ({context_length}) is greater than the derived context_length ({derived_context_len}). " f"{msg} To allow overriding this maximum, set the env var SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1"
f"This may lead to incorrect model outputs or CUDA errors. Note that the derived context_length may differ from max_position_embeddings in the model's config. "
f"To allow overriding this maximum, set the env var SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1"
) )
else: else:
self.context_len = context_length self.context_len = context_length
......
...@@ -576,7 +576,7 @@ class TokenizerManager: ...@@ -576,7 +576,7 @@ class TokenizerManager:
f"model's context length ({self.context_len} tokens). " f"model's context length ({self.context_len} tokens). "
"Truncating the input." "Truncating the input."
) )
input_ids = input_ids[:_max_req_len] del input_ids[_max_req_len:]
input_token_num = len(input_ids) input_token_num = len(input_ids)
else: else:
raise ValueError( raise ValueError(
......
...@@ -1236,6 +1236,11 @@ class ModelRunner: ...@@ -1236,6 +1236,11 @@ class ModelRunner:
# Initialize req_to_token_pool # Initialize req_to_token_pool
if self.req_to_token_pool is None: if self.req_to_token_pool is None:
# FIXME(lsyin): this is the temporary fix for the context length issue when using speculative decoding
extra_max_context_len = 4
if self.server_args.speculative_num_draft_tokens is not None:
extra_max_context_len += self.server_args.speculative_num_draft_tokens
if self.server_args.disaggregation_mode == "decode": if self.server_args.disaggregation_mode == "decode":
from sglang.srt.disaggregation.decode import DecodeReqToTokenPool from sglang.srt.disaggregation.decode import DecodeReqToTokenPool
...@@ -1244,7 +1249,8 @@ class ModelRunner: ...@@ -1244,7 +1249,8 @@ class ModelRunner:
pre_alloc_size = max_num_reqs * 2 if max_num_reqs <= 32 else 0 pre_alloc_size = max_num_reqs * 2 if max_num_reqs <= 32 else 0
self.req_to_token_pool = DecodeReqToTokenPool( self.req_to_token_pool = DecodeReqToTokenPool(
size=max_num_reqs, size=max_num_reqs,
max_context_len=self.model_config.context_len + 4, max_context_len=self.model_config.context_len
+ extra_max_context_len,
device=self.device, device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver, enable_memory_saver=self.server_args.enable_memory_saver,
pre_alloc_size=pre_alloc_size, pre_alloc_size=pre_alloc_size,
...@@ -1252,7 +1258,8 @@ class ModelRunner: ...@@ -1252,7 +1258,8 @@ class ModelRunner:
else: else:
self.req_to_token_pool = ReqToTokenPool( self.req_to_token_pool = ReqToTokenPool(
size=max_num_reqs, size=max_num_reqs,
max_context_len=self.model_config.context_len + 4, max_context_len=self.model_config.context_len
+ extra_max_context_len,
device=self.device, device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver, enable_memory_saver=self.server_args.enable_memory_saver,
) )
......
...@@ -41,6 +41,7 @@ class EAGLEDraftCudaGraphRunner: ...@@ -41,6 +41,7 @@ class EAGLEDraftCudaGraphRunner:
# Parse args # Parse args
self.eagle_worker = eagle_worker self.eagle_worker = eagle_worker
self.model_runner = model_runner = eagle_worker.model_runner self.model_runner = model_runner = eagle_worker.model_runner
self.model_runner: EAGLEWorker
self.graphs = {} self.graphs = {}
self.output_buffers = {} self.output_buffers = {}
self.enable_torch_compile = model_runner.server_args.enable_torch_compile self.enable_torch_compile = model_runner.server_args.enable_torch_compile
......
...@@ -9,7 +9,6 @@ from huggingface_hub import snapshot_download ...@@ -9,7 +9,6 @@ from huggingface_hub import snapshot_download
from sglang.srt.distributed import ( from sglang.srt.distributed import (
GroupCoordinator, GroupCoordinator,
get_tensor_model_parallel_world_size,
get_tp_group, get_tp_group,
patch_tensor_parallel_group, patch_tensor_parallel_group,
) )
...@@ -92,7 +91,7 @@ class EAGLEWorker(TpModelWorker): ...@@ -92,7 +91,7 @@ class EAGLEWorker(TpModelWorker):
) )
self.padded_static_len = -1 self.padded_static_len = -1
# Override context length with target model's context length # Override the context length of the draft model to be the same as the target model.
server_args.context_length = target_worker.model_runner.model_config.context_len server_args.context_length = target_worker.model_runner.model_config.context_len
# Do not capture cuda graph in `super().__init__()` # Do not capture cuda graph in `super().__init__()`
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment