Commit ad141d07 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev' into v0.9.2-dev-ds

parents 36a7e89e f7e9c329
...@@ -965,7 +965,7 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, ...@@ -965,7 +965,7 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, CALL_CONVERT_FP8(__nv_bfloat16, uint8_t,
vllm::Fp8KVCacheDataType::kFp8E4M3); vllm::Fp8KVCacheDataType::kFp8E4M3);
} }
} else if (kv_cache_dtype == "fp8_e5m2") { } else if (kv_cache_dtype == "fp8_e5m2") {
if (src_cache.dtype() == at::ScalarType::Float) { if (src_cache.dtype() == at::ScalarType::Float) {
CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kFp8E5M2); CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kFp8E5M2);
} else if (src_cache.dtype() == at::ScalarType::Half) { } else if (src_cache.dtype() == at::ScalarType::Half) {
...@@ -980,7 +980,7 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, ...@@ -980,7 +980,7 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) { } else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, CALL_CONVERT_FP8(__nv_bfloat16, uint8_t,
vllm::Fp8KVCacheDataType::kFp8E5M2); vllm::Fp8KVCacheDataType::kFp8E5M2);
} }
} else { } else {
TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype); TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype);
} }
......
...@@ -2174,7 +2174,6 @@ def gather_cache(src_cache: torch.Tensor, ...@@ -2174,7 +2174,6 @@ def gather_cache(src_cache: torch.Tensor,
cu_seq_lens, batch_size, seq_starts) cu_seq_lens, batch_size, seq_starts)
#dst_fp8->bf16 #dst_fp8->bf16
convert_fp8(dst, dst_fp8, scale, kv_dtype) convert_fp8(dst, dst_fp8, scale, kv_dtype)
else: else:
torch.ops._C_cache_ops.gather_cache(src_cache, dst, block_table, torch.ops._C_cache_ops.gather_cache(src_cache, dst, block_table,
cu_seq_lens, batch_size, seq_starts) cu_seq_lens, batch_size, seq_starts)
......
...@@ -30,6 +30,7 @@ try: ...@@ -30,6 +30,7 @@ try:
except AttributeError: except AttributeError:
tag_cudagraph_unsafe = () # type: ignore[assignment] tag_cudagraph_unsafe = () # type: ignore[assignment]
class Attention(nn.Module): class Attention(nn.Module):
"""Attention layer. """Attention layer.
...@@ -212,9 +213,9 @@ class Attention(nn.Module): ...@@ -212,9 +213,9 @@ class Attention(nn.Module):
# attn_metadata = get_forward_context().attn_metadata # attn_metadata = get_forward_context().attn_metadata
# #if (attn_metadata is not None and getattr(attn_metadata, "enable_kv_scales_calculation", False)): # #if (attn_metadata is not None and getattr(attn_metadata, "enable_kv_scales_calculation", False)):
# if key is not None and value is not None: # if key is not None and value is not None:
# self.calc_kv_scales(query, key, value) # self.calc_kv_scales(query, key, value)
torch.ops.vllm.maybe_calc_kv_scales(query, key, value, torch.ops.vllm.maybe_calc_kv_scales(query, key, value,
self.layer_name) self.layer_name)
if self.use_output: if self.use_output:
output_shape = (output_shape output_shape = (output_shape
if output_shape is not None else query.shape) if output_shape is not None else query.shape)
...@@ -439,6 +440,44 @@ direct_register_custom_op( ...@@ -439,6 +440,44 @@ direct_register_custom_op(
dispatch_key=current_platform.dispatch_key, dispatch_key=current_platform.dispatch_key,
tags=tag_cudagraph_unsafe,) tags=tag_cudagraph_unsafe,)
def maybe_calc_kv_scales(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
# if attn_metadata is None or not getattr(
# attn_metadata, 'enable_kv_scales_calculation', False):
# return
self = forward_context.no_compile_layers[layer_name]
self.calc_kv_scales(query, key, value)
def maybe_calc_kv_scales_fake( query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
layer_name: str,
) -> None:
return
direct_register_custom_op(
op_name="maybe_calc_kv_scales",
op_func=maybe_calc_kv_scales,
mutates_args=[],
fake_impl=maybe_calc_kv_scales_fake,
dispatch_key=current_platform.dispatch_key,
tags=tag_cudagraph_unsafe,)
def unified_attention( def unified_attention(
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
......
...@@ -100,7 +100,7 @@ def flash_mla_with_kvcache( ...@@ -100,7 +100,7 @@ def flash_mla_with_kvcache(
softmax_scale = q.shape[-1]**(-0.5) softmax_scale = q.shape[-1]**(-0.5)
if current_platform.is_rocm(): if current_platform.is_rocm():
if kv_cache_dtype == "fp8" or kv_cache_dtype == "fp8_e4m3" or kv_cache_dtype == "fp8_e5m2": if kv_cache_dtype == "fp8" or kv_cache_dtype == "fp8_e4m3" or kv_cache_dtype == "fp8_e5m2":
kv_dtype = "fp8_e4m3" if kv_cache_dtype == "fp8" else kv_cache_dtype kv_dtype = "fp8_e4m3" if kv_cache_dtype == "fp8" else kv_cache_dtype
out, softmax_lse = flash_mla_cuda.fwd_kvcache_quantization_mla( out, softmax_lse = flash_mla_cuda.fwd_kvcache_quantization_mla(
q, q,
k_cache, k_cache,
......
...@@ -326,7 +326,7 @@ class ModelConfig: ...@@ -326,7 +326,7 @@ class ModelConfig:
"""Whether to disable sliding window. If True, we will disable the sliding """Whether to disable sliding window. If True, we will disable the sliding
window functionality of the model, capping to sliding window size. If the window functionality of the model, capping to sliding window size. If the
model does not support sliding window, this argument is ignored.""" model does not support sliding window, this argument is ignored."""
disable_cascade_attn: bool = False disable_cascade_attn: bool = True
"""Disable cascade attention for V1. While cascade attention does not """Disable cascade attention for V1. While cascade attention does not
change the mathematical correctness, disabling it could be useful for change the mathematical correctness, disabling it could be useful for
preventing potential numerical issues. Note that even if this is set to preventing potential numerical issues. Note that even if this is set to
...@@ -419,10 +419,6 @@ class ModelConfig: ...@@ -419,10 +419,6 @@ class ModelConfig:
override_attention_dtype: Optional[str] = None override_attention_dtype: Optional[str] = None
"""Override dtype for attention""" """Override dtype for attention"""
enable_chunked_prefill: Optional[bool] = None
"""If True, prefill requests can be chunked based
on the remaining max_num_batched_tokens."""
def compute_hash(self) -> str: def compute_hash(self) -> str:
""" """
WARNING: Whenever a new field is added to this config, WARNING: Whenever a new field is added to this config,
...@@ -452,7 +448,6 @@ class ModelConfig: ...@@ -452,7 +448,6 @@ class ModelConfig:
factors.append(self.rope_theta) factors.append(self.rope_theta)
# hf_config can control how the model looks! # hf_config can control how the model looks!
factors.append(self.hf_config.to_json_string()) factors.append(self.hf_config.to_json_string())
factors.append(self.enable_chunked_prefill)
str_factors = str(factors) str_factors = str(factors)
assert_hashable(str_factors) assert_hashable(str_factors)
return hashlib.sha256(str(factors).encode()).hexdigest() return hashlib.sha256(str(factors).encode()).hexdigest()
......
...@@ -1004,7 +1004,6 @@ class EngineArgs: ...@@ -1004,7 +1004,6 @@ class EngineArgs:
enable_sleep_mode=self.enable_sleep_mode, enable_sleep_mode=self.enable_sleep_mode,
model_impl=self.model_impl, model_impl=self.model_impl,
override_attention_dtype=self.override_attention_dtype, override_attention_dtype=self.override_attention_dtype,
enable_chunked_prefill=self.enable_chunked_prefill,
) )
def create_load_config(self) -> LoadConfig: def create_load_config(self) -> LoadConfig:
...@@ -1594,9 +1593,6 @@ class EngineArgs: ...@@ -1594,9 +1593,6 @@ class EngineArgs:
# For pooling tasks the default is False # For pooling tasks the default is False
if model_config.runner_type != "pooling": if model_config.runner_type != "pooling":
self.enable_chunked_prefill = True self.enable_chunked_prefill = True
if model_config.enable_chunked_prefill is not None and \
model_config.enable_chunked_prefill is False:
self.enable_chunked_prefill = False
if self.enable_prefix_caching is None: if self.enable_prefix_caching is None:
self.enable_prefix_caching = True self.enable_prefix_caching = True
else: else:
...@@ -1610,10 +1606,6 @@ class EngineArgs: ...@@ -1610,10 +1606,6 @@ class EngineArgs:
action = "Enabling" if \ action = "Enabling" if \
incremental_prefill_supported else "Disabling" incremental_prefill_supported else "Disabling"
if model_config.enable_chunked_prefill is not None and \
model_config.enable_chunked_prefill is False:
self.enable_chunked_prefill = False
if self.enable_chunked_prefill is None: if self.enable_chunked_prefill is None:
self.enable_chunked_prefill = incremental_prefill_supported self.enable_chunked_prefill = incremental_prefill_supported
......
...@@ -166,6 +166,9 @@ if TYPE_CHECKING: ...@@ -166,6 +166,9 @@ if TYPE_CHECKING:
VLLM_USE_GLOBAL_CACHE13: bool = False VLLM_USE_GLOBAL_CACHE13: bool = False
VLLM_USE_LIGHTOP: bool = False VLLM_USE_LIGHTOP: bool = False
VLLM_USE_OPT_CAT: bool = False VLLM_USE_OPT_CAT: bool = False
VLLM_USE_OPT_MOE_SUM: bool = False
VLLM_USE_LIGHTOP_MOE_SUM: bool = False
VLLM_USE_LIGHTOP_MOE_ALIGN: bool = False
VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False
USE_FUSED_RMS_QUANT: bool = False USE_FUSED_RMS_QUANT: bool = False
USE_FUSED_SILU_MUL_QUANT: bool = False USE_FUSED_SILU_MUL_QUANT: bool = False
...@@ -1104,6 +1107,18 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1104,6 +1107,18 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_OPT_CAT": "VLLM_USE_OPT_CAT":
lambda: (os.environ.get("VLLM_USE_OPT_CAT", "False").lower() in lambda: (os.environ.get("VLLM_USE_OPT_CAT", "False").lower() in
("true", "1")), ("true", "1")),
# vLLM will use triton moe_sum
"VLLM_USE_OPT_MOE_SUM":
lambda: (os.environ.get("VLLM_USE_OPT_MOE_SUM", "False").lower() in
("true", "1")),
# vLLM will use lightop moe_sum
"VLLM_USE_LIGHTOP_MOE_SUM":
lambda: (os.environ.get("VLLM_USE_LIGHTOP_MOE_SUM", "False").lower() in
("true", "1")),
# vLLM will use lightop moe_align_block_size
"VLLM_USE_LIGHTOP_MOE_ALIGN":
lambda: (os.environ.get("VLLM_USE_LIGHTOP_MOE_ALIGN", "False").lower() in
("true", "1")),
# vLLM will use opt merge_aatn_states, not triton # vLLM will use opt merge_aatn_states, not triton
"VLLM_USE_MERGE_ATTN_STATES_OPT": "VLLM_USE_MERGE_ATTN_STATES_OPT":
lambda: (os.environ.get("VLLM_USE_MERGE_ATTN_STATES_OPT", "True").lower() in lambda: (os.environ.get("VLLM_USE_MERGE_ATTN_STATES_OPT", "True").lower() in
......
...@@ -53,6 +53,136 @@ logger = init_logger(__name__) ...@@ -53,6 +53,136 @@ logger = init_logger(__name__)
if envs.VLLM_USE_GLOBAL_CACHE13: if envs.VLLM_USE_GLOBAL_CACHE13:
moe_cache_singleton = None moe_cache_singleton = None
@torch.compile
def moe_sum_reduce_torch_compile(x, out, routed_scaling_factor):
torch.sum(x, dim=1, out=out)
out.mul_(routed_scaling_factor)
@triton.jit
def _moe_sum_reduce_kernel(
input_ptr,
input_stride_0,
input_stride_1,
input_stride_2,
output_ptr,
output_stride_0,
output_stride_1,
token_num: int,
topk_num: int,
hidden_dim: int,
routed_scaling_factor: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_DIM: tl.constexpr,
NUM_STAGE: tl.constexpr,
):
input_stride_0 = tl.cast(input_stride_0, dtype=tl.int64)
input_stride_1 = tl.cast(input_stride_1, dtype=tl.int64)
output_stride_0 = tl.cast(output_stride_0, dtype=tl.int64)
token_block_id = tl.program_id(0)
dim_block_id = tl.program_id(1)
token_start = token_block_id * BLOCK_M
token_end = min((token_block_id + 1) * BLOCK_M, token_num)
dim_start = dim_block_id * BLOCK_DIM
dim_end = min((dim_block_id + 1) * BLOCK_DIM, hidden_dim)
offs_dim = dim_start + tl.arange(0, BLOCK_DIM)
for token_index in range(token_start, token_end):
accumulator = tl.zeros((BLOCK_DIM,), dtype=tl.float32)
input_t_ptr = input_ptr + token_index * input_stride_0 + offs_dim
for i in tl.range(0, topk_num, num_stages=NUM_STAGE):
tmp = tl.load(
input_t_ptr + i * input_stride_1, mask=offs_dim < dim_end, other=0.0
)
accumulator += tmp
accumulator = accumulator * routed_scaling_factor
store_t_ptr = output_ptr + token_index * output_stride_0 + offs_dim
tl.store(
store_t_ptr,
accumulator.to(input_ptr.dtype.element_ty),
mask=offs_dim < dim_end,
)
def moe_sum_reduce_triton(
input: torch.Tensor, output: torch.Tensor, routed_scaling_factor: float
):
assert input.is_contiguous()
assert output.is_contiguous()
token_num, topk_num, hidden_dim = input.shape
assert output.shape[0] == token_num and output.shape[1] == hidden_dim
if token_num <= 32:
BLOCK_M = 1
BLOCK_DIM = 512
NUM_STAGE = 2
num_warps = 4
elif token_num <= 128:
BLOCK_M = 1
BLOCK_DIM = 1024
NUM_STAGE = 0
num_warps = 2
elif token_num <= 4096:
BLOCK_M = 1
BLOCK_DIM = 2048
NUM_STAGE = 0
num_warps = 2
else:
BLOCK_M = 1
BLOCK_DIM = 2048
NUM_STAGE = 2
num_warps = 8
grid = (
triton.cdiv(token_num, BLOCK_M),
triton.cdiv(hidden_dim, BLOCK_DIM),
)
_moe_sum_reduce_kernel[grid](
input,
*input.stride(),
output,
*output.stride(),
token_num=token_num,
topk_num=topk_num,
hidden_dim=hidden_dim,
routed_scaling_factor=routed_scaling_factor,
BLOCK_M=BLOCK_M,
BLOCK_DIM=BLOCK_DIM,
NUM_STAGE=NUM_STAGE,
num_warps=num_warps,
)
return
def moe_reduce_dispatch(
intermediate_cache3: torch.Tensor,
out_hidden_states: torch.Tensor,
begin_chunk_idx: int,
end_chunk_idx: int,
):
inter_cache_view = intermediate_cache3.view(*intermediate_cache3.shape)
n = intermediate_cache3.shape[0]
# 根据 n 大小选择不同的 reduce 实现
if 1 <= n <= 4:
moe_sum_reduce_torch_compile(inter_cache_view, out_hidden_states[begin_chunk_idx:end_chunk_idx], 1.0)
elif 4 < n <= 1024:
moe_sum_reduce_triton(inter_cache_view, out_hidden_states[begin_chunk_idx:end_chunk_idx], 1.0)
elif 1024 < n <= 32768:
ops.moe_sum_opt1(inter_cache_view, out_hidden_states[begin_chunk_idx:end_chunk_idx])
else:
ops.moe_sum(inter_cache_view, out_hidden_states[begin_chunk_idx:end_chunk_idx])
def get_moe_cache(top_k_num,N,K,device,dtype): def get_moe_cache(top_k_num,N,K,device,dtype):
global moe_cache_singleton global moe_cache_singleton
if moe_cache_singleton is None: if moe_cache_singleton is None:
...@@ -1789,8 +1919,16 @@ def fused_experts_impl( ...@@ -1789,8 +1919,16 @@ def fused_experts_impl(
# ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()), # ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()),
# out_hidden_states[begin_chunk_idx:end_chunk_idx]) * routed_scaling_factor # out_hidden_states[begin_chunk_idx:end_chunk_idx]) * routed_scaling_factor
else: else:
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()), if envs.VLLM_USE_LIGHTOP_MOE_SUM:
out_hidden_states[begin_chunk_idx:end_chunk_idx]) from lightop import op as op
op.moe_sum(input=intermediate_cache3.view(*intermediate_cache3.size()),
output=out_hidden_states[begin_chunk_idx:end_chunk_idx], bias=None,
expert_mask=None, num_local_tokens=None, factor=1.0)
elif envs.VLLM_USE_OPT_MOE_SUM:
moe_reduce_dispatch(intermediate_cache3.view(*intermediate_cache3.size()), out_hidden_states[begin_chunk_idx:end_chunk_idx], begin_chunk_idx, end_chunk_idx)
else:
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()),
out_hidden_states[begin_chunk_idx:end_chunk_idx])
return out_hidden_states return out_hidden_states
......
...@@ -240,8 +240,16 @@ def moe_align_block_size( ...@@ -240,8 +240,16 @@ def moe_align_block_size(
expert_mask = expert_mask, expert_mask = expert_mask,
num_local_tokens = None) num_local_tokens = None)
else: else:
ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids, if envs.VLLM_USE_LIGHTOP_MOE_ALIGN:
expert_ids, num_tokens_post_pad) from lightop import op as op
op.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
expert_ids, num_tokens_post_pad,
expert_map = None,
expert_mask = None,
num_local_tokens = None)
else:
ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
expert_ids, num_tokens_post_pad)
if expert_map is not None: if expert_map is not None:
expert_ids = expert_map[expert_ids] expert_ids = expert_map[expert_ids]
......
...@@ -21,6 +21,7 @@ from vllm.utils import W8a8GetCacheJSON ...@@ -21,6 +21,7 @@ from vllm.utils import W8a8GetCacheJSON
import os import os
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm import envs from vllm import envs
try: try:
...@@ -441,3 +442,5 @@ class SlimQuantW4A8Int8MoEMethod: ...@@ -441,3 +442,5 @@ class SlimQuantW4A8Int8MoEMethod:
shared_output=shared_output, shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
) )
...@@ -925,7 +925,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): ...@@ -925,7 +925,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1' self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
self.tritonsingleton= W8a8GetCacheJSON() self.tritonsingleton= W8a8GetCacheJSON()
self.tritonsingleton.topk = config.num_experts_per_tok self.tritonsingleton.topk = config.num_experts_per_tok
self.tritonsingleton.quant_method=self.quant_method self.tritonsingleton.quant_method=self.quant_method
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
dp_size = get_dp_group().world_size dp_size = get_dp_group().world_size
...@@ -1120,7 +1120,10 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): ...@@ -1120,7 +1120,10 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
if is_pp_missing_parameter(name, self): if is_pp_missing_parameter(name, self):
continue continue
param = params_dict[name] try:
param = params_dict[name]
except Exception as e:
continue
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import copy import copy
import gc import gc
import time import time
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment