Commit ad141d07 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev' into v0.9.2-dev-ds

parents 36a7e89e f7e9c329
......@@ -965,7 +965,7 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
CALL_CONVERT_FP8(__nv_bfloat16, uint8_t,
vllm::Fp8KVCacheDataType::kFp8E4M3);
}
} else if (kv_cache_dtype == "fp8_e5m2") {
} else if (kv_cache_dtype == "fp8_e5m2") {
if (src_cache.dtype() == at::ScalarType::Float) {
CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kFp8E5M2);
} else if (src_cache.dtype() == at::ScalarType::Half) {
......@@ -980,7 +980,7 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
} else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
CALL_CONVERT_FP8(__nv_bfloat16, uint8_t,
vllm::Fp8KVCacheDataType::kFp8E5M2);
}
}
} else {
TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype);
}
......
......@@ -2174,7 +2174,6 @@ def gather_cache(src_cache: torch.Tensor,
cu_seq_lens, batch_size, seq_starts)
#dst_fp8->bf16
convert_fp8(dst, dst_fp8, scale, kv_dtype)
else:
torch.ops._C_cache_ops.gather_cache(src_cache, dst, block_table,
cu_seq_lens, batch_size, seq_starts)
......
......@@ -30,6 +30,7 @@ try:
except AttributeError:
tag_cudagraph_unsafe = () # type: ignore[assignment]
class Attention(nn.Module):
"""Attention layer.
......@@ -212,9 +213,9 @@ class Attention(nn.Module):
# attn_metadata = get_forward_context().attn_metadata
# #if (attn_metadata is not None and getattr(attn_metadata, "enable_kv_scales_calculation", False)):
# if key is not None and value is not None:
# self.calc_kv_scales(query, key, value)
# self.calc_kv_scales(query, key, value)
torch.ops.vllm.maybe_calc_kv_scales(query, key, value,
self.layer_name)
self.layer_name)
if self.use_output:
output_shape = (output_shape
if output_shape is not None else query.shape)
......@@ -439,6 +440,44 @@ direct_register_custom_op(
dispatch_key=current_platform.dispatch_key,
tags=tag_cudagraph_unsafe,)
def maybe_calc_kv_scales(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
# if attn_metadata is None or not getattr(
# attn_metadata, 'enable_kv_scales_calculation', False):
# return
self = forward_context.no_compile_layers[layer_name]
self.calc_kv_scales(query, key, value)
def maybe_calc_kv_scales_fake( query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
layer_name: str,
) -> None:
return
direct_register_custom_op(
op_name="maybe_calc_kv_scales",
op_func=maybe_calc_kv_scales,
mutates_args=[],
fake_impl=maybe_calc_kv_scales_fake,
dispatch_key=current_platform.dispatch_key,
tags=tag_cudagraph_unsafe,)
def unified_attention(
query: torch.Tensor,
key: torch.Tensor,
......
......@@ -100,7 +100,7 @@ def flash_mla_with_kvcache(
softmax_scale = q.shape[-1]**(-0.5)
if current_platform.is_rocm():
if kv_cache_dtype == "fp8" or kv_cache_dtype == "fp8_e4m3" or kv_cache_dtype == "fp8_e5m2":
kv_dtype = "fp8_e4m3" if kv_cache_dtype == "fp8" else kv_cache_dtype
kv_dtype = "fp8_e4m3" if kv_cache_dtype == "fp8" else kv_cache_dtype
out, softmax_lse = flash_mla_cuda.fwd_kvcache_quantization_mla(
q,
k_cache,
......
......@@ -326,7 +326,7 @@ class ModelConfig:
"""Whether to disable sliding window. If True, we will disable the sliding
window functionality of the model, capping to sliding window size. If the
model does not support sliding window, this argument is ignored."""
disable_cascade_attn: bool = False
disable_cascade_attn: bool = True
"""Disable cascade attention for V1. While cascade attention does not
change the mathematical correctness, disabling it could be useful for
preventing potential numerical issues. Note that even if this is set to
......@@ -419,10 +419,6 @@ class ModelConfig:
override_attention_dtype: Optional[str] = None
"""Override dtype for attention"""
enable_chunked_prefill: Optional[bool] = None
"""If True, prefill requests can be chunked based
on the remaining max_num_batched_tokens."""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
......@@ -452,7 +448,6 @@ class ModelConfig:
factors.append(self.rope_theta)
# hf_config can control how the model looks!
factors.append(self.hf_config.to_json_string())
factors.append(self.enable_chunked_prefill)
str_factors = str(factors)
assert_hashable(str_factors)
return hashlib.sha256(str(factors).encode()).hexdigest()
......
......@@ -1004,7 +1004,6 @@ class EngineArgs:
enable_sleep_mode=self.enable_sleep_mode,
model_impl=self.model_impl,
override_attention_dtype=self.override_attention_dtype,
enable_chunked_prefill=self.enable_chunked_prefill,
)
def create_load_config(self) -> LoadConfig:
......@@ -1594,9 +1593,6 @@ class EngineArgs:
# For pooling tasks the default is False
if model_config.runner_type != "pooling":
self.enable_chunked_prefill = True
if model_config.enable_chunked_prefill is not None and \
model_config.enable_chunked_prefill is False:
self.enable_chunked_prefill = False
if self.enable_prefix_caching is None:
self.enable_prefix_caching = True
else:
......@@ -1610,10 +1606,6 @@ class EngineArgs:
action = "Enabling" if \
incremental_prefill_supported else "Disabling"
if model_config.enable_chunked_prefill is not None and \
model_config.enable_chunked_prefill is False:
self.enable_chunked_prefill = False
if self.enable_chunked_prefill is None:
self.enable_chunked_prefill = incremental_prefill_supported
......
......@@ -166,6 +166,9 @@ if TYPE_CHECKING:
VLLM_USE_GLOBAL_CACHE13: bool = False
VLLM_USE_LIGHTOP: bool = False
VLLM_USE_OPT_CAT: bool = False
VLLM_USE_OPT_MOE_SUM: bool = False
VLLM_USE_LIGHTOP_MOE_SUM: bool = False
VLLM_USE_LIGHTOP_MOE_ALIGN: bool = False
VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False
USE_FUSED_RMS_QUANT: bool = False
USE_FUSED_SILU_MUL_QUANT: bool = False
......@@ -1104,6 +1107,18 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_OPT_CAT":
lambda: (os.environ.get("VLLM_USE_OPT_CAT", "False").lower() in
("true", "1")),
# vLLM will use triton moe_sum
"VLLM_USE_OPT_MOE_SUM":
lambda: (os.environ.get("VLLM_USE_OPT_MOE_SUM", "False").lower() in
("true", "1")),
# vLLM will use lightop moe_sum
"VLLM_USE_LIGHTOP_MOE_SUM":
lambda: (os.environ.get("VLLM_USE_LIGHTOP_MOE_SUM", "False").lower() in
("true", "1")),
# vLLM will use lightop moe_align_block_size
"VLLM_USE_LIGHTOP_MOE_ALIGN":
lambda: (os.environ.get("VLLM_USE_LIGHTOP_MOE_ALIGN", "False").lower() in
("true", "1")),
# vLLM will use opt merge_aatn_states, not triton
"VLLM_USE_MERGE_ATTN_STATES_OPT":
lambda: (os.environ.get("VLLM_USE_MERGE_ATTN_STATES_OPT", "True").lower() in
......
......@@ -53,6 +53,136 @@ logger = init_logger(__name__)
if envs.VLLM_USE_GLOBAL_CACHE13:
moe_cache_singleton = None
@torch.compile
def moe_sum_reduce_torch_compile(x, out, routed_scaling_factor):
torch.sum(x, dim=1, out=out)
out.mul_(routed_scaling_factor)
@triton.jit
def _moe_sum_reduce_kernel(
input_ptr,
input_stride_0,
input_stride_1,
input_stride_2,
output_ptr,
output_stride_0,
output_stride_1,
token_num: int,
topk_num: int,
hidden_dim: int,
routed_scaling_factor: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_DIM: tl.constexpr,
NUM_STAGE: tl.constexpr,
):
input_stride_0 = tl.cast(input_stride_0, dtype=tl.int64)
input_stride_1 = tl.cast(input_stride_1, dtype=tl.int64)
output_stride_0 = tl.cast(output_stride_0, dtype=tl.int64)
token_block_id = tl.program_id(0)
dim_block_id = tl.program_id(1)
token_start = token_block_id * BLOCK_M
token_end = min((token_block_id + 1) * BLOCK_M, token_num)
dim_start = dim_block_id * BLOCK_DIM
dim_end = min((dim_block_id + 1) * BLOCK_DIM, hidden_dim)
offs_dim = dim_start + tl.arange(0, BLOCK_DIM)
for token_index in range(token_start, token_end):
accumulator = tl.zeros((BLOCK_DIM,), dtype=tl.float32)
input_t_ptr = input_ptr + token_index * input_stride_0 + offs_dim
for i in tl.range(0, topk_num, num_stages=NUM_STAGE):
tmp = tl.load(
input_t_ptr + i * input_stride_1, mask=offs_dim < dim_end, other=0.0
)
accumulator += tmp
accumulator = accumulator * routed_scaling_factor
store_t_ptr = output_ptr + token_index * output_stride_0 + offs_dim
tl.store(
store_t_ptr,
accumulator.to(input_ptr.dtype.element_ty),
mask=offs_dim < dim_end,
)
def moe_sum_reduce_triton(
input: torch.Tensor, output: torch.Tensor, routed_scaling_factor: float
):
assert input.is_contiguous()
assert output.is_contiguous()
token_num, topk_num, hidden_dim = input.shape
assert output.shape[0] == token_num and output.shape[1] == hidden_dim
if token_num <= 32:
BLOCK_M = 1
BLOCK_DIM = 512
NUM_STAGE = 2
num_warps = 4
elif token_num <= 128:
BLOCK_M = 1
BLOCK_DIM = 1024
NUM_STAGE = 0
num_warps = 2
elif token_num <= 4096:
BLOCK_M = 1
BLOCK_DIM = 2048
NUM_STAGE = 0
num_warps = 2
else:
BLOCK_M = 1
BLOCK_DIM = 2048
NUM_STAGE = 2
num_warps = 8
grid = (
triton.cdiv(token_num, BLOCK_M),
triton.cdiv(hidden_dim, BLOCK_DIM),
)
_moe_sum_reduce_kernel[grid](
input,
*input.stride(),
output,
*output.stride(),
token_num=token_num,
topk_num=topk_num,
hidden_dim=hidden_dim,
routed_scaling_factor=routed_scaling_factor,
BLOCK_M=BLOCK_M,
BLOCK_DIM=BLOCK_DIM,
NUM_STAGE=NUM_STAGE,
num_warps=num_warps,
)
return
def moe_reduce_dispatch(
intermediate_cache3: torch.Tensor,
out_hidden_states: torch.Tensor,
begin_chunk_idx: int,
end_chunk_idx: int,
):
inter_cache_view = intermediate_cache3.view(*intermediate_cache3.shape)
n = intermediate_cache3.shape[0]
# 根据 n 大小选择不同的 reduce 实现
if 1 <= n <= 4:
moe_sum_reduce_torch_compile(inter_cache_view, out_hidden_states[begin_chunk_idx:end_chunk_idx], 1.0)
elif 4 < n <= 1024:
moe_sum_reduce_triton(inter_cache_view, out_hidden_states[begin_chunk_idx:end_chunk_idx], 1.0)
elif 1024 < n <= 32768:
ops.moe_sum_opt1(inter_cache_view, out_hidden_states[begin_chunk_idx:end_chunk_idx])
else:
ops.moe_sum(inter_cache_view, out_hidden_states[begin_chunk_idx:end_chunk_idx])
def get_moe_cache(top_k_num,N,K,device,dtype):
global moe_cache_singleton
if moe_cache_singleton is None:
......@@ -1789,8 +1919,16 @@ def fused_experts_impl(
# ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()),
# out_hidden_states[begin_chunk_idx:end_chunk_idx]) * routed_scaling_factor
else:
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()),
out_hidden_states[begin_chunk_idx:end_chunk_idx])
if envs.VLLM_USE_LIGHTOP_MOE_SUM:
from lightop import op as op
op.moe_sum(input=intermediate_cache3.view(*intermediate_cache3.size()),
output=out_hidden_states[begin_chunk_idx:end_chunk_idx], bias=None,
expert_mask=None, num_local_tokens=None, factor=1.0)
elif envs.VLLM_USE_OPT_MOE_SUM:
moe_reduce_dispatch(intermediate_cache3.view(*intermediate_cache3.size()), out_hidden_states[begin_chunk_idx:end_chunk_idx], begin_chunk_idx, end_chunk_idx)
else:
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.size()),
out_hidden_states[begin_chunk_idx:end_chunk_idx])
return out_hidden_states
......
......@@ -240,8 +240,16 @@ def moe_align_block_size(
expert_mask = expert_mask,
num_local_tokens = None)
else:
ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
expert_ids, num_tokens_post_pad)
if envs.VLLM_USE_LIGHTOP_MOE_ALIGN:
from lightop import op as op
op.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
expert_ids, num_tokens_post_pad,
expert_map = None,
expert_mask = None,
num_local_tokens = None)
else:
ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
expert_ids, num_tokens_post_pad)
if expert_map is not None:
expert_ids = expert_map[expert_ids]
......
......@@ -21,6 +21,7 @@ from vllm.utils import W8a8GetCacheJSON
import os
from vllm import _custom_ops as ops
from vllm import envs
try:
......@@ -441,3 +442,5 @@ class SlimQuantW4A8Int8MoEMethod:
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
)
......@@ -925,7 +925,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
self.tritonsingleton= W8a8GetCacheJSON()
self.tritonsingleton.topk = config.num_experts_per_tok
self.tritonsingleton.quant_method=self.quant_method
self.tritonsingleton.quant_method=self.quant_method
parallel_config = vllm_config.parallel_config
dp_size = get_dp_group().world_size
......@@ -1120,7 +1120,10 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
try:
param = params_dict[name]
except Exception as e:
continue
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import copy
import gc
import time
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment