Unverified Commit ebaa2f31 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Rename arguments `--disable-nan-detection` to `--enable-nan-detection` (#2066)

parent 62832bb2
...@@ -29,7 +29,7 @@ logger = logging.getLogger(__name__) ...@@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
class Sampler(nn.Module): class Sampler(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.use_nan_detectioin = not global_server_args_dict["disable_nan_detection"] self.use_nan_detectioin = global_server_args_dict["enable_nan_detection"]
def forward( def forward(
self, self,
......
...@@ -57,7 +57,7 @@ global_server_args_dict = { ...@@ -57,7 +57,7 @@ global_server_args_dict = {
"triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32, "triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
"disable_mla": ServerArgs.disable_mla, "disable_mla": ServerArgs.disable_mla,
"torchao_config": ServerArgs.torchao_config, "torchao_config": ServerArgs.torchao_config,
"disable_nan_detection": ServerArgs.disable_nan_detection, "enable_nan_detection": ServerArgs.enable_nan_detection,
"enable_dp_attention": ServerArgs.enable_dp_attention, "enable_dp_attention": ServerArgs.enable_dp_attention,
} }
......
...@@ -139,7 +139,7 @@ class ModelRunner: ...@@ -139,7 +139,7 @@ class ModelRunner:
"disable_mla": server_args.disable_mla, "disable_mla": server_args.disable_mla,
"torchao_config": server_args.torchao_config, "torchao_config": server_args.torchao_config,
"disable_penalizer": server_args.disable_penalizer, "disable_penalizer": server_args.disable_penalizer,
"disable_nan_detection": server_args.disable_nan_detection, "enable_nan_detection": server_args.enable_nan_detection,
"enable_dp_attention": server_args.enable_dp_attention, "enable_dp_attention": server_args.enable_dp_attention,
} }
) )
...@@ -276,6 +276,10 @@ class ModelRunner: ...@@ -276,6 +276,10 @@ class ModelRunner:
else None else None
) )
self.dtype = self.vllm_model_config.dtype self.dtype = self.vllm_model_config.dtype
if self.sliding_window_size:
assert (
self.server_args.attention_backend == "flashinfer"
), "Only flashinfer supports window attention."
logger.info( logger.info(
f"Load weight end. " f"Load weight end. "
......
...@@ -332,6 +332,7 @@ class Gemma2ForCausalLM(nn.Module): ...@@ -332,6 +332,7 @@ class Gemma2ForCausalLM(nn.Module):
# Gemma does not apply LoRA to the embedding layer. # Gemma does not apply LoRA to the embedding layer.
embedding_modules = {} embedding_modules = {}
embedding_padding_modules = [] embedding_padding_modules = []
supports_lora = True
def __init__( def __init__(
self, self,
......
...@@ -124,7 +124,6 @@ class ServerArgs: ...@@ -124,7 +124,6 @@ class ServerArgs:
disable_custom_all_reduce: bool = False disable_custom_all_reduce: bool = False
disable_mla: bool = False disable_mla: bool = False
disable_penalizer: bool = False disable_penalizer: bool = False
disable_nan_detection: bool = False
enable_overlap_schedule: bool = False enable_overlap_schedule: bool = False
enable_mixed_chunk: bool = False enable_mixed_chunk: bool = False
enable_dp_attention: bool = False enable_dp_attention: bool = False
...@@ -132,6 +131,7 @@ class ServerArgs: ...@@ -132,6 +131,7 @@ class ServerArgs:
torch_compile_max_bs: int = 32 torch_compile_max_bs: int = 32
cuda_graph_max_bs: int = 160 cuda_graph_max_bs: int = 160
torchao_config: str = "" torchao_config: str = ""
enable_nan_detection: bool = False
enable_p2p_check: bool = False enable_p2p_check: bool = False
triton_attention_reduce_in_fp32: bool = False triton_attention_reduce_in_fp32: bool = False
num_continuous_decode_steps: int = 1 num_continuous_decode_steps: int = 1
...@@ -171,11 +171,11 @@ class ServerArgs: ...@@ -171,11 +171,11 @@ class ServerArgs:
else: else:
gpu_mem = get_nvgpu_memory_capacity() gpu_mem = get_nvgpu_memory_capacity()
if gpu_mem < 25000: if gpu_mem < 25000:
self.chunked_prefill_size //= 4 # make it 2048
self.cuda_graph_max_bs = 4
logger.warning( logger.warning(
"Automatically adjust --chunked-prefill-size for small GPUs." "Automatically adjust --chunked-prefill-size for small GPUs."
) )
self.chunked_prefill_size //= 4 # make it 2048
self.cuda_graph_max_bs = 4
if not is_flashinfer_available(): if not is_flashinfer_available():
self.attention_backend = "triton" self.attention_backend = "triton"
...@@ -194,7 +194,7 @@ class ServerArgs: ...@@ -194,7 +194,7 @@ class ServerArgs:
self.cuda_graph_max_bs = min(self.cuda_graph_max_bs, 96) self.cuda_graph_max_bs = min(self.cuda_graph_max_bs, 96)
self.enable_overlap_schedule = False self.enable_overlap_schedule = False
logger.warning( logger.warning(
f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE workload issue. " f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
f"The CUDA graph max batch size is adjusted to {self.cuda_graph_max_bs}. " f"The CUDA graph max batch size is adjusted to {self.cuda_graph_max_bs}. "
"Data parallel size is adjusted to be the same as tensor parallel size." "Data parallel size is adjusted to be the same as tensor parallel size."
) )
...@@ -204,21 +204,8 @@ class ServerArgs: ...@@ -204,21 +204,8 @@ class ServerArgs:
"Overlap scheduler mode is enabled. This is an experimental feature. " "Overlap scheduler mode is enabled. This is an experimental feature. "
"Sampling penalizer (e.g., frequency and repetition penalty), constrained decoding (e.g., regex, JSON), " "Sampling penalizer (e.g., frequency and repetition penalty), constrained decoding (e.g., regex, JSON), "
"and embedding APIs are not supported and will lead to wrong results. " "and embedding APIs are not supported and will lead to wrong results. "
"The NaN detection is also disabled."
) )
self.disable_penalizer = True self.disable_penalizer = True
self.disable_nan_detection = True
# Model-specific patches
if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path:
logger.info(
"Not sure why, the tokenizer will add an additional token at the end of the prompt when trust_remote_mode=True"
)
self.trust_remote_code = False
if "gemma-2" in self.model_path.lower():
logger.info("When using sliding window in gemma-2, turn on flashinfer.")
self.attention_backend = "flashinfer"
@staticmethod @staticmethod
def add_cli_args(parser: argparse.ArgumentParser): def add_cli_args(parser: argparse.ArgumentParser):
...@@ -683,6 +670,11 @@ class ServerArgs: ...@@ -683,6 +670,11 @@ class ServerArgs:
default=ServerArgs.torchao_config, default=ServerArgs.torchao_config,
help="Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo", help="Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo",
) )
parser.add_argument(
"--enable-nan-detection",
action="store_true",
help="Enable the NaN detection for debugging purposes.",
)
parser.add_argument( parser.add_argument(
"--enable-p2p-check", "--enable-p2p-check",
action="store_true", action="store_true",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment