Commit bd363067 authored by lizhigong's avatar lizhigong
Browse files

Merge branch 'v0.8.5.post1-dev' into v0.8.5-zero_overhead

parents 87ef4618 d36deb1a
...@@ -27,7 +27,7 @@ def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner: ...@@ -27,7 +27,7 @@ def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner:
def test_deepseek_mla_attn_backend_module(): def test_deepseek_mla_attn_backend_module():
model_runner = _create_model_runner( model_runner = _create_model_runner(
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct", os.path.join(models_path_prefix, "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"),
trust_remote_code=True, trust_remote_code=True,
enable_chunked_prefill=False, enable_chunked_prefill=False,
) )
......
...@@ -1979,3 +1979,31 @@ def flash_mla_with_kvcache( ...@@ -1979,3 +1979,31 @@ def flash_mla_with_kvcache(
# torch.ops._C.cutlass_mla_decode(out, q_nope, q_pe, kv_c_and_k_pe_cache, # torch.ops._C.cutlass_mla_decode(out, q_nope, q_pe, kv_c_and_k_pe_cache,
# seq_lens, page_table, scale) # seq_lens, page_table, scale)
# return out # return out
def moe_fused_gate(
input_tensor,
bias,
num_expert_group,
topk_group,
topk,
n_share_experts_fusion=0,
routed_scaling_factor=0,
):
# This fused kernel function is used to select topk expert in a hierarchical 2-layer fashion
# it split group of expert into num_expert_group, and use top2 expert weight sum in each group
# as the group weight to select exerpt groups and then select topk experts within the selected groups
# the #experts is decided by the input tensor shape and we currently only support power of 2 #experts
# and #experts should be divisible by num_expert_group. #expert/num_expert_group <= 32 is limitted for now.
# for non-supported case, we suggestion to use the biased_grouped_topk func in sglang.srt.layers.moe.topk
# n_share_experts_fusion: if > 0, the last expert will be replaced with a round-robin shared expert
# routed_scaling_factor: if > 0, the last expert will be scaled by this factor
return torch.ops._moe_C.moe_fused_gate(
input_tensor,
bias,
num_expert_group,
topk_group,
topk,
n_share_experts_fusion,
routed_scaling_factor,
)
...@@ -27,7 +27,12 @@ from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8, ...@@ -27,7 +27,12 @@ from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8,
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal import MultiModalPlaceholderMap from vllm.multimodal import MultiModalPlaceholderMap
from vllm.utils import async_tensor_h2d, make_tensor_with_pad from vllm.utils import async_tensor_h2d, make_tensor_with_pad
from vllm.vllm_flash_attn import (flash_attn_varlen_func, from vllm.platforms import current_platform
if not current_platform.is_rocm():
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
flash_attn_with_kvcache)
else:
from flash_attn import (flash_attn_varlen_func, vllm_flash_attn_varlen_func,
flash_attn_with_kvcache) flash_attn_with_kvcache)
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -807,6 +812,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -807,6 +812,7 @@ class FlashAttentionImpl(AttentionImpl):
(num_kv_tokens, num_kv_heads, head_size)) (num_kv_tokens, num_kv_heads, head_size))
descale_shape = (q_seq_start_loc.shape[0] - 1, key.shape[1]) descale_shape = (q_seq_start_loc.shape[0] - 1, key.shape[1])
if not current_platform.is_rocm():
flash_attn_varlen_func( flash_attn_varlen_func(
q=query, q=query,
k=key, k=key,
...@@ -826,6 +832,21 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -826,6 +832,21 @@ class FlashAttentionImpl(AttentionImpl):
k_descale=layer._k_scale.expand(descale_shape), k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape), v_descale=layer._v_scale.expand(descale_shape),
) )
else:
prefill_output = flash_attn_varlen_func(
q=query,
k=key,
v=value,
cu_seqlens_q=q_seq_start_loc,
cu_seqlens_k=k_seq_start_loc,
max_seqlen_q=q_seq_len,
max_seqlen_k=k_seq_len,
softmax_scale=softmax_scale,
causal=_get_causal_option(attn_type),
window_size=window_size,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
)
else: else:
# prefix-enabled attention # prefix-enabled attention
assert attn_type == AttentionType.DECODER, ( assert attn_type == AttentionType.DECODER, (
...@@ -835,6 +856,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -835,6 +856,7 @@ class FlashAttentionImpl(AttentionImpl):
max_seq_len = max(prefill_meta.seq_lens) max_seq_len = max(prefill_meta.seq_lens)
descale_shape = (prefill_meta.query_start_loc.shape[0] - 1, descale_shape = (prefill_meta.query_start_loc.shape[0] - 1,
key.shape[1]) key.shape[1])
if not current_platform.is_rocm():
flash_attn_varlen_func( # noqa flash_attn_varlen_func( # noqa
q=query, q=query,
k=key_cache, k=key_cache,
...@@ -855,6 +877,27 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -855,6 +877,27 @@ class FlashAttentionImpl(AttentionImpl):
k_descale=layer._k_scale.expand(descale_shape), k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape), v_descale=layer._v_scale.expand(descale_shape),
) )
else:
vllm_flash_attn_varlen_func( # noqa
q=query,
k=key_cache,
v=value_cache,
cu_seqlens_q=prefill_meta.query_start_loc,
max_seqlen_q=prefill_meta.max_query_len,
seqused_k=prefill_meta.seq_lens_tensor,
max_seqlen_k=max_seq_len,
softmax_scale=softmax_scale,
causal=True,
window_size=window_size,
alibi_slopes=alibi_slopes,
block_table=prefill_meta.block_tables,
softcap=logits_soft_cap,
out=prefill_output,
fa_version=self.vllm_flash_attn_version,
q_descale=layer._q_scale.expand(descale_shape),
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
)
if decode_meta := attn_metadata.decode_metadata: if decode_meta := attn_metadata.decode_metadata:
# Decoding run. # Decoding run.
...@@ -870,6 +913,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -870,6 +913,7 @@ class FlashAttentionImpl(AttentionImpl):
assert decode_meta.query_start_loc is not None assert decode_meta.query_start_loc is not None
descale_shape = (decode_meta.query_start_loc.shape[0] - 1, descale_shape = (decode_meta.query_start_loc.shape[0] - 1,
key.shape[1]) key.shape[1])
if not current_platform.is_rocm():
flash_attn_varlen_func( flash_attn_varlen_func(
q=decode_query, q=decode_query,
k=key_cache, k=key_cache,
...@@ -890,6 +934,22 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -890,6 +934,22 @@ class FlashAttentionImpl(AttentionImpl):
k_descale=layer._k_scale.expand(descale_shape), k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape), v_descale=layer._v_scale.expand(descale_shape),
) )
else:
decode_output = flash_attn_varlen_func(
q=decode_query,
k=key_cache,
v=value_cache,
cu_seqlens_q=decode_meta.query_start_loc,
max_seqlen_q=decode_meta.max_decode_query_len,
seqused_k=decode_meta.seq_lens_tensor,
max_seqlen_k=decode_meta.max_decode_seq_len,
softmax_scale=softmax_scale,
causal=True,
window_size=window_size,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
block_table=decode_meta.block_tables,
)
else: else:
# Use flash_attn_with_kvcache for normal decoding. # Use flash_attn_with_kvcache for normal decoding.
( (
...@@ -898,6 +958,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -898,6 +958,7 @@ class FlashAttentionImpl(AttentionImpl):
block_tables_arg, block_tables_arg,
) = get_seq_len_block_table_args(decode_meta, False, attn_type) ) = get_seq_len_block_table_args(decode_meta, False, attn_type)
descale_shape = (seq_lens_arg.shape[0], key_cache.shape[-2]) descale_shape = (seq_lens_arg.shape[0], key_cache.shape[-2])
if not current_platform.is_rocm():
flash_attn_with_kvcache( flash_attn_with_kvcache(
q=decode_query.unsqueeze(1), q=decode_query.unsqueeze(1),
k_cache=key_cache, k_cache=key_cache,
...@@ -915,6 +976,19 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -915,6 +976,19 @@ class FlashAttentionImpl(AttentionImpl):
k_descale=layer._k_scale.expand(descale_shape), k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape), v_descale=layer._v_scale.expand(descale_shape),
) )
else:
decode_output = flash_attn_with_kvcache(
q=decode_query.unsqueeze(1),
k_cache=key_cache,
v_cache=value_cache,
block_table=block_tables_arg,
cache_seqlens=seq_lens_arg,
softmax_scale=softmax_scale,
causal=True,
window_size=window_size,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
)
return output return output
......
...@@ -22,7 +22,7 @@ from vllm.logger import init_logger ...@@ -22,7 +22,7 @@ from vllm.logger import init_logger
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.rocm import use_rocm_custom_paged_attention from vllm.platforms.rocm import use_rocm_custom_paged_attention
from vllm.utils import SUPPORT_TC from vllm.utils import SUPPORT_TC, gpuname
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
...@@ -578,7 +578,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -578,7 +578,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
try: try:
from flash_attn import flash_attn_varlen_func # noqa: F401 from flash_attn import flash_attn_varlen_func # noqa: F401
self.fa_attn_func = flash_attn_varlen_func self.fa_attn_func = flash_attn_varlen_func
if not envs.VLLM_USE_TRITON_PREFIX_FLASH_ATTN: if not envs.VLLM_USE_TRITON_PREFIX_FLASH_ATTN and gpuname.startswith('K100_AI'):
from flash_attn import vllm_flash_attn_varlen_func from flash_attn import vllm_flash_attn_varlen_func
self.fa_prefix_attn_func = vllm_flash_attn_varlen_func self.fa_prefix_attn_func = vllm_flash_attn_varlen_func
...@@ -852,9 +852,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -852,9 +852,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
else: else:
# prefix-enabled attention - # prefix-enabled attention -
# not applicable for encoder-only models # not applicable for encoder-only models
# if not envs.VLLM_USE_TRITON_PREFIX_FLASH_ATTN: if envs.VLLM_USE_TRITON_PREFIX_FLASH_ATTN or gpuname.startswith('BW'):
# self.fa_prefix_attn_func = vllm_flash_attn_varlen_func
if envs.VLLM_USE_TRITON_PREFIX_FLASH_ATTN:
version_key = triton_key() version_key = triton_key()
if self.attn_type != AttentionType.ENCODER_ONLY: if self.attn_type != AttentionType.ENCODER_ONLY:
output[:num_prefill_tokens] = paged_attn.forward_prefix( output[:num_prefill_tokens] = paged_attn.forward_prefix(
...@@ -922,10 +920,9 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -922,10 +920,9 @@ class ROCmFlashAttentionImpl(AttentionImpl):
num_seqs, num_heads, head_size = decode_query.shape num_seqs, num_heads, head_size = decode_query.shape
block_size = value_cache.shape[3] block_size = value_cache.shape[3]
gqa_ratio = num_heads // self.num_kv_heads gqa_ratio = num_heads // self.num_kv_heads
# use_custom = use_rocm_custom_paged_attention( use_custom = use_rocm_custom_paged_attention(
# decode_query.dtype, head_size, block_size, gqa_ratio, decode_query.dtype, head_size, block_size, gqa_ratio,
# decode_meta.max_decode_seq_len, self.sliding_window) decode_meta.max_decode_seq_len, self.sliding_window)
use_custom = False
if use_custom: if use_custom:
max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type
!= AttentionType.ENCODER_DECODER else != AttentionType.ENCODER_DECODER else
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# yapf: disable # yapf: disable
import os
import argparse import argparse
import dataclasses import dataclasses
import json import json
...@@ -35,12 +36,12 @@ from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3 ...@@ -35,12 +36,12 @@ from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
from vllm.transformers_utils.utils import check_gguf_file from vllm.transformers_utils.utils import check_gguf_file
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, GiB_bytes, is_in_ray_actor from vllm.utils import FlexibleArgumentParser, GiB_bytes, is_in_ray_actor
# yapf: enable # yapf: enable
logger = init_logger(__name__) logger = init_logger(__name__)
ALLOWED_DETAILED_TRACE_MODULES = ["model", "worker", "all"] ALLOWED_DETAILED_TRACE_MODULES = ["model", "worker", "all"]
models_path_prefix = os.getenv('VLLM_OPTEST_MODELS_PATH') or os.getenv("OPTEST_MODELS_PATH")
# object is used to allow for special typing forms # object is used to allow for special typing forms
T = TypeVar("T") T = TypeVar("T")
...@@ -203,7 +204,7 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]: ...@@ -203,7 +204,7 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
@dataclass @dataclass
class EngineArgs: class EngineArgs:
"""Arguments for vLLM engine.""" """Arguments for vLLM engine."""
model: str = 'facebook/opt-125m' model: str = os.path.join(models_path_prefix, 'facebook/opt-125m') if models_path_prefix is not None else 'facebook/opt-125m'
served_model_name: Optional[Union[str, List[str]]] = None served_model_name: Optional[Union[str, List[str]]] = None
tokenizer: Optional[str] = None tokenizer: Optional[str] = None
hf_config_path: Optional[str] = None hf_config_path: Optional[str] = None
......
...@@ -2062,8 +2062,14 @@ class LLMEngine: ...@@ -2062,8 +2062,14 @@ class LLMEngine:
prompt_type: Literal["encoder", "decoder"], prompt_type: Literal["encoder", "decoder"],
): ):
model_config = self.model_config model_config = self.model_config
tokenizer = (None if self.tokenizer is None else if self.tokenizer is None:
self.tokenizer.get_lora_tokenizer(lora_request)) tokenizer = None
elif self.model_config.tokenizer_mode != "cpm":
tokenizer = self.tokenizer.get_lora_tokenizer(lora_request)
else:
tokenizer = self.tokenizer
# tokenizer = (None if self.tokenizer is None else
# self.tokenizer.get_lora_tokenizer(lora_request))
prompt_ids = prompt_inputs["prompt_token_ids"] prompt_ids = prompt_inputs["prompt_token_ids"]
if not prompt_ids: if not prompt_ids:
......
...@@ -124,11 +124,12 @@ if TYPE_CHECKING: ...@@ -124,11 +124,12 @@ if TYPE_CHECKING:
VLLM_PCIE_USE_CUSTOM_ALLREDUCE: bool = False VLLM_PCIE_USE_CUSTOM_ALLREDUCE: bool = False
VLLM_ENFORCE_EAGER_BS_THRESHOLD: Optional[int] = None VLLM_ENFORCE_EAGER_BS_THRESHOLD: Optional[int] = None
VLLM_HAS_CONTEXT_DEFAULT: bool = False VLLM_HAS_CONTEXT_DEFAULT: bool = False
VLLM_FLASH_ATTN_BACKEND: bool = False
VLLM_ENABLE_TBO: bool = False VLLM_ENABLE_TBO: bool = False
VLLM_TBO_REQ_DELAY_MS:int = 0 VLLM_TBO_REQ_DELAY_MS:int = 0
VLLM_ZERO_OVERHEAD: bool = False VLLM_ZERO_OVERHEAD: bool = False
VLLM_ENABLE_MOE_FUSED_GATE: bool = False
def get_default_cache_root(): def get_default_cache_root():
return os.getenv( return os.getenv(
...@@ -801,6 +802,19 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -801,6 +802,19 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ENFORCE_EAGER_BS_THRESHOLD": "VLLM_ENFORCE_EAGER_BS_THRESHOLD":
lambda: int(os.environ.get("VLLM_ENFORCE_EAGER_BS_THRESHOLD", "-1")), lambda: int(os.environ.get("VLLM_ENFORCE_EAGER_BS_THRESHOLD", "-1")),
# 'has_comtext' is a variable in common.py, which is calculated
# by metadata by default. However, it may introduce synchronization
# and affect performance, so it is directly assigned as False.
# If there are any problems during use, use environment variables
# to restore the default usage.
"VLLM_HAS_CONTEXT_DEFAULT":
lambda: bool(int(os.getenv("VLLM_HAS_CONTEXT_DEFAULT", "0"))),
# If set, vLLM will use FlashAttention Backend for attention computation on rocm
"VLLM_FLASH_ATTN_BACKEND":
lambda: (os.environ.get("VLLM_FLASH_ATTN_BACKEND", "False").lower() in
("true", "1")),
# Enable two batch overlap. # Enable two batch overlap.
"VLLM_ENABLE_TBO": "VLLM_ENABLE_TBO":
lambda: bool(int(os.getenv("VLLM_ENABLE_TBO", "0"))), lambda: bool(int(os.getenv("VLLM_ENABLE_TBO", "0"))),
...@@ -813,13 +827,9 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -813,13 +827,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ZERO_OVERHEAD": "VLLM_ZERO_OVERHEAD":
lambda: bool(int(os.getenv("VLLM_ZERO_OVERHEAD", "0"))), lambda: bool(int(os.getenv("VLLM_ZERO_OVERHEAD", "0"))),
# 'has_comtext' is a variable in common.py, which is calculated # If set, vLLM will enable the moe_fused_gate kernel.
# by metadata by default. However, it may introduce synchronization "VLLM_ENABLE_MOE_FUSED_GATE":
# and affect performance, so it is directly assigned as False. lambda: bool(int(os.getenv("VLLM_ENABLE_MOE_FUSED_GATE", "1"))),
# If there are any problems during use, use environment variables
# to restore the default usage.
"VLLM_HAS_CONTEXT_DEFAULT":
lambda: bool(int(os.getenv("VLLM_HAS_CONTEXT_DEFAULT", "0"))),
} }
# end-env-vars-definition # end-env-vars-definition
......
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3,
"num_ldmatrixes": 1
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3,
"num_ldmatrixes": 1
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 3,
"num_ldmatrixes": 1
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5,
"num_ldmatrixes": 1
},
"16": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 5,
"num_ldmatrixes": 1
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5,
"num_ldmatrixes": 1
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 4,
"num_ldmatrixes": 1
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4,
"num_ldmatrixes": 1
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 4,
"num_ldmatrixes": 1
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 4,
"num_ldmatrixes": 1
},
"128": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 4,
"num_ldmatrixes": 1
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3,
"num_ldmatrixes": 1
},
"512": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3,
"num_ldmatrixes": 1
},
"1024": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"1536": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3,
"num_ldmatrixes": 1
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3,
"num_ldmatrixes": 1
}
}
...@@ -3,97 +3,97 @@ ...@@ -3,97 +3,97 @@
"BLOCK_SIZE_M": 16, "BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32, "BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64, "BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1, "GROUP_SIZE_M": 64,
"num_warps": 4, "num_warps": 2,
"num_stages": 4, "num_stages": 2,
"num_ldmatrixes": 1 "num_ldmatrixes": 1
}, },
"2": { "2": {
"BLOCK_SIZE_M": 16, "BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32, "BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64, "BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1, "GROUP_SIZE_M": 16,
"num_warps": 2, "num_warps": 2,
"num_stages": 4, "num_stages": 3,
"num_ldmatrixes": 1 "num_ldmatrixes": 1
}, },
"4": { "4": {
"BLOCK_SIZE_M": 16, "BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32, "BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64, "BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1, "GROUP_SIZE_M": 32,
"num_warps": 2, "num_warps": 2,
"num_stages": 3, "num_stages": 2,
"num_ldmatrixes": 1 "num_ldmatrixes": 1
}, },
"8": { "8": {
"BLOCK_SIZE_M": 16, "BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32, "BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64, "BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1, "GROUP_SIZE_M": 32,
"num_warps": 4, "num_warps": 2,
"num_stages": 4, "num_stages": 2,
"num_ldmatrixes": 1 "num_ldmatrixes": 1
}, },
"16": { "16": {
"BLOCK_SIZE_M": 16, "BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32, "BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64, "BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1, "GROUP_SIZE_M": 64,
"num_warps": 4, "num_warps": 2,
"num_stages": 4, "num_stages": 2,
"num_ldmatrixes": 1 "num_ldmatrixes": 1
}, },
"24": { "24": {
"BLOCK_SIZE_M": 32, "BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64, "BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32, "BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1, "GROUP_SIZE_M": 16,
"num_warps": 4, "num_warps": 2,
"num_stages": 2, "num_stages": 2,
"num_ldmatrixes": 1 "num_ldmatrixes": 1
}, },
"32": { "32": {
"BLOCK_SIZE_M": 128, "BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32, "BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64, "BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1, "GROUP_SIZE_M": 32,
"num_warps": 4, "num_warps": 4,
"num_stages": 2, "num_stages": 2,
"num_ldmatrixes": 1 "num_ldmatrixes": 1
}, },
"48": { "48": {
"BLOCK_SIZE_M": 128, "BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64, "BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64, "BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1, "GROUP_SIZE_M": 16,
"num_warps": 8, "num_warps": 4,
"num_stages": 2, "num_stages": 2,
"num_ldmatrixes": 1 "num_ldmatrixes": 1
}, },
"64": { "64": {
"BLOCK_SIZE_M": 64, "BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64, "BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32, "BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1, "GROUP_SIZE_M": 32,
"num_warps": 4, "num_warps": 8,
"num_stages": 2, "num_stages": 2,
"num_ldmatrixes": 1 "num_ldmatrixes": 1
}, },
"96": { "96": {
"BLOCK_SIZE_M": 128, "BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256, "BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32, "BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1, "GROUP_SIZE_M": 16,
"num_warps": 4, "num_warps": 8,
"num_stages": 2, "num_stages": 2,
"num_ldmatrixes": 1 "num_ldmatrixes": 1
}, },
"128": { "128": {
"BLOCK_SIZE_M": 128, "BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128, "BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64, "BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1, "GROUP_SIZE_M": 16,
"num_warps": 4, "num_warps": 4,
"num_stages": 2, "num_stages": 2,
"num_ldmatrixes": 1 "num_ldmatrixes": 1
...@@ -102,14 +102,14 @@ ...@@ -102,14 +102,14 @@
"BLOCK_SIZE_M": 128, "BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256, "BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32, "BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1, "GROUP_SIZE_M": 16,
"num_warps": 4, "num_warps": 8,
"num_stages": 2, "num_stages": 2,
"num_ldmatrixes": 1 "num_ldmatrixes": 1
}, },
"512": { "512": {
"BLOCK_SIZE_M": 256, "BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128, "BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32, "BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1, "GROUP_SIZE_M": 1,
"num_warps": 4, "num_warps": 4,
...@@ -117,17 +117,17 @@ ...@@ -117,17 +117,17 @@
"num_ldmatrixes": 1 "num_ldmatrixes": 1
}, },
"1024": { "1024": {
"BLOCK_SIZE_M": 256, "BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64, "BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64, "BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1, "GROUP_SIZE_M": 1,
"num_warps": 8, "num_warps": 4,
"num_stages": 2, "num_stages": 2,
"num_ldmatrixes": 1 "num_ldmatrixes": 1
}, },
"1536": { "1536": {
"BLOCK_SIZE_M": 256, "BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64, "BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64, "BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1, "GROUP_SIZE_M": 1,
"num_warps": 8, "num_warps": 8,
...@@ -135,8 +135,8 @@ ...@@ -135,8 +135,8 @@
"num_ldmatrixes": 1 "num_ldmatrixes": 1
}, },
"2048": { "2048": {
"BLOCK_SIZE_M": 256, "BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64, "BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64, "BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1, "GROUP_SIZE_M": 1,
"num_warps": 8, "num_warps": 8,
...@@ -144,20 +144,20 @@ ...@@ -144,20 +144,20 @@
"num_ldmatrixes": 1 "num_ldmatrixes": 1
}, },
"3072": { "3072": {
"BLOCK_SIZE_M": 256, "BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64, "BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64, "BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1, "GROUP_SIZE_M": 1,
"num_warps": 8, "num_warps": 4,
"num_stages": 2, "num_stages": 2,
"num_ldmatrixes": 1 "num_ldmatrixes": 1
}, },
"4096": { "4096": {
"BLOCK_SIZE_M": 256, "BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64, "BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64, "BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1, "GROUP_SIZE_M": 1,
"num_warps": 8, "num_warps": 4,
"num_stages": 2, "num_stages": 2,
"num_ldmatrixes": 1 "num_ldmatrixes": 1
} }
......
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 32,
"num_warps": 2,
"num_stages": 2,
"num_ldmatrixes": 1
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 4,
"num_ldmatrixes": 1
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 3,
"num_ldmatrixes": 1
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 4,
"num_ldmatrixes": 1
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 4,
"num_ldmatrixes": 1
},
"48": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3,
"num_ldmatrixes": 1
},
"64": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3,
"num_ldmatrixes": 1
},
"96": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3,
"num_ldmatrixes": 1
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"256": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"512": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3,
"num_ldmatrixes": 1
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"num_ldmatrixes": 1
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"num_ldmatrixes": 1
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
}
}
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import functools import functools
import json import json
import os import os
import math
from typing import Any, Callable, Dict, List, Optional, Tuple from typing import Any, Callable, Dict, List, Optional, Tuple
import torch import torch
...@@ -1182,6 +1183,10 @@ def fused_topk( ...@@ -1182,6 +1183,10 @@ def fused_topk(
return topk_weights, topk_ids return topk_weights, topk_ids
def is_power_of_two(n):
return n > 0 and math.log2(n).is_integer()
# This is used by the Deepseek-V2 and Deepseek-V3 model # This is used by the Deepseek-V2 and Deepseek-V3 model
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) @torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
def grouped_topk( def grouped_topk(
......
...@@ -19,10 +19,13 @@ from vllm.logger import init_logger ...@@ -19,10 +19,13 @@ from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, grouped_topk, is_power_of_two)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum from vllm.platforms.interface import CpuArchEnum
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
from vllm import _custom_ops as ops
if current_platform.is_cuda_alike(): if current_platform.is_cuda_alike():
from .fused_moe import fused_experts from .fused_moe import fused_experts
...@@ -174,6 +177,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -174,6 +177,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False,
) -> torch.Tensor: ) -> torch.Tensor:
return self.forward( return self.forward(
x=x, x=x,
...@@ -191,7 +196,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -191,7 +196,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
e_score_correction_bias=e_score_correction_bias, e_score_correction_bias=e_score_correction_bias,
activation=activation, activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
use_nn_moe=use_nn_moe) use_nn_moe=use_nn_moe,
routed_scaling_factor=routed_scaling_factor,
use_fused_gate=use_fused_gate)
def forward_cuda( def forward_cuda(
self, self,
...@@ -211,6 +218,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -211,6 +218,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False,
) -> torch.Tensor: ) -> torch.Tensor:
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
...@@ -222,7 +231,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -222,7 +231,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias) e_score_correction_bias=e_score_correction_bias,
routed_scaling_factor=routed_scaling_factor,
use_fused_gate=use_fused_gate)
return fused_experts( return fused_experts(
hidden_states=x, hidden_states=x,
...@@ -255,6 +266,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -255,6 +266,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
activation: str = "silu", activation: str = "silu",
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False,
**kwargs, **kwargs,
): ):
assert activation == "silu", f"{activation} is not supported." assert activation == "silu", f"{activation} is not supported."
...@@ -290,6 +303,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -290,6 +303,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False,
) -> torch.Tensor: ) -> torch.Tensor:
assert not use_grouped_topk assert not use_grouped_topk
assert num_expert_group is None assert num_expert_group is None
...@@ -436,6 +451,7 @@ class FusedMoE(torch.nn.Module): ...@@ -436,6 +451,7 @@ class FusedMoE(torch.nn.Module):
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
routed_scaling_factor: Optional[float] = None,
): ):
super().__init__() super().__init__()
...@@ -505,6 +521,7 @@ class FusedMoE(torch.nn.Module): ...@@ -505,6 +521,7 @@ class FusedMoE(torch.nn.Module):
self.scoring_func = scoring_func self.scoring_func = scoring_func
self.e_score_correction_bias = e_score_correction_bias self.e_score_correction_bias = e_score_correction_bias
self.activation = activation self.activation = activation
self.routed_scaling_factor = routed_scaling_factor
if self.scoring_func != "softmax" and not self.use_grouped_topk: if self.scoring_func != "softmax" and not self.use_grouped_topk:
raise ValueError("Only softmax scoring function is supported for " raise ValueError("Only softmax scoring function is supported for "
...@@ -554,9 +571,16 @@ class FusedMoE(torch.nn.Module): ...@@ -554,9 +571,16 @@ class FusedMoE(torch.nn.Module):
self.quant_method.create_weights(layer=self, **moe_quant_params) self.quant_method.create_weights(layer=self, **moe_quant_params)
from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce
self.tbo_all_reduce = tbo_all_reduce self.tbo_all_reduce = tbo_all_reduce
# moe_fused_gate kernel ensure that num_experts/num_expert_group does not exceed MAX_VPT=32 now. And when kernel can handle MAX_VPT > 32, we can remove this assertion.
self.use_fused_gate = envs.VLLM_ENABLE_MOE_FUSED_GATE \
and self.e_score_correction_bias is not None \
and self.global_num_experts // num_expert_group <= 32 \
and is_power_of_two(e_score_correction_bias.shape[0])
def _load_per_tensor_weight_scale(self, shard_id: str, def _load_per_tensor_weight_scale(self, shard_id: str,
param: torch.nn.Parameter, param: torch.nn.Parameter,
loaded_weight: torch.Tensor, loaded_weight: torch.Tensor,
...@@ -839,14 +863,24 @@ class FusedMoE(torch.nn.Module): ...@@ -839,14 +863,24 @@ class FusedMoE(torch.nn.Module):
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None): e_score_correction_bias: Optional[torch.Tensor] = None,
from vllm.model_executor.layers.fused_moe.fused_moe import ( routed_scaling_factor: Optional[float] = None,
fused_topk, grouped_topk) use_fused_gate: Optional[bool] = False):
# DeekSeekv2 uses grouped_top_k # DeekSeekv2 uses grouped_top_k
if use_grouped_topk: if use_grouped_topk:
assert topk_group is not None assert topk_group is not None
assert num_expert_group is not None assert num_expert_group is not None
if use_fused_gate:
topk_weights, topk_ids = ops.moe_fused_gate(
router_logits,
e_score_correction_bias,
num_expert_group,
topk_group,
top_k,
routed_scaling_factor=routed_scaling_factor,
n_share_experts_fusion=0,
)
else:
topk_weights, topk_ids = grouped_topk( topk_weights, topk_ids = grouped_topk(
hidden_states=hidden_states, hidden_states=hidden_states,
gating_output=router_logits, gating_output=router_logits,
...@@ -927,6 +961,8 @@ class FusedMoE(torch.nn.Module): ...@@ -927,6 +961,8 @@ class FusedMoE(torch.nn.Module):
activation=self.activation, activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input, apply_router_weight_on_input=self.apply_router_weight_on_input,
use_nn_moe=self.use_nn_moe, use_nn_moe=self.use_nn_moe,
routed_scaling_factor=self.routed_scaling_factor,
use_fused_gate=self.use_fused_gate
) )
if self.dp_size > 1: if self.dp_size > 1:
......
...@@ -384,6 +384,8 @@ class BlockInt8MoEMethod: ...@@ -384,6 +384,8 @@ class BlockInt8MoEMethod:
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False,
) -> torch.Tensor: ) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
...@@ -399,7 +401,9 @@ class BlockInt8MoEMethod: ...@@ -399,7 +401,9 @@ class BlockInt8MoEMethod:
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias e_score_correction_bias=e_score_correction_bias,
routed_scaling_factor=routed_scaling_factor,
use_fused_gate=use_fused_gate
) )
# Expert fusion with INT8 quantization # Expert fusion with INT8 quantization
......
...@@ -31,7 +31,10 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( ...@@ -31,7 +31,10 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
should_ignore_layer) should_ignore_layer)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import W8a8GetCacheJSON
import os
from vllm import _custom_ops as ops
logger = init_logger(__name__) logger = init_logger(__name__)
__all__ = ["CompressedTensorsLinearMethod"] __all__ = ["CompressedTensorsLinearMethod"]
...@@ -540,8 +543,30 @@ class CompressedTensorsLinearMethod(LinearMethodBase): ...@@ -540,8 +543,30 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
def __init__(self, quantization_config: CompressedTensorsConfig): def __init__(self, quantization_config: CompressedTensorsConfig):
self.quantization_config = quantization_config self.quantization_config = quantization_config
self.tritonsingleton= W8a8GetCacheJSON()
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
n=layer.weight.shape[0]
k=layer.weight.shape[1]
if self.w8a8_strategy==1:
if {n,k} not in self.tritonsingleton.weight_shapes:
self.tritonsingleton.weight_shapes.append({n,k})
json_file=self.tritonsingleton.get_w8a8json_name(n,k)
configs_dict=self.tritonsingleton.get_triton_cache(json_file,n,k)
if configs_dict:
self.tritonsingleton.triton_json_dict.update(configs_dict)
for key, value in configs_dict.items():
m=int(key.split('_')[0])
ops.triton_int8_gemm_helper(m=m,n=n,k=k,per_token_act_quant=True,per_out_channel_weight_quant=True,use_bias=False,best_config=value)
else:
weight_data=layer.weight.data
_weight=weight_data.T.contiguous().reshape(n,-1)
layer.weight.data=_weight
layer.scheme.process_weights_after_loading(layer) layer.scheme.process_weights_after_loading(layer)
def create_weights(self, layer: torch.nn.Module, def create_weights(self, layer: torch.nn.Module,
......
...@@ -296,6 +296,8 @@ class MoeWNA16Method(FusedMoEMethodBase): ...@@ -296,6 +296,8 @@ class MoeWNA16Method(FusedMoEMethodBase):
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False,
) -> torch.Tensor: ) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
assert activation == "silu", "Only SiLU activation is supported." assert activation == "silu", "Only SiLU activation is supported."
...@@ -309,7 +311,9 @@ class MoeWNA16Method(FusedMoEMethodBase): ...@@ -309,7 +311,9 @@ class MoeWNA16Method(FusedMoEMethodBase):
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias) e_score_correction_bias=e_score_correction_bias,
routed_scaling_factor=routed_scaling_factor,
use_fused_gate=use_fused_gate)
weight_bits = self.quant_config.weight_bits weight_bits = self.quant_config.weight_bits
has_zp = self.quant_config.has_zp has_zp = self.quant_config.has_zp
......
...@@ -300,8 +300,8 @@ def _w8a8_block_int8_matmul( ...@@ -300,8 +300,8 @@ def _w8a8_block_int8_matmul(
GROUP_SIZE_M: tl.constexpr, GROUP_SIZE_M: tl.constexpr,
): ):
"""Triton-accelerated function used to perform linear operations (dot """Triton-accelerated function used to perform linear operations (dot
product) on input tensors `A` and `B` with block-wise quantization, and product) on input tensors `A` and `B` with block-wise quantization,
store the result in output tensor `C`. and store the result in output tensor `C`.
""" """
pid = tl.program_id(axis=0) pid = tl.program_id(axis=0)
...@@ -316,16 +316,29 @@ def _w8a8_block_int8_matmul( ...@@ -316,16 +316,29 @@ def _w8a8_block_int8_matmul(
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
# offs_bsn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_bsn = pid_n * BLOCK_SIZE_N // group_n
offs_k = tl.arange(0, BLOCK_SIZE_K) offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
# a_ptrs = A + (offs_am[:, None] * stride_am)
# b_ptrs = B + (offs_bn[None, :] * stride_bn)
As_ptrs = As + offs_am * stride_As_m As_ptrs = As + offs_am * stride_As_m
offs_bsn = offs_bn // group_n # offs_bsn = offs_bn // group_n
Bs_ptrs = Bs + offs_bsn * stride_Bs_n Bs_ptrs = Bs + offs_bsn * stride_Bs_n
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
k_start = k * BLOCK_SIZE_K
offs_ks = k_start // group_k
a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
a = tl.load(a_ptrs, a = tl.load(a_ptrs,
mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
other=0.0) other=0.0)
...@@ -333,16 +346,13 @@ def _w8a8_block_int8_matmul( ...@@ -333,16 +346,13 @@ def _w8a8_block_int8_matmul(
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
other=0.0) other=0.0)
k_start = k * BLOCK_SIZE_K
offs_ks = k_start // group_k
a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
accumulator += tl.dot(a, b).to(tl.float32) * a_s[:, accumulator += tl.dot(a, b).to(tl.float32) * a_s[:, None] * b_s[None, :]
None] * b_s[None, :]
a_ptrs += BLOCK_SIZE_K * stride_ak a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk b_ptrs += BLOCK_SIZE_K * stride_bk
if C.dtype.element_ty == tl.bfloat16: if C.dtype.element_ty == tl.bfloat16:
c = accumulator.to(tl.bfloat16) c = accumulator.to(tl.bfloat16)
elif C.dtype.element_ty == tl.float16: elif C.dtype.element_ty == tl.float16:
...@@ -436,29 +446,10 @@ def w8a8_block_int8_matmul( ...@@ -436,29 +446,10 @@ def w8a8_block_int8_matmul(
C_shape = A.shape[:-1] + (N, ) C_shape = A.shape[:-1] + (N, )
C = A.new_empty(C_shape, dtype=output_dtype) C = A.new_empty(C_shape, dtype=output_dtype)
# configs = get_w8a8_block_int8_configs(N, K, block_size[0], block_size[1]) if len(W8A8_TRITONJSON.triton_json_list)==0:
# if configs:
# # If an optimal configuration map has been found, look up the
# # optimal config
# config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
# else:
# # Default config
# # Block-wise quant: BLOCK_SIZE_K must be divisible by block_size[1]
# config = {
# "BLOCK_SIZE_M": 64,
# "BLOCK_SIZE_N": block_size[0],
# "BLOCK_SIZE_K": block_size[1],
# "GROUP_SIZE_M": 32,
# "num_warps": 4,
# "num_stages": 3,
# }
#print("W8A8_TRITONJSON.triton_json_dict[0]:",W8A8_TRITONJSON.triton_json_dict[0])
if len(W8A8_TRITONJSON.triton_json_dict)==0:
config=None config=None
#print("len(W8A8_TRITONJSON.triton_json_dict)=0:",len(W8A8_TRITONJSON.triton_json_dict))
elif f"1_{N}_{K}_block[{block_n},{block_k}]" in W8A8_TRITONJSON.triton_json_dict[0]: elif f"1_{N}_{K}_block[{block_n},{block_k}]" in W8A8_TRITONJSON.triton_json_list[0]:
if M<=16: if M<=16:
m_=M m_=M
elif M<=64: elif M<=64:
...@@ -480,12 +471,13 @@ def w8a8_block_int8_matmul( ...@@ -480,12 +471,13 @@ def w8a8_block_int8_matmul(
m_=4096 m_=4096
else: else:
m_=8192 m_=8192
#print("==================m:{},n:{},k:{}".format(M,N,K))
config=W8A8_TRITONJSON.triton_json_dict[0][f"{m_}_{N}_{K}_block[{block_n},{block_k}]"] config=W8A8_TRITONJSON.triton_json_list[0][f"{m_}_{N}_{K}_block[{block_n},{block_k}]"]
else: else:
config=None config=None
if config==None:
# print("m:{},n:{},k:{}".format(M,N,K)) # print("m:{},n:{},k:{}".format(M,N,K))
# print("config not found!") # print("config not found!")
......
...@@ -420,7 +420,7 @@ def apply_int8_linear( ...@@ -420,7 +420,7 @@ def apply_int8_linear(
if len(W8A8_TRITONJSON.triton_json_dict)==0: if len(W8A8_TRITONJSON.triton_json_dict)==0:
best_config=None best_config=None
elif f"1_{n}_{k}" in W8A8_TRITONJSON.triton_json_dict[0]: elif f"1_{n}_{k}" in W8A8_TRITONJSON.triton_json_dict:
if m<=16: if m<=16:
m_=m m_=m
#best_config=W8A8_TRITONJSON.triton_json_dict[0][f"{m}_{n}_{k}"] #best_config=W8A8_TRITONJSON.triton_json_dict[0][f"{m}_{n}_{k}"]
...@@ -444,7 +444,7 @@ def apply_int8_linear( ...@@ -444,7 +444,7 @@ def apply_int8_linear(
else: else:
m_=8192 m_=8192
best_config=W8A8_TRITONJSON.triton_json_dict[0][f"{m_}_{n}_{k}"] best_config=W8A8_TRITONJSON.triton_json_dict[f"{m_}_{n}_{k}"]
else: else:
best_config=None best_config=None
......
...@@ -252,6 +252,8 @@ class W8A8Int8MoEMethod: ...@@ -252,6 +252,8 @@ class W8A8Int8MoEMethod:
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False,
) -> torch.Tensor: ) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
...@@ -266,7 +268,9 @@ class W8A8Int8MoEMethod: ...@@ -266,7 +268,9 @@ class W8A8Int8MoEMethod:
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias e_score_correction_bias=e_score_correction_bias,
routed_scaling_factor=routed_scaling_factor,
use_fused_gate=use_fused_gate
) )
return fused_experts( return fused_experts(
......
...@@ -142,7 +142,8 @@ class DeepseekV2MoE(nn.Module): ...@@ -142,7 +142,8 @@ class DeepseekV2MoE(nn.Module):
topk_group=config.topk_group, topk_group=config.topk_group,
prefix=f"{prefix}.experts", prefix=f"{prefix}.experts",
scoring_func=config.scoring_func, scoring_func=config.scoring_func,
e_score_correction_bias=self.gate.e_score_correction_bias,) e_score_correction_bias=self.gate.e_score_correction_bias,
routed_scaling_factor=self.routed_scaling_factor)
if config.n_shared_experts is not None: if config.n_shared_experts is not None:
intermediate_size = (config.moe_intermediate_size * intermediate_size = (config.moe_intermediate_size *
...@@ -961,7 +962,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP): ...@@ -961,7 +962,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
if configs_dict: if configs_dict:
all_json.update(configs_dict) all_json.update(configs_dict)
self.tritonsingleton.triton_json_dict.append(all_json) self.tritonsingleton.triton_json_list.append(all_json)
#print("self.tritonsingleton.triton_json_dict[0].shape:",len(self.tritonsingleton.triton_json_dict[0])) #print("self.tritonsingleton.triton_json_dict[0].shape:",len(self.tritonsingleton.triton_json_dict[0]))
for key, value in all_json.items(): for key, value in all_json.items():
m=int(key.split('_')[0]) m=int(key.split('_')[0])
......
...@@ -46,7 +46,6 @@ from .interfaces import SupportsPP ...@@ -46,7 +46,6 @@ from .interfaces import SupportsPP
from .utils import (AutoWeightsLoader, is_pp_missing_parameter, from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix) maybe_prefix)
from vllm.utils import is_hip,W8a8GetCacheJSON
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
class GPTNeoXAttention(nn.Module): class GPTNeoXAttention(nn.Module):
...@@ -219,12 +218,10 @@ class GPTNeoXModel(nn.Module): ...@@ -219,12 +218,10 @@ class GPTNeoXModel(nn.Module):
make_empty_intermediate_tensors_factory(["hidden_states"], make_empty_intermediate_tensors_factory(["hidden_states"],
config.hidden_size)) config.hidden_size))
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
self.quant_method = None self.quant_method = None
if quant_config is not None: if quant_config is not None:
self.quant_method=quant_config.get_name() self.quant_method=quant_config.get_name()
self.quant_config=quant_config self.quant_config=quant_config
self.tritonsingleton= W8a8GetCacheJSON()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_in(input_ids) return self.embed_in(input_ids)
...@@ -288,52 +285,6 @@ class GPTNeoXModel(nn.Module): ...@@ -288,52 +285,6 @@ class GPTNeoXModel(nn.Module):
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name) loaded_params.add(name)
#当为triton支持推理的时候不能进行处理
if self.quant_method == "compressed_tensors":
os.environ['LM_NN'] = '0'
lay_key_words = [
"attention.query_key_value.weight",
"attention.dense.weight",
"mlp.dense_h_to_4h.weight",
"mlp.dense_4h_to_h.weight",
]
combined_words = "|".join(lay_key_words)
weight_shapes=[]
all_json={}
matched_key_words=set()
for layername, weight in params_dict.items():
matches = re.findall(combined_words, layername)
if matches and "scale" not in layername:
weight_data =params_dict[layername]
n=weight_data.shape[0]
k=weight_data.shape[1]
#rocblas和cutlass目前都需要weight做处理,但是triton不用
if self.w8a8_strategy!=1:
_weight=weight_data.T.contiguous().reshape(n,-1)
weight_data.data.copy_(_weight)
#下面是针对模型记录模型出现k和n值
elif len(matched_key_words) < 4 and matches[0] not in matched_key_words:
matched_key_words.add(matches[0])
weight_shapes.append({n,k})
json_file=self.tritonsingleton.get_w8a8json_name(n,k)
configs_dict=self.tritonsingleton.get_triton_cache(json_file,n,k)
if configs_dict:
all_json.update(configs_dict)
if self.w8a8_strategy==1:
self.tritonsingleton.triton_json_dict.append(all_json)
#找到的所有config都进行一次warmup
for key, value in all_json.items():
m=int(key.split('_')[0])
n=int(key.split('_')[1])
k=int(key.split('_')[2])
ops.triton_int8_gemm_helper(m=m,n=n,k=k,per_token_act_quant=True,per_out_channel_weight_quant=True,use_bias=False,best_config=value)
return loaded_params return loaded_params
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment