Commit 1f4b9553 authored by 王敏's avatar 王敏
Browse files

Merge remote-tracking branch 'origin/v0.9.2-dev-ds' into v0.9.2-dev-ds

parents 5a5e4f3b c2e6f453
...@@ -114,9 +114,9 @@ def run_moe_test( ...@@ -114,9 +114,9 @@ def run_moe_test(
return baseline_output return baseline_output
@pytest.mark.parametrize("m", [1, 33, 64, 222, 32768, 40000]) @pytest.mark.parametrize("m", [1, 33, 64, 32768, 40000])
@pytest.mark.parametrize("n", [128, 1024, 2048]) @pytest.mark.parametrize("n", [128, 1024, 2048])
@pytest.mark.parametrize("k", [128, 511, 1024]) @pytest.mark.parametrize("k", [128, 512, 1024])
@pytest.mark.parametrize("e", NUM_EXPERTS) @pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("ep_size", EP_SIZE) @pytest.mark.parametrize("ep_size", EP_SIZE)
......
...@@ -167,6 +167,7 @@ if TYPE_CHECKING: ...@@ -167,6 +167,7 @@ if TYPE_CHECKING:
VLLM_USE_LIGHTOP: bool = False VLLM_USE_LIGHTOP: bool = False
VLLM_USE_OPT_CAT: bool = False VLLM_USE_OPT_CAT: bool = False
VLLM_USE_OPT_MOE_SUM: bool = False VLLM_USE_OPT_MOE_SUM: bool = False
VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD: bool = False
VLLM_USE_LIGHTOP_MOE_SUM: bool = False VLLM_USE_LIGHTOP_MOE_SUM: bool = False
VLLM_USE_LIGHTOP_MOE_ALIGN: bool = False VLLM_USE_LIGHTOP_MOE_ALIGN: bool = False
VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False
...@@ -1112,6 +1113,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1112,6 +1113,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_OPT_MOE_SUM": "VLLM_USE_OPT_MOE_SUM":
lambda: (os.environ.get("VLLM_USE_OPT_MOE_SUM", "False").lower() in lambda: (os.environ.get("VLLM_USE_OPT_MOE_SUM", "False").lower() in
("true", "1")), ("true", "1")),
# vLLM will use lightop moe_sum_mul_add
"VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD":
lambda: (os.environ.get("VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD", "True").lower() in
("true", "1")),
# vLLM will use lightop moe_sum # vLLM will use lightop moe_sum
"VLLM_USE_LIGHTOP_MOE_SUM": "VLLM_USE_LIGHTOP_MOE_SUM":
lambda: (os.environ.get("VLLM_USE_LIGHTOP_MOE_SUM", "True").lower() in lambda: (os.environ.get("VLLM_USE_LIGHTOP_MOE_SUM", "True").lower() in
......
...@@ -1896,7 +1896,7 @@ def fused_experts_impl( ...@@ -1896,7 +1896,7 @@ def fused_experts_impl(
block_shape=block_shape, block_shape=block_shape,
use_nn_moe=use_nn_moe) use_nn_moe=use_nn_moe)
if envs.VLLM_USE_LIGHTOP: if envs.VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD:
from lightop import op as op from lightop import op as op
op.moe_sum(input=intermediate_cache3.view(*intermediate_cache3.size()), op.moe_sum(input=intermediate_cache3.view(*intermediate_cache3.size()),
output=out_hidden_states[begin_chunk_idx:end_chunk_idx], bias=shared_output[begin_chunk_idx:end_chunk_idx], output=out_hidden_states[begin_chunk_idx:end_chunk_idx], bias=shared_output[begin_chunk_idx:end_chunk_idx],
......
...@@ -247,6 +247,8 @@ def get_model_architecture( ...@@ -247,6 +247,8 @@ def get_model_architecture(
if architectures in [['DeepseekV3ForCausalLM'], ['DeepSeekMTPModel']]: if architectures in [['DeepseekV3ForCausalLM'], ['DeepSeekMTPModel']]:
if not envs.is_set("VLLM_USE_LIGHTOP"): if not envs.is_set("VLLM_USE_LIGHTOP"):
os.environ['VLLM_USE_LIGHTOP'] = '1' os.environ['VLLM_USE_LIGHTOP'] = '1'
if not envs.is_set("VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD"):
os.environ['VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD'] = '1'
if not envs.is_set("VLLM_USE_OPT_CAT"): if not envs.is_set("VLLM_USE_OPT_CAT"):
os.environ['VLLM_USE_OPT_CAT'] = '1' os.environ['VLLM_USE_OPT_CAT'] = '1'
...@@ -258,6 +260,8 @@ def get_model_architecture( ...@@ -258,6 +260,8 @@ def get_model_architecture(
if architectures in [['DeepseekV3ForCausalLM'], ['DeepSeekMTPModel']]: if architectures in [['DeepseekV3ForCausalLM'], ['DeepSeekMTPModel']]:
if not envs.is_set("VLLM_USE_LIGHTOP"): if not envs.is_set("VLLM_USE_LIGHTOP"):
os.environ['VLLM_USE_LIGHTOP'] = '1' os.environ['VLLM_USE_LIGHTOP'] = '1'
if not envs.is_set("VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD"):
os.environ['VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD'] = '1'
if not envs.is_set("VLLM_USE_OPT_CAT"): if not envs.is_set("VLLM_USE_OPT_CAT"):
os.environ['VLLM_USE_OPT_CAT'] = '1' os.environ['VLLM_USE_OPT_CAT'] = '1'
# awq相关配置 # awq相关配置
......
...@@ -225,7 +225,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -225,7 +225,7 @@ class DeepseekV2MoE(nn.Module):
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
if not self.use_mori_ep: if not self.use_mori_ep:
if envs.VLLM_USE_LIGHTOP: if envs.envs.VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD:
final_hidden_states = self.experts( final_hidden_states = self.experts(
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits,
......
...@@ -12,6 +12,7 @@ from vllm.attention.selector import (backend_name_to_enum, ...@@ -12,6 +12,7 @@ from vllm.attention.selector import (backend_name_to_enum,
get_global_forced_attn_backend) get_global_forced_attn_backend)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import _Backend, current_platform from vllm.platforms import _Backend, current_platform
from vllm.utils import SUPPORT_TC
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -82,6 +83,8 @@ def get_vit_attn_backend(support_fa: bool = False) -> _Backend: ...@@ -82,6 +83,8 @@ def get_vit_attn_backend(support_fa: bool = False) -> _Backend:
selected_backend = backend_name_to_enum(backend_by_env_var) selected_backend = backend_name_to_enum(backend_by_env_var)
if selected_backend is None: if selected_backend is None:
if current_platform.is_cuda() or current_platform.is_rocm(): if current_platform.is_cuda() or current_platform.is_rocm():
if not SUPPORT_TC:
selected_backend = _Backend.TORCH_SDPA
device_available = current_platform.has_device_capability(80) device_available = current_platform.has_device_capability(80)
if device_available and support_fa: if device_available and support_fa:
from transformers.utils import is_flash_attn_2_available from transformers.utils import is_flash_attn_2_available
......
...@@ -16,13 +16,15 @@ from vllm.utils import cuda_device_count_stateless ...@@ -16,13 +16,15 @@ from vllm.utils import cuda_device_count_stateless
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
from vllm.utils import SUPPORT_TC from vllm.utils import is_kme, SUPPORT_TC
if not SUPPORT_TC: if not SUPPORT_TC:
os.environ['VLLM_USE_V1'] = '0' os.environ['VLLM_USE_V1'] = '0'
os.environ['VLLM_USE_FLASH_ATTN_PA'] = '0' os.environ['VLLM_USE_FLASH_ATTN_PA'] = '0'
os.environ['VLLM_USE_FLASH_MLA'] = '0' os.environ['VLLM_USE_FLASH_MLA'] = '0'
if is_kme:
os.environ['VLLM_USE_FLASH_ATTN_PA'] = '0'
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import ModelConfig, VllmConfig from vllm.config import ModelConfig, VllmConfig
...@@ -296,6 +298,8 @@ class RocmPlatform(Platform): ...@@ -296,6 +298,8 @@ class RocmPlatform(Platform):
logger.info("flash_attn is not supported on NAVI GPUs.") logger.info("flash_attn is not supported on NAVI GPUs.")
else: else:
logger.info("%s is not supported in AMD GPUs.", selected_backend) logger.info("%s is not supported in AMD GPUs.", selected_backend)
if is_kme:
os.environ['VLLM_USE_TRITON_FLASH_ATTN'] = '1'
logger.info("Using ROCmFlashAttention backend.") logger.info("Using ROCmFlashAttention backend.")
return "vllm.attention.backends.rocm_flash_attn.ROCmFlashAttentionBackend" # noqa: E501 return "vllm.attention.backends.rocm_flash_attn.ROCmFlashAttentionBackend" # noqa: E501
......
...@@ -85,6 +85,7 @@ POOLING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768 ...@@ -85,6 +85,7 @@ POOLING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120 MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120
gpuname = torch.cuda.get_device_properties(torch.cuda.current_device()).name gpuname = torch.cuda.get_device_properties(torch.cuda.current_device()).name
is_kme = gpuname.startswith('K100_AI') or gpuname.startswith('K500SM_AI')
SUPPORT_TC = gpuname.startswith('K100_AI') or gpuname.startswith('K500SM_AI') or gpuname.startswith('BW') SUPPORT_TC = gpuname.startswith('K100_AI') or gpuname.startswith('K500SM_AI') or gpuname.startswith('BW')
def _generate_random_int8( def _generate_random_int8(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment