Unverified Commit fa909dc3 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

feat: update model_specific_adjustment (#5344)


Co-authored-by: default avatarhebiao064 <hebiaobuaa@gmail.com>
parent e8f62b20
...@@ -383,7 +383,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -383,7 +383,7 @@ class FlashAttentionBackend(AttentionBackend):
forward_batch.req_pool_indices, : metadata.max_seq_len_k forward_batch.req_pool_indices, : metadata.max_seq_len_k
] ]
elif forward_batch.forward_mode.is_extend_or_draft_extend(): elif forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed():
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32) metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
metadata.cu_seqlens_k = torch.nn.functional.pad( metadata.cu_seqlens_k = torch.nn.functional.pad(
......
...@@ -78,7 +78,7 @@ class ForwardMode(IntEnum): ...@@ -78,7 +78,7 @@ class ForwardMode(IntEnum):
self == ForwardMode.EXTEND self == ForwardMode.EXTEND
or self == ForwardMode.MIXED or self == ForwardMode.MIXED
or self == ForwardMode.DRAFT_EXTEND or self == ForwardMode.DRAFT_EXTEND
or self == self.TARGET_VERIFY or self == ForwardMode.TARGET_VERIFY
) )
def is_decode(self): def is_decode(self):
...@@ -96,6 +96,13 @@ class ForwardMode(IntEnum): ...@@ -96,6 +96,13 @@ class ForwardMode(IntEnum):
def is_draft_extend(self): def is_draft_extend(self):
return self == ForwardMode.DRAFT_EXTEND return self == ForwardMode.DRAFT_EXTEND
def is_extend_or_draft_extend_or_mixed(self):
return (
self == ForwardMode.EXTEND
or self == ForwardMode.DRAFT_EXTEND
or self == ForwardMode.MIXED
)
def is_cuda_graph(self): def is_cuda_graph(self):
return ( return (
self == ForwardMode.DECODE self == ForwardMode.DECODE
...@@ -103,9 +110,6 @@ class ForwardMode(IntEnum): ...@@ -103,9 +110,6 @@ class ForwardMode(IntEnum):
or self == ForwardMode.IDLE or self == ForwardMode.IDLE
) )
def is_extend_or_draft_extend(self):
return self == ForwardMode.EXTEND or self == ForwardMode.DRAFT_EXTEND
def is_dummy_first(self): def is_dummy_first(self):
return self == ForwardMode.DUMMY_FIRST return self == ForwardMode.DUMMY_FIRST
......
...@@ -78,9 +78,11 @@ from sglang.srt.utils import ( ...@@ -78,9 +78,11 @@ from sglang.srt.utils import (
get_available_gpu_memory, get_available_gpu_memory,
init_custom_process_group, init_custom_process_group,
is_cuda, is_cuda,
is_fa3_default_architecture,
is_flashinfer_available, is_flashinfer_available,
is_hip, is_hip,
is_hopper_with_cuda_12_3, is_hopper_with_cuda_12_3,
is_no_spec_infer_or_topk_one,
monkey_patch_p2p_access_check, monkey_patch_p2p_access_check,
monkey_patch_vllm_gguf_config, monkey_patch_vllm_gguf_config,
set_cpu_offload_max_bytes, set_cpu_offload_max_bytes,
...@@ -242,18 +244,21 @@ class ModelRunner: ...@@ -242,18 +244,21 @@ class ModelRunner:
elif server_args.attention_backend is None: elif server_args.attention_backend is None:
# By default, use flashinfer for non-mla attention and triton for mla attention # By default, use flashinfer for non-mla attention and triton for mla attention
if not self.use_mla_backend: if not self.use_mla_backend:
server_args.attention_backend = ( if (
"flashinfer" if is_flashinfer_available() else "triton" is_hopper_with_cuda_12_3()
) and is_no_spec_infer_or_topk_one(server_args)
and is_fa3_default_architecture(self.model_config.hf_config)
):
server_args.attention_backend = "fa3"
else:
server_args.attention_backend = (
"flashinfer" if is_flashinfer_available() else "triton"
)
else: else:
if is_hopper_with_cuda_12_3(): if is_hopper_with_cuda_12_3() and is_no_spec_infer_or_topk_one(
if server_args.speculative_eagle_topk is None or ( server_args
server_args.speculative_eagle_topk is not None ):
and server_args.speculative_eagle_topk == 1 server_args.attention_backend = "fa3"
):
server_args.attention_backend = "fa3"
else:
server_args.attention_backend = "triton"
else: else:
server_args.attention_backend = "triton" server_args.attention_backend = "triton"
logger.info( logger.info(
......
...@@ -569,7 +569,7 @@ def encode_video(video_path, frame_count_limit=None): ...@@ -569,7 +569,7 @@ def encode_video(video_path, frame_count_limit=None):
def load_image( def load_image(
image_file: Union[Image.Image, str, bytes] image_file: Union[Image.Image, str, bytes],
) -> tuple[Image.Image, tuple[int, int]]: ) -> tuple[Image.Image, tuple[int, int]]:
image = image_size = None image = image_size = None
if isinstance(image_file, Image.Image): if isinstance(image_file, Image.Image):
...@@ -1905,3 +1905,28 @@ def get_local_ip_by_remote() -> str: ...@@ -1905,3 +1905,28 @@ def get_local_ip_by_remote() -> str:
return s.getsockname()[0] return s.getsockname()[0]
except Exception: except Exception:
raise ValueError(f"Can not get local ip") raise ValueError(f"Can not get local ip")
def is_page_size_one(server_args):
return server_args.page_size == 1
def is_no_spec_infer_or_topk_one(server_args):
return server_args.speculative_eagle_topk is None or (
server_args.speculative_eagle_topk is not None
and server_args.speculative_eagle_topk == 1
and is_page_size_one(server_args)
)
def is_fa3_default_architecture(hf_config):
architectures = getattr(hf_config, "architectures", None)
if not isinstance(architectures, list) or not architectures:
return False
default_archs = {
"Qwen2ForCausalLM",
"Llama4ForConditionalGeneration",
"LlamaForCausalLM",
"MistralForCausalLM",
}
return architectures[0] in default_archs
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment