Commit 0e607f8e authored by zhuwenwen's avatar zhuwenwen
Browse files

fix tests of kernels

set VLLM_USE_PD_SPLIT=1
update moe_align_block_size
parent cbdc58ec
...@@ -401,7 +401,7 @@ _EMBEDDING_EXAMPLE_MODELS = { ...@@ -401,7 +401,7 @@ _EMBEDDING_EXAMPLE_MODELS = {
"LlavaNextForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix, "royokong/e5-v")), "LlavaNextForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix, "royokong/e5-v")),
"Phi3VForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "TIGER-Lab/VLM2Vec-Full"), "Phi3VForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "TIGER-Lab/VLM2Vec-Full"),
trust_remote_code=True), trust_remote_code=True),
"Qwen2VLForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix, "MrLight/dse-qwen2-2b-mrl-v1"), # noqa: E501 "Qwen2VLForConditionalGeneration": _HfExamplesInfo(os.path.join(models_path_prefix, "MrLight/dse-qwen2-2b-mrl-v1")), # noqa: E501
"PrithviGeoSpatialMAE": _HfExamplesInfo(os.path.join(models_path_prefix, "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11"), # noqa: E501 "PrithviGeoSpatialMAE": _HfExamplesInfo(os.path.join(models_path_prefix, "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11"), # noqa: E501
dtype=torch.float16, dtype=torch.float16,
enforce_eager=True, enforce_eager=True,
...@@ -656,9 +656,9 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = { ...@@ -656,9 +656,9 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
os.path.join(models_path_prefix, "meituan-longcat/LongCat-Flash-Chat"), os.path.join(models_path_prefix, "meituan-longcat/LongCat-Flash-Chat"),
trust_remote_code=True, trust_remote_code=True,
speculative_model=os.path.join(models_path_prefix, "meituan-longcat/LongCat-Flash-Chat")), speculative_model=os.path.join(models_path_prefix, "meituan-longcat/LongCat-Flash-Chat")),
"MiMoMTPModel": _HfExamplesInfo(os.path.join(models_path_prefix, "XiaomiMiMo/MiMo-7B-RL")), "MiMoMTPModel": _HfExamplesInfo(os.path.join(models_path_prefix, "XiaomiMiMo/MiMo-7B-RL"),
trust_remote_code=True, trust_remote_code=True,
speculative_model=os.path.join(models_path_prefix, "XiaomiMiMo/MiMo-7B-RL"), speculative_model=os.path.join(models_path_prefix, "XiaomiMiMo/MiMo-7B-RL")),
"Qwen3NextMTP": _HfExamplesInfo(os.path.join(models_path_prefix, "Qwen/Qwen3-Next-80B-A3B-Instruct"), "Qwen3NextMTP": _HfExamplesInfo(os.path.join(models_path_prefix, "Qwen/Qwen3-Next-80B-A3B-Instruct"),
min_transformers_version="4.56.3"), min_transformers_version="4.56.3"),
} }
......
...@@ -233,6 +233,8 @@ if TYPE_CHECKING: ...@@ -233,6 +233,8 @@ if TYPE_CHECKING:
VLLM_USE_LIGHTOP_MOE_SUM: bool = False VLLM_USE_LIGHTOP_MOE_SUM: bool = False
VLLM_USE_LIGHTOP_MOE_ALIGN: bool = False VLLM_USE_LIGHTOP_MOE_ALIGN: bool = False
VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False
USE_FUSED_RMS_QUANT: bool = False
USE_FUSED_SILU_MUL_QUANT: bool = False
VLLM_USE_PD_SPLIT: bool = False VLLM_USE_PD_SPLIT: bool = False
VLLM_USE_PP_SYNC: bool = False VLLM_USE_PP_SYNC: bool = False
VLLM_USE_PIECEWISE: bool = False VLLM_USE_PIECEWISE: bool = False
...@@ -1635,9 +1637,19 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1635,9 +1637,19 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_MERGE_ATTN_STATES_OPT": "VLLM_USE_MERGE_ATTN_STATES_OPT":
lambda: (os.environ.get("VLLM_USE_MERGE_ATTN_STATES_OPT", "True").lower() in lambda: (os.environ.get("VLLM_USE_MERGE_ATTN_STATES_OPT", "True").lower() in
("true", "1")), ("true", "1")),
# vllm will use rmsquant fused op
"USE_FUSED_RMS_QUANT":
lambda: bool(int(os.getenv("USE_FUSED_RMS_QUANT", "0"))),
# vllm will use silu_mul_quant fused op,
# This variable has a default value of true,
# but it is still controlled by CRQ and RQ.
"USE_FUSED_SILU_MUL_QUANT":
lambda: bool(int(os.getenv("USE_FUSED_SILU_MUL_QUANT", "0"))),
# vLLM will split prefill and decode, not mix up # vLLM will split prefill and decode, not mix up
"VLLM_USE_PD_SPLIT": "VLLM_USE_PD_SPLIT":
lambda: (os.environ.get("VLLM_USE_PD_SPLIT", "False").lower() in lambda: (os.environ.get("VLLM_USE_PD_SPLIT", "True").lower() in
("true", "1")), ("true", "1")),
# vLLM will sync to avoid pp vmfault # vLLM will sync to avoid pp vmfault
"VLLM_USE_PP_SYNC": "VLLM_USE_PP_SYNC":
......
...@@ -102,7 +102,7 @@ def moe_align_block_size( ...@@ -102,7 +102,7 @@ def moe_align_block_size(
expert_map = expert_map, expert_map = expert_map,
expert_mask = expert_mask, expert_mask = expert_mask,
num_local_tokens = None, num_local_tokens = None,
Is_fuse_fill = False) Is_fuse_fill = True)
else: else:
if envs.VLLM_USE_LIGHTOP_MOE_ALIGN: if envs.VLLM_USE_LIGHTOP_MOE_ALIGN:
from lightop import op as op from lightop import op as op
...@@ -111,7 +111,7 @@ def moe_align_block_size( ...@@ -111,7 +111,7 @@ def moe_align_block_size(
expert_map = None, expert_map = None,
expert_mask = None, expert_mask = None,
num_local_tokens = None, num_local_tokens = None,
Is_fuse_fill = False) Is_fuse_fill = True)
else: else:
ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids, ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
expert_ids, num_tokens_post_pad) expert_ids, num_tokens_post_pad)
......
...@@ -137,15 +137,9 @@ def get_rope( ...@@ -137,15 +137,9 @@ def get_rope(
scaling_alpha, dtype) scaling_alpha, dtype)
elif "factor" in rope_scaling: elif "factor" in rope_scaling:
scaling_factor = rope_scaling["factor"] scaling_factor = rope_scaling["factor"]
scaling_alpha = rope_scaling["alpha"] rotary_emb = DynamicNTKScalingRotaryEmbedding(
if scaling_alpha: head_size, rotary_dim, max_position, base, is_neox_style,
rotary_emb = DynamicNTKAlphaRotaryEmbedding( scaling_factor, dtype)
head_size, rotary_dim, max_position, base, is_neox_style,
scaling_alpha, dtype)
else:
rotary_emb = DynamicNTKScalingRotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style,
scaling_factor, dtype)
else: else:
raise ValueError("Dynamic rope scaling must contain either " raise ValueError("Dynamic rope scaling must contain either "
"'alpha' or 'factor' field") "'alpha' or 'factor' field")
......
...@@ -199,11 +199,11 @@ def _get_model_architecture( ...@@ -199,11 +199,11 @@ def _get_model_architecture(
if not envs.is_set("VLLM_USE_OPT_CAT"): if not envs.is_set("VLLM_USE_OPT_CAT"):
os.environ['VLLM_USE_OPT_CAT'] = '1' os.environ['VLLM_USE_OPT_CAT'] = '1'
else: else:
if not envs.is_set("VLLM_USE_PD_SPLIT"): # if not envs.is_set("VLLM_USE_PD_SPLIT"):
os.environ['VLLM_USE_PD_SPLIT'] = '1' # os.environ['VLLM_USE_PD_SPLIT'] = '1'
if architectures in [['Qwen3MoeForCausalLM']]: if architectures in [['Qwen3MoeForCausalLM']]:
# if not envs.is_set("VLLM_USE_LIGHTOP_MOE_ALIGN"): if not envs.is_set("VLLM_USE_LIGHTOP_MOE_ALIGN"):
# os.environ['VLLM_USE_LIGHTOP_MOE_ALIGN'] = '1' os.environ['VLLM_USE_LIGHTOP_MOE_ALIGN'] = '1'
if not envs.is_set("VLLM_USE_LIGHTOP_MOE_SUM"): if not envs.is_set("VLLM_USE_LIGHTOP_MOE_SUM"):
os.environ['VLLM_USE_LIGHTOP_MOE_SUM'] = '1' os.environ['VLLM_USE_LIGHTOP_MOE_SUM'] = '1'
if not envs.is_set("VLLM_USE_FUSE_SILU_AND_MUL"): if not envs.is_set("VLLM_USE_FUSE_SILU_AND_MUL"):
...@@ -226,11 +226,11 @@ def _get_model_architecture( ...@@ -226,11 +226,11 @@ def _get_model_architecture(
if not envs.is_set("VLLM_USE_OPT_CAT"): if not envs.is_set("VLLM_USE_OPT_CAT"):
os.environ['VLLM_USE_OPT_CAT'] = '1' os.environ['VLLM_USE_OPT_CAT'] = '1'
else: else:
if not envs.is_set("VLLM_USE_PD_SPLIT"): # if not envs.is_set("VLLM_USE_PD_SPLIT"):
os.environ['VLLM_USE_PD_SPLIT'] = '1' # os.environ['VLLM_USE_PD_SPLIT'] = '1'
if architectures in [['Qwen3MoeForCausalLM']]: if architectures in [['Qwen3MoeForCausalLM']]:
# if not envs.is_set("VLLM_USE_LIGHTOP_MOE_ALIGN"): if not envs.is_set("VLLM_USE_LIGHTOP_MOE_ALIGN"):
# os.environ['VLLM_USE_LIGHTOP_MOE_ALIGN'] = '1' os.environ['VLLM_USE_LIGHTOP_MOE_ALIGN'] = '1'
if not envs.is_set("VLLM_USE_LIGHTOP_MOE_SUM"): if not envs.is_set("VLLM_USE_LIGHTOP_MOE_SUM"):
os.environ['VLLM_USE_LIGHTOP_MOE_SUM'] = '1' os.environ['VLLM_USE_LIGHTOP_MOE_SUM'] = '1'
if not envs.is_set("VLLM_USE_FUSE_SILU_AND_MUL"): if not envs.is_set("VLLM_USE_FUSE_SILU_AND_MUL"):
......
...@@ -129,8 +129,8 @@ STR_DTYPE_TO_TORCH_DTYPE = { ...@@ -129,8 +129,8 @@ STR_DTYPE_TO_TORCH_DTYPE = {
"bfloat16": torch.bfloat16, "bfloat16": torch.bfloat16,
"float": torch.float, "float": torch.float,
"fp8": torch.uint8, "fp8": torch.uint8,
# "fp8_e4m3": torch.uint8, "fp8_e4m3": torch.uint8,
# "fp8_e5m2": torch.uint8, "fp8_e5m2": torch.uint8,
"int8": torch.int8, "int8": torch.int8,
"fp8_inc": torch.float8_e4m3fn, "fp8_inc": torch.float8_e4m3fn,
"fp8_ds_mla": torch.uint8, "fp8_ds_mla": torch.uint8,
......
...@@ -1089,14 +1089,15 @@ class Scheduler(SchedulerInterface): ...@@ -1089,14 +1089,15 @@ class Scheduler(SchedulerInterface):
def schedule(self) -> SchedulerOutput: def schedule(self) -> SchedulerOutput:
if envs.VLLM_USE_PD_SPLIT: if envs.VLLM_USE_PD_SPLIT:
return self.schedule_split_pd() if self.use_mla:
else: if self.full_cuda_graph and self.num_spec_tokens > 0:
if self.connector is not None: return self.schedule_split_pd()
return self.schedule_default() else:
if self.full_cuda_graph and self.use_mla and self.num_spec_tokens > 0 : self.schedule_default()
return self.schedule_split_pd()
else: else:
return self.schedule_default() return self.schedule_split_pd()
else:
return self.schedule_default()
def _update_after_schedule( def _update_after_schedule(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment