Commit 8824ae6a authored by 王敏's avatar 王敏
Browse files

merge 092-dev分支近期修改

parents f9f1887d c0707728
......@@ -10,7 +10,7 @@ from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.linear import (LinearBase,LinearMethodBase)
from vllm.model_executor.layers.quantization.base_config import (QuantizationConfig,
QuantizeMethodBase)
from vllm.model_executor.layers.quantization.utils.w4a8_utils import w4a8_2_marlin_weight
from vllm.model_executor.layers.quantization.utils.w4a8_utils import w4a8_weight_repack_impl
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
FusedMoeWeightScaleSupported)
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
......@@ -22,6 +22,7 @@ try:
except Exception:
print("INFO: Please install lmslim if you want to infer the quantitative model of moe.\n")
class MarlinMoeWorkspace:
"""
Singleton manager for device-specific workspace buffers used by w4a8 Marlin-MoE.
......@@ -205,16 +206,8 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
layer.w2_weight_scale.data, requires_grad=False
)
w1_marlin_list = []
for e in range(layer.w13_weight.shape[0]):
w1_marlin_in = w4a8_2_marlin_weight(layer.w13_weight[e])
w1_marlin_list.append(w1_marlin_in)
layer.w13_weight = Parameter(torch.stack(w1_marlin_list, dim=0), requires_grad=False)
w2_marlin_list = []
for e in range(layer.w2_weight.shape[0]):
w2_marlin_in = w4a8_2_marlin_weight(layer.w2_weight[e])
w2_marlin_list.append(w2_marlin_in)
layer.w2_weight = Parameter(torch.stack(w2_marlin_list, dim=0), requires_grad=False)
layer.w13_weight = Parameter(w4a8_weight_repack_impl(layer.w13_weight), requires_grad=False)
layer.w2_weight = Parameter(w4a8_weight_repack_impl(layer.w2_weight), requires_grad=False)
def apply_ep( #dp+ep
self,
......
......@@ -176,15 +176,19 @@ def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) \
supports_router_weight = not layer.apply_router_weight_on_input
# moe marlin requires the activation to be silu
supports_activation = layer.activation == "silu"
#暂时只支持bw
device_name = torch.cuda.get_device_properties(torch.cuda.current_device()).name
supports_device = "BW" in device_name
# gate-up: (n, k) = (intermediate_size_per_partition * 2, hidden_size)
# down: (n, k) = (hidden_size, intermediate_size_per_partition)
# moe marlin requires n % 128 == 0 and k % 64 == 0
supports_shape = hidden_size % 128 == 0 and \
intermediate_size_per_partition % max(64, group_size) == 0
supports_group_size = group_size in [-1, 32, 64, 128]
#暂时只支持64
supports_group_size = group_size in [64]
return supports_shape and supports_group_size and \
supports_router_weight and supports_activation
supports_router_weight and supports_activation and supports_device
def marlin_make_workspace(output_size_per_partition: int,
......
......@@ -2,6 +2,12 @@
import torch
import numpy as np
try:
from lightop import awq_marlin_repack_w4a8
use_lightop = True
except Exception:
use_lightop = False
def unpack_int8_to_int4(tensor_int8: torch.Tensor) -> torch.Tensor:
"""
将[N, K//2]大小的torch.int8 Tensor,转换为[N, K]大小的torch.int32 Tensor。
......@@ -54,12 +60,12 @@ def marlin_weights(q_w,weight_perm,k_tile=32,n_tile=64,pack_factor=8):
q_w = q_w.reshape((-1, weight_perm.numel()))[:, weight_perm].reshape(q_w.shape)
orig_device = q_w.device
q_w = q_w.cpu().numpy().astype(np.uint32)
q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32)
q_w = q_w.contiguous().to(torch.int32)
M, N = q_w.shape
assert N % pack_factor == 0, f"size_n ({N}) must be divisible by pack_factor ({pack_factor})"
q_packed = torch.zeros((M, N // pack_factor), dtype=torch.int32, device=orig_device)
for i in range(pack_factor):
q_packed |= q_w[:, i::pack_factor] << 4 * i
q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device)
q_packed += q_w[:, i::pack_factor] << (4 * i)
return q_packed
......@@ -70,3 +76,18 @@ def w4a8_2_marlin_weight(w4a8_w):
marlin_q_w = marlin_weights(full_w4a8_w, weight_perm, k_tile=32, n_tile=64, pack_factor=8)
return marlin_q_w
def w4a8_weight_repack_impl(input):
if use_lightop:
size_batch = input.shape[0]
size_n = input.shape[1]
size_k = input.shape[2] * 2
output = torch.zeros((size_batch, size_k // 32, size_n * 4), device=input.device, dtype=torch.int32)
awq_marlin_repack_w4a8(input, output, size_batch, size_k, size_n)
else:
w_marlin_list = []
for e in range(input.shape[0]):
w_marlin_in = w4a8_2_marlin_weight(input[e])
w_marlin_list.append(w_marlin_in)
output = torch.stack(w_marlin_list, dim=0)
return output
\ No newline at end of file
......@@ -40,6 +40,8 @@ from vllm.platforms import current_platform
if current_platform.is_cuda():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
if current_platform.is_rocm():
from flash_attn.layers.rotary import apply_rotary_emb
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
......
......@@ -96,11 +96,21 @@ class DeepseekV2MLP(nn.Module):
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
def forward(self, x,
rms_weight: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None,
update_hd: Optional[bool] = False
):
if envs.USE_FUSED_RMS_QUANT:
gate_up, new_resi, _ = self.gate_up_proj(x, rms_weight, residual, update_hd=update_hd)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x, new_resi
else:
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class DeepseekV2MoE(nn.Module):
......@@ -153,10 +163,10 @@ class DeepseekV2MoE(nn.Module):
self.n_local_physical_experts)
self.physical_expert_end = (self.physical_expert_start +
self.n_local_physical_experts)
dp_size = get_dp_group().world_size
self.use_all2all_ep = envs.VLLM_USE_ALLTOALL_EP and dp_size > 1 and parallel_config.enable_expert_parallel
moe_cls = FusedMoE if not self.use_all2all_ep else EPMoE
self.experts = moe_cls(
num_experts=config.n_routed_experts,
......@@ -179,8 +189,8 @@ class DeepseekV2MoE(nn.Module):
if config.n_shared_experts is not None:
intermediate_size = (config.moe_intermediate_size *
config.n_shared_experts)
shared_expert_cls = DeepseekV2MLP if not self.use_all2all_ep else EPSharedExperts
self.shared_experts = shared_expert_cls(
#shared_expert_cls = DeepseekV2MLP if not self.use_all2all_ep else EPSharedExperts
self.shared_experts = DeepseekV2MLP(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
......@@ -195,13 +205,21 @@ class DeepseekV2MoE(nn.Module):
from vllm.two_batch_overlap.two_batch_overlap import tbo_all_reduce
self.tbo_all_reduce = tbo_all_reduce
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
def forward(self, hidden_states: torch.Tensor,
rms_weight: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None
) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
if not self.use_all2all_ep:
if self.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
if envs.USE_FUSED_RMS_QUANT:
shared_output, new_resi = self.shared_experts(hidden_states, rms_weight, residual, update_hd=True)
else:
shared_output = self.shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
if not self.use_all2all_ep:
......@@ -215,9 +233,9 @@ class DeepseekV2MoE(nn.Module):
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
else:
final_hidden_states = self.experts(hidden_states=hidden_states,
final_hidden_states, new_resi = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
if not self.use_all2all_ep:
if shared_output is not None:
if hidden_states.dtype != torch.float16 or self.dpsk_fp16_quick:
......@@ -235,8 +253,10 @@ class DeepseekV2MoE(nn.Module):
final_hidden_states = (
self.experts.maybe_all_reduce_tensor_model_parallel(
final_hidden_states))
return final_hidden_states.view(num_tokens, hidden_dim)
if envs.USE_FUSED_RMS_QUANT:
return final_hidden_states.view(num_tokens, hidden_dim), new_resi
else:
return final_hidden_states.view(num_tokens, hidden_dim)
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
......@@ -437,19 +457,36 @@ class DeepseekV2MLAAttention(nn.Module):
self.max_position_embeddings = max_position_embeddings
if self.q_lora_rank is not None:
self.q_a_proj = ReplicatedLinear(self.hidden_size,
if envs.USE_FUSED_RMS_QUANT:
self.q_a_proj = ReplicatedLinear(self.hidden_size,
self.q_lora_rank,
bias=False,
quant_config=quant_config,
eps=config.rms_norm_eps,
prefix=f"{prefix}.q_a_proj")
self.q_a_layernorm = RMSNorm(self.q_lora_rank,
eps=config.rms_norm_eps)
self.q_b_proj = ColumnParallelLinear(q_lora_rank,
self.q_b_proj = ColumnParallelLinear(q_lora_rank,
self.num_heads *
self.qk_head_dim,
bias=False,
quant_config=quant_config,
eps=config.rms_norm_eps,
prefix=f"{prefix}.q_b_proj")
else:
self.q_a_proj = ReplicatedLinear(self.hidden_size,
self.q_lora_rank,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_a_proj")
self.q_b_proj = ColumnParallelLinear(q_lora_rank,
self.num_heads *
self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_b_proj")
self.q_a_layernorm = RMSNorm(self.q_lora_rank,
eps=config.rms_norm_eps)
else:
self.q_proj = ColumnParallelLinear(self.hidden_size,
self.num_heads *
......@@ -524,31 +561,60 @@ class DeepseekV2MLAAttention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
rms_weight: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None
) -> torch.Tensor:
if self.q_lora_rank is not None:
q_c = self.q_a_proj(hidden_states)[0]
q_c = self.q_a_layernorm(q_c)
q = self.q_b_proj(q_c)[0]
if envs.USE_FUSED_RMS_QUANT and rms_weight is not None:
if self.q_lora_rank is not None:
q_c, new_residual, _, input_quant_args = self.q_a_proj(hidden_states, rms_weight=rms_weight, residual=residual, update_hd=False)
q, _, _ = self.q_b_proj(q_c, rms_weight=self.q_a_layernorm.weight.data, residual=None, update_hd=False)
else:
q = self.q_proj(hidden_states)[0]
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states, quant_args=input_quant_args, update_hd=False)[0].split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
q = q.view(-1, self.num_local_heads, self.qk_head_dim)
# Add head dim of 1 to k_pe
k_pe = k_pe.unsqueeze(1)
q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb(
positions, q[..., self.qk_nope_head_dim:], k_pe)
attn_out = self.mla_attn(
q,
kv_c_normed,
k_pe,
output_shape=(hidden_states.shape[0],
self.num_local_heads * self.v_head_dim))
return self.o_proj(attn_out)[0], new_residual
else:
q = self.q_proj(hidden_states)[0]
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
if self.q_lora_rank is not None:
q_c = self.q_a_proj(hidden_states)[0]
q_c = self.q_a_layernorm(q_c)
q = self.q_b_proj(q_c)[0]
else:
q = self.q_proj(hidden_states)[0]
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
q = q.view(-1, self.num_local_heads, self.qk_head_dim)
# Add head dim of 1 to k_pe
k_pe = k_pe.unsqueeze(1)
q = q.view(-1, self.num_local_heads, self.qk_head_dim)
# Add head dim of 1 to k_pe
k_pe = k_pe.unsqueeze(1)
q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb(
positions, q[..., self.qk_nope_head_dim:], k_pe)
q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb(
positions, q[..., self.qk_nope_head_dim:], k_pe)
attn_out = self.mla_attn(
q,
kv_c_normed,
k_pe,
output_shape=(hidden_states.shape[0],
self.num_local_heads * self.v_head_dim))
return self.o_proj(attn_out)[0]
attn_out = self.mla_attn(
q,
kv_c_normed,
k_pe,
output_shape=(hidden_states.shape[0],
self.num_local_heads * self.v_head_dim))
return self.o_proj(attn_out)[0]
class DeepseekV2DecoderLayer(nn.Module):
......@@ -623,47 +689,90 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> torch.Tensor:
# Self Attention
# Fix residual FP16 overflow
residual_fix_overflow = False
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
residual_fix_overflow = True
if envs.USE_FUSED_RMS_QUANT:
# Fix residual FP16 overflow
residual_fix_overflow = False
assert self.input_layernorm.has_weight is True
if residual is None:
residual = hidden_states
hidden_states, _ = self.self_attn(
positions = positions,
hidden_states = hidden_states,
rms_weight = self.input_layernorm.weight.data,
residual = None
)
residual_fix_overflow = True
else:
hidden_states, new_residual = self.self_attn(
positions = positions,
hidden_states = hidden_states,
rms_weight = self.input_layernorm.weight.data,
residual = residual
)
residual = new_residual
if hidden_states.dtype == torch.float16 and not self.dpsk_fp16_quick:
# rmsnorm, and rmsnorm result would not affect by scale.
hidden_states *= 1. / self.routed_scaling_factor
if self.layer_idx == 0 or residual_fix_overflow:
# The residual is shared by all layers, we only scale it on
# first layer.
residual *= 1. / self.routed_scaling_factor
hidden_states, new_resi = self.mlp(hidden_states, self.post_attention_layernorm.weight.data, residual)
if isinstance(self.mlp,
DeepseekV2MLP) and hidden_states.dtype == torch.float16 and not self.dpsk_fp16_quick:
# Fix FP16 overflow
# Scaling the DeepseekV2MLP output, it is the input of
# input_layernorm of next decoder layer.
# The scaling of DeepseekV2MOE output would be done in the forward
# of DeepseekV2MOE
hidden_states *= 1. / self.routed_scaling_factor
return hidden_states, new_resi
else:
hidden_states, residual = self.input_layernorm(
# Self Attention
# Fix residual FP16 overflow
residual_fix_overflow = False
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
residual_fix_overflow = True
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
)
if hidden_states.dtype == torch.float16 and not self.dpsk_fp16_quick:
# Fix FP16 overflow
# We scale both hidden_states and residual before
# rmsnorm, and rmsnorm result would not affect by scale.
hidden_states *= 1. / self.routed_scaling_factor
if self.layer_idx == 0 or residual_fix_overflow:
# The residual is shared by all layers, we only scale it on
# first layer.
residual *= 1. / self.routed_scaling_factor
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
)
hidden_states = self.mlp(hidden_states)
if isinstance(self.mlp,
DeepseekV2MLP) and hidden_states.dtype == torch.float16 and not self.dpsk_fp16_quick:
# Fix FP16 overflow
# Scaling the DeepseekV2MLP output, it is the input of
# input_layernorm of next decoder layer.
# The scaling of DeepseekV2MOE output would be done in the forward
# of DeepseekV2MOE
hidden_states *= 1. / self.routed_scaling_factor
if hidden_states.dtype == torch.float16 and not self.dpsk_fp16_quick:
# Fix FP16 overflow
# We scale both hidden_states and residual before
# rmsnorm, and rmsnorm result would not affect by scale.
hidden_states *= 1. / self.routed_scaling_factor
if self.layer_idx == 0 or residual_fix_overflow:
# The residual is shared by all layers, we only scale it on
# first layer.
residual *= 1. / self.routed_scaling_factor
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
if isinstance(self.mlp,
DeepseekV2MLP) and hidden_states.dtype == torch.float16 and not self.dpsk_fp16_quick:
# Fix FP16 overflow
# Scaling the DeepseekV2MLP output, it is the input of
# input_layernorm of next decoder layer.
# The scaling of DeepseekV2MOE output would be done in the forward
# of DeepseekV2MOE
hidden_states *= 1. / self.routed_scaling_factor
return hidden_states, residual
return hidden_states, residual
@support_torch_compile
......@@ -984,7 +1093,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
# However it's not mapped locally to this rank
# So we simply skip it
continue
if self.use_all2all_ep:
name = name.replace(ep_moe_shared_experts_keys, ep_moe_shared_experts_mapping[ep_moe_shared_experts_keys])
# Skip loading extra bias for GPTQ models.
......
......@@ -54,6 +54,7 @@ from .utils import (AutoWeightsLoader, WeightsMapper,
init_vllm_registered_model, is_pp_missing_parameter,
maybe_prefix, merge_multimodal_embeddings)
from .vision import get_vit_attn_backend
from vllm.platforms import current_platform
logger = init_logger(__name__)
......@@ -330,7 +331,10 @@ def apply_rotary_pos_emb_flashatt(
cos = cos.chunk(2, dim=-1)[0].contiguous()
sin = sin.chunk(2, dim=-1)[0].contiguous()
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
if not current_platform.is_rocm():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
else:
from flash_attn.layers.rotary import apply_rotary_emb
q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q)
k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k)
......
......@@ -436,7 +436,8 @@ class Qwen2_5_VisionPatchEmbed(nn.Module):
L, C = x.shape
x = x.view(L, -1, self.temporal_patch_size, self.patch_size,
self.patch_size)
x = x.to(memory_format=torch.channels_last_3d)
if os.environ.get('PYTORCH_MIOPEN_SUGGEST_NDHWC') == '1':
x = x.to(memory_format=torch.channels_last_3d)
x = self.proj(x).view(L, self.hidden_size)
return x
......
......@@ -246,6 +246,8 @@ def apply_rotary_pos_emb_vision(t: torch.Tensor,
apply_rotary_emb = apply_rotary_emb_torch
if current_platform.is_cuda():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
if current_platform.is_rocm():
from flash_attn.layers.rotary import apply_rotary_emb
output = apply_rotary_emb(t_, cos, sin).type_as(t)
return output
......@@ -464,7 +466,8 @@ class Qwen2VisionPatchEmbed(nn.Module):
L, C = x.shape
x = x.view(L, -1, self.temporal_patch_size, self.patch_size,
self.patch_size)
x = x.to(memory_format=torch.channels_last_3d)
if os.environ.get('PYTORCH_MIOPEN_SUGGEST_NDHWC') == '1':
x = x.to(memory_format=torch.channels_last_3d)
x = self.proj(x).view(L, self.embed_dim)
return x
......
......@@ -287,18 +287,29 @@ def tbo_split_and_execute_model(
attn_metadata_left = prepare_tbo_atten_metadata(runner, input_split.scheduler_output_left, input_split.req_ids_left, 0)
attn_metadata_right = prepare_tbo_atten_metadata(runner, input_split.scheduler_output_right, input_split.req_ids_right, input_split.req_num_left)
model_output = tbo_model_executable_v1(
runner,
attn_metadata_left,
attn_metadata_right,
num_input_tokens_left,
num_input_tokens_right,
num_tokens_across_dp,
input_ids,
positions,
intermediate_tensors,
inputs_embeds)
finished_sending, finished_recving = None, None
with set_forward_context(attn_metadata,
runner.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
skip_cuda_graphs=True):
runner.maybe_setup_kv_connector(scheduler_output)
model_output = tbo_model_executable_v1(
runner,
attn_metadata_left,
attn_metadata_right,
num_input_tokens_left,
num_input_tokens_right,
num_tokens_across_dp,
input_ids,
positions,
intermediate_tensors,
inputs_embeds)
runner.maybe_wait_for_kv_save()
finished_sending, finished_recving = (
runner.get_finished_kv_transfers(scheduler_output))
#finished_sending, finished_recving = None, None
else:
# Run the decoder.
# Use persistent buffers for CUDA graphs.
......
......@@ -162,6 +162,14 @@ def init_two_batch_overlap():
tbo_obj_v1 = TwoBatchOverlap()
tbo_obj_v1.init_tbo_thread()
def tbo_maybe_save_kv_layer_to_connector(layer_name, kv_cache):
from vllm.attention.layer import maybe_save_kv_layer_to_connector
if envs.VLLM_ENABLE_TBO and tbo_obj_v1 != None and tbo_obj_v1.tbo_running:
tid = threading.get_ident()
if tid == tbo_obj_v1.left_tid:
return
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
def tbo_all_reduce_v1(obj):
if envs.VLLM_ENABLE_TBO and tbo_obj_v1 != None and tbo_obj_v1.tbo_running:
tid = threading.get_ident()
......
......@@ -669,30 +669,56 @@ class FlashAttentionImpl(AttentionImpl):
assert not use_local_attn, (
"Cascade attention does not support local attention.")
# Cascade attention (rare case).
cascade_attention(
output[:num_actual_tokens],
query[:num_actual_tokens],
key_cache,
value_cache,
cu_query_lens=attn_metadata.query_start_loc,
max_query_len=attn_metadata.max_query_len,
cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens,
prefix_kv_lens=attn_metadata.prefix_kv_lens,
suffix_kv_lens=attn_metadata.suffix_kv_lens,
max_kv_len=attn_metadata.max_seq_len,
softmax_scale=self.scale,
alibi_slopes=self.alibi_slopes,
sliding_window=self.sliding_window,
logits_soft_cap=self.logits_soft_cap,
block_table=attn_metadata.block_table,
common_prefix_len=attn_metadata.common_prefix_len,
fa_version=self.vllm_flash_attn_version,
prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata,
suffix_scheduler_metadata=attn_metadata.scheduler_metadata,
q_descale=layer._q_scale,
k_descale=layer._k_scale,
v_descale=layer._v_scale,
)
if not current_platform.is_rocm():
cascade_attention(
output[:num_actual_tokens],
query[:num_actual_tokens],
key_cache,
value_cache,
cu_query_lens=attn_metadata.query_start_loc,
max_query_len=attn_metadata.max_query_len,
cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens,
prefix_kv_lens=attn_metadata.prefix_kv_lens,
suffix_kv_lens=attn_metadata.suffix_kv_lens,
max_kv_len=attn_metadata.max_seq_len,
softmax_scale=self.scale,
alibi_slopes=self.alibi_slopes,
sliding_window=self.sliding_window,
logits_soft_cap=self.logits_soft_cap,
block_table=attn_metadata.block_table,
common_prefix_len=attn_metadata.common_prefix_len,
fa_version=self.vllm_flash_attn_version,
prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata,
suffix_scheduler_metadata=attn_metadata.scheduler_metadata,
q_descale=layer._q_scale,
k_descale=layer._k_scale,
v_descale=layer._v_scale,
)
else:
cascade_attention(
output[:num_actual_tokens],
query[:num_actual_tokens],
key_cache,
value_cache,
cu_query_lens=attn_metadata.query_start_loc,
max_query_len=attn_metadata.max_query_len,
cu_prefix_query_lens=attn_metadata.cu_prefix_query_lens,
prefix_kv_lens=attn_metadata.prefix_kv_lens,
suffix_kv_lens=attn_metadata.suffix_kv_lens,
max_kv_len=attn_metadata.max_seq_len,
softmax_scale=self.scale,
alibi_slopes=self.alibi_slopes,
sliding_window=self.sliding_window,
logits_soft_cap=self.logits_soft_cap,
block_table=attn_metadata.block_table,
common_prefix_len=attn_metadata.common_prefix_len,
fa_version=2, #self.vllm_flash_attn_version,
prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata,
suffix_scheduler_metadata=attn_metadata.scheduler_metadata,
# q_descale=layer._q_scale,
# k_descale=layer._k_scale,
# v_descale=layer._v_scale,
)
return output
......@@ -825,6 +851,31 @@ def cascade_attention(
v_descale=v_descale.expand(descale_shape)
if v_descale is not None else None,
)
else:
prefix_output, prefix_lse, _ = vllm_flash_attn_varlen_func(
q=query,
k=key_cache,
v=value_cache,
cu_seqlens_q=cu_prefix_query_lens,
seqused_k=prefix_kv_lens,
max_seqlen_q=num_tokens,
max_seqlen_k=common_prefix_len,
softmax_scale=softmax_scale,
causal=False,
window_size=sliding_window,
block_table=block_table[:1],
softcap=logits_soft_cap,
return_softmax_lse=True,
scheduler_metadata=prefix_scheduler_metadata,
# fa_version=fa_version,
# q_descale=q_descale.expand(descale_shape)
# if q_descale is not None else None,
# k_descale=k_descale.expand(descale_shape)
# if k_descale is not None else None,
# v_descale=v_descale.expand(descale_shape)
# if v_descale is not None else None,
is_prefix_cache=True,
)
descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2])
......@@ -853,6 +904,31 @@ def cascade_attention(
v_descale=v_descale.expand(descale_shape)
if v_descale is not None else None,
)
else:
suffix_output, suffix_lse, _ = vllm_flash_attn_varlen_func(
q=query,
k=key_cache,
v=value_cache,
cu_seqlens_q=cu_query_lens,
seqused_k=suffix_kv_lens,
max_seqlen_q=max_query_len,
max_seqlen_k=max_kv_len - common_prefix_len,
softmax_scale=softmax_scale,
causal=True,
window_size=sliding_window,
block_table=block_table[:, num_common_kv_blocks:],
softcap=logits_soft_cap,
return_softmax_lse=True,
scheduler_metadata=suffix_scheduler_metadata,
# fa_version=fa_version,
# q_descale=q_descale.expand(descale_shape)
# if q_descale is not None else None,
# k_descale=k_descale.expand(descale_shape)
# if k_descale is not None else None,
# v_descale=v_descale.expand(descale_shape)
# if v_descale is not None else None,
is_prefix_cache=True,
)
# Merge prefix and suffix outputs, and store the result in output.
merge_attn_states(output, prefix_output, prefix_lse, suffix_output,
......
......@@ -216,6 +216,7 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
from vllm.v1.attention.backends.mla.test_concat import lightop_concat_prefill_helper
try:
from vllm.vllm_flash_attn import flash_attn_varlen_func
......@@ -777,10 +778,13 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
# v with 0s to match the qk head dim for attention backends that do
# not support different headdims
# We don't need to pad V if we are on a hopper system with FA3
self._pad_v = self.vllm_flash_attn_version is None or not (
self.vllm_flash_attn_version == 3
and current_platform.get_device_capability()[0] == 9
and torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count == 120 )
if not current_platform.is_rocm():
self._pad_v = self.vllm_flash_attn_version is None or not (
self.vllm_flash_attn_version == 3
and current_platform.get_device_capability()[0] == 9)
else:
self._pad_v = torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count == 120
def _flash_attn_varlen_diff_headdims(self,
q,
......@@ -921,8 +925,16 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
dim=-1)
if envs.VLLM_USE_TRITON_CAT:
if k_nope.shape[0] > 1024:
k = lightop_concat_prefill_helper(k_nope, k_pe.expand((*k_nope.shape[:-1], -1)),
dim=2)
else:
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
dim=-1)
else:
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
dim=-1)
attn_output, attn_softmax_lse = \
self._flash_attn_varlen_diff_headdims(
......@@ -977,7 +989,15 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
if envs.VLLM_USE_TRITON_CAT:
if k_nope.shape[0] > 1024:
k = lightop_concat_prefill_helper(k_nope, k_pe.expand((*k_nope.shape[:-1], -1)),
dim=2)
else:
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
dim=-1)
else:
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
output = self._flash_attn_varlen_diff_headdims(
q=q,
......
import triton
import triton.language as tl
import torch
from functools import reduce
import pytest
import torch
import math
@pytest.mark.parametrize("shape_pair,dim", [
(((4, 8, 512), (4, 8, 64)), 2),
(((8, 8, 512), (8, 8, 64)), 2),
(((16, 8, 512), (16, 8, 64)), 2),
(((32, 8, 512), (32, 8, 64)), 2),
(((64, 8, 512), (64, 8, 64)), 2),
(((128, 8, 512), (128, 8, 64)), 2),
(((256, 8, 512), (256, 8, 64)), 2),
(((512, 8, 512), (512, 8, 64)), 2),
(((672, 8, 512), (672, 8, 64)), 2),
(((768, 8, 512), (768, 8, 64)), 2),
(((896, 8, 512), (896, 8, 64)), 2),
(((1024, 8, 512), (1024, 8, 64)), 2),
(((4, 16, 512), (4, 16, 64)), 2),
(((8, 16, 512), (8, 16, 64)), 2),
(((16, 16, 512), (16, 16, 64)), 2),
(((32, 16, 512), (32, 16, 64)), 2),
(((64, 16, 512), (64, 16, 64)), 2),
(((128, 16, 512), (128, 16, 64)), 2),
(((256, 16, 512), (256, 16, 64)), 2),
(((512, 16, 512), (512, 16, 64)), 2),
(((672, 16, 512), (672, 16, 64)), 2),
(((768, 16, 512), (768, 16, 64)), 2),
(((896, 16, 512), (896, 16, 64)), 2),
(((1024, 16, 512), (1024, 16, 64)), 2),
(((4, 32, 512), (4, 32, 64)), 2),
(((8, 32, 512), (8, 32, 64)), 2),
(((16, 32, 512), (16, 32, 64)), 2),
(((32, 32, 512), (32, 32, 64)), 2),
(((64, 32, 512), (64, 32, 64)), 2),
(((128, 32, 512), (128, 32, 64)), 2),
(((256, 32, 512), (256, 32, 64)), 2),
(((512, 32, 512), (512, 32, 64)), 2),
(((672, 32, 512), (672, 32, 64)), 2),
(((768, 32, 512), (768, 32, 64)), 2),
(((896, 32, 512), (896, 32, 64)), 2),
(((1024, 32, 512), (1024, 32, 64)), 2),
(((4, 32, 128), (4, 32, 64)), 2),
(((8, 32, 128), (8, 32, 64)), 2),
(((16, 32, 128), (16, 32, 64)), 2),
(((32, 32, 128), (32, 32, 64)), 2),
(((64, 32, 128), (64, 32, 64)), 2),
(((128, 32, 128), (128, 32, 64)), 2),
(((256, 32, 128), (256, 32, 64)), 2),
(((512, 32, 128), (512, 32, 64)), 2),
(((672, 32, 128), (672, 32, 64)), 2),
(((768, 32, 128), (768, 32, 64)), 2),
(((896, 32, 128), (896, 32, 64)), 2),
(((1024, 32, 128), (1024, 32, 64)), 2),
])
def test_concat_Acc(shape_pair, dim):
torch.manual_seed(1)
shape1, shape2 = shape_pair
x = torch.randn(*shape1, device='cuda', dtype=torch.bfloat16)
y = torch.randn(*shape2, device='cuda', dtype=torch.bfloat16)
expected = torch.cat([x,y], dim=dim)
result = concat_helper(x, y, dim=dim)
assert torch.allclose(result, expected, rtol=1e-5, atol=1e-5), "Mismatch"
@triton.jit
def concat_kernel_prefill(
A_ptr, B_ptr, C_ptr,
A_section_numel, B_section_numel, C_section_numel,
Per_block,
section_num,
BLOCK_SIZE: tl.constexpr
):
block_idx = tl.program_id(0)# 获取当前block的索引
for sub_section_index in range(Per_block//2):
sub_section_offset = block_idx * Per_block + sub_section_index * 2
if sub_section_offset <= section_num-1:
C_section_start = C_ptr + sub_section_offset * C_section_numel
A_section_start = A_ptr + sub_section_offset * A_section_numel
B_section_start = B_ptr + sub_section_offset * B_section_numel
Arrange_doubleA = tl.arange(0, 256)
mask = Arrange_doubleA < (256)
Arrange2 = (tl.arange(0, 128)[None,:] + tl.arange(0, 2)[:,None]).reshape(256)
val_from_A = tl.load(A_section_start + Arrange_doubleA)
tensorAsn = tl.full((256,), 0, tl.int32)
tensorAsn2 = tl.full((256,), (C_section_numel-1), tl.int32)
tensor_offsets = tl.where(Arrange_doubleA < A_section_numel,tensorAsn , tensorAsn2)
off = Arrange2 + tensor_offsets
tl.store(C_section_start + off,val_from_A,mask=mask)
Arrange_doubleB = tl.arange(0, 128)
mask = Arrange_doubleB < (B_section_numel*2)
val_from_B = tl.load(B_section_start + Arrange_doubleB,mask=mask)
Arrange3 = (tl.arange(0, 64)[None,:] + tl.arange(0, 2)[:,None]).reshape(128)
tensorAsn = tl.full((128,), A_section_numel, tl.int32)
tensorAsn2 = tl.full((128,), (C_section_numel + A_section_numel-1), tl.int32)
tensor_offsets = tl.where(Arrange_doubleB < B_section_numel,tensorAsn , tensorAsn2)
tl.store(C_section_start+ Arrange3 + tensor_offsets , val_from_B)
@triton.jit
def concat_kernel(
A_ptr, B_ptr, C_ptr,
A_section_numel, B_section_numel, C_section_numel,
Per_block,
section_num,
BLOCK_SIZE: tl.constexpr
):
block_idx = tl.program_id(0)
for sub_section_index in range(Per_block):
sub_offset = block_idx * Per_block + sub_section_index
if sub_offset <= section_num-1:
C_ptr_block_start = C_ptr + sub_offset * C_section_numel
A_ptr_block_start = A_ptr + sub_offset * A_section_numel
B_ptr_block_start = B_ptr + sub_offset * B_section_numel
for offset in range(0, A_section_numel, BLOCK_SIZE):
offset_idx = offset + tl.arange(0, BLOCK_SIZE)
mask = offset_idx < A_section_numel
val_from_A = tl.load(A_ptr_block_start + offset_idx, mask=mask)
tl.store(C_ptr_block_start + offset_idx, val_from_A, mask=mask)
for offset in range(0, B_section_numel, BLOCK_SIZE):
offset_idx = offset + tl.arange(0, BLOCK_SIZE)
mask = offset_idx < B_section_numel
val_from_B = tl.load(B_ptr_block_start + offset_idx, mask=mask)
tl.store(C_ptr_block_start + A_section_numel + offset_idx, val_from_B, mask=mask)
def concat_helper(A:torch.Tensor, B:torch.Tensor, dim:int):
A = A.contiguous()
B = B.contiguous()
output_shape = list(A.shape)
output_shape[dim] = A.shape[dim] + B.shape[dim]
C = torch.empty(output_shape, device=A.device, dtype=A.dtype)
if dim!=0 :
block_num = reduce(lambda x, y: x * y, output_shape[:dim])
Per_block = 1
unit_offset_A, unit_offset_B, unit_offset_C = A.stride(dim-1),B.stride(dim-1),C.stride(dim-1)
#case prefill
if (A.shape[2] == 128 and B.shape[2] == 64 and A.shape[0] > 16):
Per_block = 8
num_blocks = math.ceil(block_num/Per_block)
concat_kernel_prefill[(num_blocks,)](
A, B, C,
unit_offset_A, unit_offset_B, unit_offset_C,
Per_block,
block_num,
BLOCK_SIZE=1024)
return C
else:
if (A.shape[1]==8 and A.shape[0] > 128) or ( A.shape[1]==16 and A.shape[0] > 96) or ( A.shape[1]==32 and A.shape[2] == 512 and A.shape[0] > 64):
Per_block = 2
num_blocks = math.ceil(block_num/Per_block)
concat_kernel[(num_blocks,)](
A, B, C,
unit_offset_A, unit_offset_B, unit_offset_C,
Per_block,
block_num,
BLOCK_SIZE=1024)
return C
assert False, "not support"
configs = []
configs.append(
triton.testing.Benchmark(
x_names=['size'],
x_vals=[4,8,16,32,64,96,128,256,512,768,1024],
x_log=True,
line_arg='provider',
line_vals=['triton', 'torch'],
line_names=['Triton', 'Torch'],
styles=[('blue', '-'), ('green', '-')],
ylabel='s',
plot_name='concat-dim2',
args={"dim":2},
),
)
@triton.testing.perf_report(configs)
def benchmark(size, provider, dim):
x = torch.rand([size,8,512], device='cuda', dtype=torch.bfloat16)
y = torch.rand([size,8,64], device='cuda', dtype=torch.bfloat16)
quantiles = [0.5, 0.2, 0.8]
if provider == 'torch':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.cat([x,y],dim=dim), quantiles=quantiles)
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: concat_helper(x, y,dim=dim), quantiles=quantiles)
return (ms*1000), (max_ms*1000), (min_ms*1000)
@triton.testing.perf_report(configs)
def benchmark_16(size, provider, dim):
x = torch.rand([size,16,512], device='cuda', dtype=torch.bfloat16)
y = torch.rand([size,16,64], device='cuda', dtype=torch.bfloat16)
quantiles = [0.5, 0.2, 0.8]
if provider == 'torch':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.cat([x,y],dim=dim), quantiles=quantiles)
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: concat_helper(x, y,dim=dim), quantiles=quantiles)
return (ms*1000), (max_ms*1000), (min_ms*1000)
@triton.testing.perf_report(configs)
def benchmark_32(size, provider, dim):
x = torch.rand([size,32,512], device='cuda', dtype=torch.bfloat16)
y = torch.rand([size,32,64], device='cuda', dtype=torch.bfloat16)
quantiles = [0.5, 0.2, 0.8]
if provider == 'torch':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.cat([x,y],dim=dim), quantiles=quantiles)
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: concat_helper(x, y,dim=dim), quantiles=quantiles)
return (ms*1000), (max_ms*1000), (min_ms*1000)
@triton.testing.perf_report(configs)
def benchmark_prefill(size, provider, dim):
x = torch.rand([size,32,128], device='cuda', dtype=torch.bfloat16)
y = torch.rand([size,32,64], device='cuda', dtype=torch.bfloat16)
quantiles = [0.5, 0.2, 0.8]
if provider == 'torch':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.cat([x,y],dim=dim), quantiles=quantiles)
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: concat_helper(x, y,dim=dim), quantiles=quantiles)
return (ms*1000), (max_ms*1000), (min_ms*1000)
if __name__ == '__main__':
# benchmark.run(save_path="./triton_test_8",print_data=True)
# benchmark_16.run(save_path="./triton_test_16",print_data=True)
# benchmark_32.run(save_path="./triton_test_32",print_data=True)
benchmark_prefill.run(save_path="./triton_test_prefill",print_data=True)
\ No newline at end of file
import triton
import triton.language as tl
import torch
from functools import reduce
import pytest
import torch
import math
@pytest.mark.parametrize("shape_pair,dim", [
(((4, 8, 512), (4, 8, 64)), 2),
(((8, 8, 512), (8, 8, 64)), 2),
(((16, 8, 512), (16, 8, 64)), 2),
(((32, 8, 512), (32, 8, 64)), 2),
(((64, 8, 512), (64, 8, 64)), 2),
(((128, 8, 512), (128, 8, 64)), 2),
(((256, 8, 512), (256, 8, 64)), 2),
(((512, 8, 512), (512, 8, 64)), 2),
(((672, 8, 512), (672, 8, 64)), 2),
(((768, 8, 512), (768, 8, 64)), 2),
(((896, 8, 512), (896, 8, 64)), 2),
(((1024, 8, 512), (1024, 8, 64)), 2),
(((4, 16, 512), (4, 16, 64)), 2),
(((8, 16, 512), (8, 16, 64)), 2),
(((16, 16, 512), (16, 16, 64)), 2),
(((32, 16, 512), (32, 16, 64)), 2),
(((64, 16, 512), (64, 16, 64)), 2),
(((128, 16, 512), (128, 16, 64)), 2),
(((256, 16, 512), (256, 16, 64)), 2),
(((512, 16, 512), (512, 16, 64)), 2),
(((672, 16, 512), (672, 16, 64)), 2),
(((768, 16, 512), (768, 16, 64)), 2),
(((896, 16, 512), (896, 16, 64)), 2),
(((1024, 16, 512), (1024, 16, 64)), 2),
(((4, 32, 512), (4, 32, 64)), 2),
(((8, 32, 512), (8, 32, 64)), 2),
(((16, 32, 512), (16, 32, 64)), 2),
(((32, 32, 512), (32, 32, 64)), 2),
(((64, 32, 512), (64, 32, 64)), 2),
(((128, 32, 512), (128, 32, 64)), 2),
(((256, 32, 512), (256, 32, 64)), 2),
(((512, 32, 512), (512, 32, 64)), 2),
(((672, 32, 512), (672, 32, 64)), 2),
(((768, 32, 512), (768, 32, 64)), 2),
(((896, 32, 512), (896, 32, 64)), 2),
(((1024, 32, 512), (1024, 32, 64)), 2),
])
def test_concat_Acc(shape_pair, dim):
torch.manual_seed(1)
shape1, shape2 = shape_pair
M = shape1[0]
N = shape1[1]
x_sizes = [M, N, 512]
x_strides = [512, 512*M, 1]
x_max_index = M * N * 512
x_required_length = x_max_index
x_data = torch.arange(x_required_length,device='cuda').bfloat16()
x = torch.as_strided(x_data, size=x_sizes, stride=x_strides)
# print("形状:", x.shape) # [4, 8, 512]
# print("步幅:", x.stride()) # (1536, 192, 1)
y_sizes = [M, N, 64]
y_strides = [1536*(N//8), 192, 1]
y_max_index = 1536*(N//8) * M
y_required_length = y_max_index
y_data = torch.arange(y_required_length,device='cuda').bfloat16()
y = torch.as_strided(y_data, size=y_sizes, stride=y_strides)
expected = torch.cat([x,y], dim=dim)
result = concat_helper(x, y, dim=dim)
assert torch.allclose(result, expected, rtol=1e-5, atol=1e-5), "Mismatch"
@triton.jit
def concat_kernel(
A_ptr, B_ptr, C_ptr,
A_section_numel, B_section_numel, C_section_numel,
Per_block,
section_num,
M,
N,
Astride_0,
Astride_1,
Astride_2,
Bstride_0,
Bstride_1,
Bstride_2,
BLOCK_SIZE: tl.constexpr
):
block_idx = tl.program_id(0)
for sub_section_index in range(Per_block):
sub_offset = block_idx * Per_block + sub_section_index
M_idx = sub_offset // N
N_idx = sub_offset % N
if sub_offset <= section_num-1:
C_ptr_block_start = C_ptr + sub_offset * C_section_numel
A_ptr_block_start = A_ptr + M_idx * Astride_0 + N_idx * Astride_1
B_ptr_block_start = B_ptr + M_idx * Bstride_0 + N_idx * Bstride_1
for offset in range(0, A_section_numel, BLOCK_SIZE):
offset_idx = offset + tl.arange(0, BLOCK_SIZE)
mask = offset_idx < A_section_numel
val_from_A = tl.load(A_ptr_block_start + offset_idx, mask=mask)
tl.store(C_ptr_block_start + offset_idx, val_from_A, mask=mask)
for offset in range(0, B_section_numel, BLOCK_SIZE):
offset_idx = offset + tl.arange(0, BLOCK_SIZE)
mask = offset_idx < B_section_numel
val_from_B = tl.load(B_ptr_block_start + offset_idx, mask=mask)
tl.store(C_ptr_block_start + A_section_numel + offset_idx, val_from_B, mask=mask)
def concat_helper(A:torch.Tensor, B:torch.Tensor, dim:int):
output_shape = list(A.shape)
output_shape[dim] = A.shape[dim] + B.shape[dim]
C = torch.empty(output_shape, device=A.device, dtype=A.dtype)
if dim!=0 :
block_num = reduce(lambda x, y: x * y, output_shape[:dim])
Per_block = 1
unit_offset_A, unit_offset_B, unit_offset_C = A.shape[dim],B.shape[dim],C.shape[dim]
if (A.shape[1]==8 and A.shape[0] > 512) or ( A.shape[1]==16 and A.shape[0] > 256):
Per_block = 2
if ( A.shape[1]==32 and A.shape[2] == 512 and A.shape[0] > 256):
Per_block = 8
num_blocks = math.ceil(block_num/Per_block)
concat_kernel[(num_blocks,)](
A, B, C,
unit_offset_A, unit_offset_B, unit_offset_C,
Per_block,
block_num,
output_shape[0],
output_shape[1],
A.stride(0),
A.stride(1),
A.stride(2),
B.stride(0),
B.stride(1),
B.stride(2),
BLOCK_SIZE=1024)
return C
assert False, "not support"
configs = []
configs.append(
triton.testing.Benchmark(
x_names=['M','N'],
x_vals=[(4,8),(8,8),(16,8),(32,8),(64,8),(96,8),(128,8),(256,8),(512,8),(768,8),(1024,8), \
(4,16),(8,16),(16,16),(32,16),(64,16),(96,16),(128,16),(256,16),(512,16),(768,16),(1024,16), \
(4,32),(8,32),(16,32),(32,32),(64,32),(96,32),(128,32),(256,32),(512,32),(768,32),(1024,32)],
x_log=True,
line_arg='provider',
line_vals=['triton', 'torch'],
line_names=['Triton', 'Torch'],
styles=[('blue', '-'), ('green', '-')],
ylabel='s',
plot_name='concat-dim2',
args={"dim":2},
),
)
@triton.testing.perf_report(configs)
def benchmark(M, N, provider, dim):
x_sizes = [M, N, 512]
x_strides = [512, 512*M, 1]
x_max_index = M * N * 512
x_required_length = x_max_index
x_data = torch.arange(x_required_length,device='cuda').bfloat16()
x = torch.as_strided(x_data, size=x_sizes, stride=x_strides)
# print("形状:", x.shape) # [M, 8, 512]
# print("步幅:", x.stride()) # (512, 512*M, 1)
y_sizes = [M, N, 64]
y_strides = [1536*(N//8), 192, 1]
y_max_index = 1536*(N//8) * M
y_required_length = y_max_index
y_data = torch.arange(y_required_length,device='cuda').bfloat16()
y = torch.as_strided(y_data, size=y_sizes, stride=y_strides)
# print("形状:", y.shape) # [M, 8, 64]
# print("步幅:", y.stride()) # (1536, 192, 1)
quantiles = [0.5, 0.2, 0.8]
if provider == 'torch':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.cat([x,y],dim=dim), quantiles=quantiles)
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: concat_helper(x, y,dim=dim), quantiles=quantiles)
return (ms*1000), (max_ms*1000), (min_ms*1000)
# @triton.testing.perf_report(configs)
# def benchmark_16(size, provider, dim):
# x = torch.rand([size,16,512], device='cuda', dtype=torch.bfloat16)
# y = torch.rand([size,16,64], device='cuda', dtype=torch.bfloat16)
# quantiles = [0.5, 0.2, 0.8]
# if provider == 'torch':
# ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.cat([x,y],dim=dim), quantiles=quantiles)
# if provider == 'triton':
# ms, min_ms, max_ms = triton.testing.do_bench(lambda: concat_helper(x, y,dim=dim), quantiles=quantiles)
# return (ms*1000), (max_ms*1000), (min_ms*1000)
# @triton.testing.perf_report(configs)
# def benchmark_32(size, provider, dim):
# x = torch.rand([size,32,512], device='cuda', dtype=torch.bfloat16)
# y = torch.rand([size,32,64], device='cuda', dtype=torch.bfloat16)
# quantiles = [0.5, 0.2, 0.8]
# if provider == 'torch':
# ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.cat([x,y],dim=dim), quantiles=quantiles)
# if provider == 'triton':
# ms, min_ms, max_ms = triton.testing.do_bench(lambda: concat_helper(x, y,dim=dim), quantiles=quantiles)
# return (ms*1000), (max_ms*1000), (min_ms*1000)
# @triton.testing.perf_report(configs)
# def benchmark_prefill(size, provider, dim):
# x = torch.rand([size,32,128], device='cuda', dtype=torch.bfloat16)
# y = torch.rand([size,32,64], device='cuda', dtype=torch.bfloat16)
# quantiles = [0.5, 0.2, 0.8]
# if provider == 'torch':
# ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.cat([x,y],dim=dim), quantiles=quantiles)
# if provider == 'triton':
# ms, min_ms, max_ms = triton.testing.do_bench(lambda: concat_helper(x, y,dim=dim), quantiles=quantiles)
# return (ms*1000), (max_ms*1000), (min_ms*1000)
if __name__ == '__main__':
benchmark.run(save_path="./triton_test",print_data=True)
# benchmark_16.run(save_path="./triton_test_16",print_data=True)
# benchmark_32.run(save_path="./triton_test_32",print_data=True)
# benchmark_prefill.run(save_path="./triton_test_prefill",print_data=True)
\ No newline at end of file
......@@ -19,6 +19,8 @@ from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonMetadataBuilder)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
from vllm import envs
from vllm.v1.attention.backends.mla.test_concat import concat_helper_decode
logger = init_logger(__name__)
......@@ -164,8 +166,16 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
q = torch.cat([q_nope, q_pe], dim=-1)\
.unsqueeze(1) # Add seqlen dim of 1 (decode)
if envs.VLLM_USE_TRITON_CAT:
if q_nope.shape[0] <= 1024:
q = concat_helper_decode(q_nope, q_pe, dim=2)\
.unsqueeze(1)
else:
q = torch.cat([q_nope, q_pe], dim=-1)\
.unsqueeze(1) # Add seqlen dim of 1 (decode)
else:
q = torch.cat([q_nope, q_pe], dim=-1)\
.unsqueeze(1) # Add seqlen dim of 1 (decode)
o, _ = flash_mla_with_kvcache(
q=q,
......
import triton
import triton.language as tl
import torch
from functools import reduce
import pytest
import torch
import math
from lightop import ds_cat
def test_concat_Acc_prefill(shape_pair, dim):
torch.manual_seed(1)
shape1, shape2 = shape_pair
M = shape1[0]
N = shape1[1]
x_sizes = [M, N, 128]
x_strides = [N//8 * 2048, 256, 1]
x_max_index = N//8 * 2048 * M
x_required_length = x_max_index
x_data = torch.arange(x_required_length,device='cuda').bfloat16()
x = torch.as_strided(x_data, size=x_sizes, stride=x_strides)
y_sizes = [M, N, 64]
y_strides = [576, 0, 1]
y_max_index = 576 * M
y_required_length = y_max_index
y_data = torch.arange(y_required_length,device='cuda').bfloat16()
y = torch.as_strided(y_data, size=y_sizes, stride=y_strides)
expected = torch.cat([x,y], dim=dim)
result = concat_prefill_helper_Triton(x, y, dim=dim)
result_lightop = lightop_concat_prefill_helper(x, y, dim=dim)
assert torch.allclose(result, expected, rtol=1e-5, atol=1e-5), "Mismatch"
# print("精度验证通过")
# print("expected",expected)
# print("result_lightop",result_lightop)
assert torch.allclose(result, result_lightop, rtol=1e-5, atol=1e-5), "result_lightop Mismatch Triton error"
assert torch.allclose(expected, result_lightop, rtol=1e-5, atol=1e-5), "result_lightop Mismatch torch error"
print("prefill 精度验证通过")
def test_concat_Acc_decode(shape_pair, dim):
torch.manual_seed(1)
shape1, shape2 = shape_pair
M = shape1[0]
N = shape1[1]
x_sizes = [M, N, 512]
x_strides = [512, 512*M, 1]
x_max_index = M * N * 512
x_required_length = x_max_index
x_data = torch.arange(x_required_length,device='cuda').bfloat16()
x = torch.as_strided(x_data, size=x_sizes, stride=x_strides)
# print("形状:", x.shape)
# print("步幅:", x.stride())
y_sizes = [M, N, 64]
y_strides = [1536*(N//8), 192, 1]
y_max_index = 1536*(N//8) * M
y_required_length = y_max_index
y_data = torch.arange(y_required_length,device='cuda').bfloat16()
y = torch.as_strided(y_data, size=y_sizes, stride=y_strides)
# print("形状:", y.shape)
# print("步幅:", y.stride())
expected = torch.cat([x,y], dim=dim)
result = concat_helper_decode(x, y, dim=dim)
assert torch.allclose(result, expected, rtol=1e-5, atol=1e-5), "Mismatch"
print("decode 精度正常")
@triton.jit
def concat_kernel(
A_ptr, B_ptr, C_ptr,
A_section_numel, B_section_numel, C_section_numel,
Per_block_A,
Per_block_B,
section_numA,
section_numB,
M,
N,
Astride_0,
Astride_1,
Astride_2,
Bstride_0,
Bstride_1,
Bstride_2,
BLOCK_SIZE: tl.constexpr
):
block_idx = tl.program_id(0)
numA = section_numA // Per_block_A
if (block_idx < numA):
#处理A的block
for sub_section_index in range(Per_block_A):
sub_offset = block_idx * Per_block_A + sub_section_index
if sub_offset <= section_numA-1:
M_idx = sub_offset // N
N_idx = sub_offset % N
C_ptr_block_start = C_ptr + sub_offset * C_section_numel
A_ptr_block_start = A_ptr + M_idx * Astride_0 + N_idx * Astride_1
for offset in range(0, A_section_numel, BLOCK_SIZE):
offset_idx = offset + tl.arange(0, BLOCK_SIZE)
mask = offset_idx < A_section_numel
val_from_A = tl.load(A_ptr_block_start + offset_idx, mask=mask)
tl.store(C_ptr_block_start + offset_idx, val_from_A, mask=mask)
else:
#处理B的block
#shape是1024*8*64,实际上只有1024 * 64 块数据,开了1024/4=256个线程块来处理。每个线程块处理1块连续的数据
#需要注意C的分块也是有M * N 大小的,而这里只有M大小个线程块,每个线程块需要写入N次数据到C中。
for sub_section_index in range(Per_block_B):
sub_offset = (block_idx - numA) * Per_block_B + sub_section_index
if sub_offset <= section_numB-1:
C_ptr_block_start = C_ptr + sub_offset * N * C_section_numel
B_ptr_block_start = B_ptr + sub_offset * Bstride_0
for offset in range(0, B_section_numel, BLOCK_SIZE):
offset_idx = offset + tl.arange(0, BLOCK_SIZE)
mask = offset_idx < B_section_numel
val_from_B = tl.load(B_ptr_block_start + offset_idx, mask=mask)
for idx in range(0,N,1):
tl.store(C_ptr_block_start + idx * C_section_numel + A_section_numel + offset_idx, val_from_B, mask=mask)
def concat_prefill_helper_Triton(A:torch.Tensor, B:torch.Tensor, dim:int):
output_shape = list(A.shape)
output_shape[dim] = A.shape[dim] + B.shape[dim] #128+64=192
C = torch.empty(output_shape, device=A.device, dtype=A.dtype)
if dim!=0 :
#分开计算block块A需要
Per_block_A = 64
Per_block_B = 1
#128 \64 \192
unit_offset_A, unit_offset_B, unit_offset_C = A.shape[dim],B.shape[dim],C.shape[dim]
#A的分块数是:M * N 这里的demo是1024 * 8
block_numA = reduce(lambda x, y: x * y, output_shape[:dim])
#B的分块数是:M 这里的demo是1024
block_numB = output_shape[0]
#A的每个分块可以处理多份数据的读取和写入,这是因为单次的任务量太小。假设这里Per_block = 8 那么A就开启了1024个线程块,每个线程块处理8份数据的读取和写入
#B的每个分块处理1次B的读取和8次C的写入,L2 cache复用率高
block_num = block_numA // Per_block_A + block_numB // Per_block_B
num_blocks = math.ceil(block_num)
concat_kernel[(num_blocks,)](
A, B, C,
unit_offset_A, unit_offset_B, unit_offset_C,
Per_block_A,
Per_block_B,
block_numA,
block_numB,
output_shape[0],
output_shape[1],
A.stride(0),
A.stride(1),
A.stride(2),
B.stride(0),
B.stride(1),
B.stride(2),
BLOCK_SIZE=1024)
return C
assert False, "not support"
def concat_helper_decode(A:torch.Tensor, B:torch.Tensor, dim:int):
assert dim==2 , "tensor dim must be 3 and concat dim must be 2"
output_shape = list(A.shape)
output_shape[dim] = A.shape[dim] + B.shape[dim]
C = torch.empty(output_shape, device=A.device, dtype=A.dtype)
mode = 0
if dim!=0 :
ds_cat( A, B, C, mode)
return C
assert False, "not support"
def lightop_concat_prefill_helper(A:torch.Tensor, B:torch.Tensor, dim:int):
assert dim==2 , "tensor dim must be 3 and concat dim must be 2"
output_shape = list(A.shape)
output_shape[dim] = A.shape[dim] + B.shape[dim]
C = torch.empty(output_shape, device=A.device, dtype=A.dtype)
mode = 6
if dim!=0 :
ds_cat( A, B, C, mode)
return C
assert False, "not support"
configs = []
configs.append(
triton.testing.Benchmark(
x_names=['M','N'],
x_vals=[(1024,8),(2048,8),(3072,8),(4096,8),(6144,8),(8192,8),\
(1024,16),(2048,16),(3072,16),(4096,16),(6144,16),(8192,16),\
(1024,32),(2048,32),(3072,32),(4096,32),(6144,32),(8192,32)
],
x_log=True,
line_arg='provider',
line_vals=['triton', 'torch', 'lightop'],
line_names=['Triton', 'Torch','Lightop'],
styles=[('blue', '-'), ('green', '-'), ('yellow', '-')],
ylabel='s',
plot_name='concat-dim2',
args={"dim":2},
),
)
configs_decode = []
configs_decode.append(
triton.testing.Benchmark(
x_names=['M','N'],
x_vals=[(4,8),(8,8),(16,8),(32,8),(64,8),(96,8),(128,8),(256,8),(512,8),(768,8),(767,8),(765,8),(766,8), \
(4,16),(8,16),(16,16),(32,16),(64,16),(96,16),(128,16),(256,16),(512,16),(768,16),(767,16),(765,16),(766,16), \
(4,32),(8,32),(16,32),(32,32),(64,32),(96,32),(128,32),(256,32),(512,32),(768,32),(767,32),(765,32),(766,32)],
x_log=True,
line_arg='provider',
line_vals=['lightop', 'torch'],
line_names=['Lightop', 'Torch'],
styles=[('blue', '-'), ('green', '-')],
ylabel='s',
plot_name='concat-dim2',
args={"dim":2},
),
)
@triton.testing.perf_report(configs)
def benchmark_prefill(M, N, provider, dim):
x_sizes = [M, N, 128]
x_strides = [N//8 * 2048, 256, 1]
x_max_index = N//8 * 2048 * M
x_required_length = x_max_index
x_data = torch.arange(x_required_length,device='cuda').bfloat16()
x = torch.as_strided(x_data, size=x_sizes, stride=x_strides)
y_sizes = [M, N, 64]
y_strides = [576, 0, 1]
y_max_index = 576 * M
y_required_length = y_max_index
y_data = torch.arange(y_required_length,device='cuda').bfloat16()
y = torch.as_strided(y_data, size=y_sizes, stride=y_strides)
quantiles = [0.5, 0.2, 0.8]
if provider == 'torch':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.cat([x,y],dim=dim), quantiles=quantiles)
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: concat_prefill_helper_Triton(x, y,dim=dim), quantiles=quantiles)
if provider == 'lightop':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: lightop_concat_prefill_helper(x, y,dim=dim), quantiles=quantiles)
return (ms*1000), (max_ms*1000), (min_ms*1000)
@triton.testing.perf_report(configs_decode)
def benchmark_decode(M, N, provider, dim):
x_sizes = [M, N, 512]
x_strides = [512, 512*M, 1]
x_max_index = M * N * 512
x_required_length = x_max_index
x_data = torch.arange(x_required_length,device='cuda').bfloat16()
x = torch.as_strided(x_data, size=x_sizes, stride=x_strides)
# print("形状:", x.shape)
# print("步幅:", x.stride())
y_sizes = [M, N, 64]
y_strides = [1536*(N//8), 192, 1]
y_max_index = 1536*(N//8) * M
y_required_length = y_max_index
y_data = torch.arange(y_required_length,device='cuda').bfloat16()
y = torch.as_strided(y_data, size=y_sizes, stride=y_strides)
# print("形状:", y.shape)
# print("步幅:", y.stride())
quantiles = [0.5, 0.2, 0.8]
if provider == 'torch':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.cat([x,y],dim=dim), quantiles=quantiles)
if provider == 'lightop':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: concat_helper_decode(x, y,dim=dim), quantiles=quantiles)
return (ms*1000), (max_ms*1000), (min_ms*1000)
if __name__ == '__main__':
benchmark_prefill.run(save_path="./triton_test",print_data=True)
test_concat_Acc_prefill(((1024, 8, 128), (1024, 8, 64)), 2)
test_concat_Acc_prefill(((2048, 8, 128), (2048, 8, 64)), 2)
test_concat_Acc_prefill(((4096, 8, 128), (4096, 8, 64)), 2)
test_concat_Acc_prefill(((8192, 8, 128), (8192, 8, 64)), 2)
test_concat_Acc_prefill(((1024, 16, 128), (1024, 16, 64)), 2)
test_concat_Acc_prefill(((2048, 16, 128), (2048, 16, 64)), 2)
test_concat_Acc_prefill(((4096, 16, 128), (4096, 16, 64)), 2)
test_concat_Acc_prefill(((8192, 16, 128), (8192, 16, 64)), 2)
test_concat_Acc_prefill(((1024, 32, 128), (1024, 32, 64)), 2)
test_concat_Acc_prefill(((2048, 32, 128), (2048, 32, 64)), 2)
test_concat_Acc_prefill(((4096, 32, 128), (4096, 32, 64)), 2)
test_concat_Acc_prefill(((8192, 32, 128), (8192, 32, 64)), 2)
benchmark_decode.run(save_path="./cat_triton_test",print_data=True)
test_concat_Acc_decode(((16, 8, 512), (16, 8, 64)), 2)
test_concat_Acc_decode(((32, 8, 512), (32, 8, 64)), 2)
test_concat_Acc_decode(((128, 8, 512), (128, 8, 64)), 2)
test_concat_Acc_decode(((768, 8, 512), (768, 8, 64)), 2)
test_concat_Acc_decode(((32, 16, 512), (32, 16, 64)), 2)
test_concat_Acc_decode(((32, 32, 512), (32, 32, 64)), 2)
test_concat_Acc_decode(((768, 32, 512), (768, 32, 64)), 2)
test_concat_Acc_decode(((128, 32, 512), (128, 32, 64)), 2)
test_concat_Acc_decode(((512, 32, 512), (512, 32, 64)), 2)
test_concat_Acc_decode(((765, 8, 512), (765, 8, 64)), 2)
test_concat_Acc_decode(((766, 8, 512), (766, 8, 64)), 2)
test_concat_Acc_decode(((767, 8, 512), (767, 8, 64)), 2)
test_concat_Acc_decode(((765, 16, 512), (765, 16, 64)), 2)
test_concat_Acc_decode(((766, 16, 512), (766, 16, 64)), 2)
test_concat_Acc_decode(((765, 32, 512), (765, 32, 64)), 2)
test_concat_Acc_decode(((767, 32, 512), (767, 32, 64)), 2)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment