Commit 66b809cc authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.7.2' into v0.7.2-dev

parents 37b63c24 0408efc6
# SPDX-License-Identifier: Apache-2.0
from collections import defaultdict
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type
......
# SPDX-License-Identifier: Apache-2.0
"""Attention layer ROCm GPUs."""
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
......
# SPDX-License-Identifier: Apache-2.0
""" Attention layer with torch scaled_dot_product_attention
and PagedAttention."""
from dataclasses import dataclass
......
# SPDX-License-Identifier: Apache-2.0
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
......@@ -24,7 +26,6 @@ from vllm.attention.backends.mla.utils import MLACommonImpl, MLACommonMetadata
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
compute_slot_mapping_start_idx,
is_block_tables_empty)
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
......@@ -70,14 +71,14 @@ class TritonMLABackend(AttentionBackend):
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists)
ops.copy_blocks_mla(kv_caches, src_to_dists)
@staticmethod
def get_supported_head_sizes() -> List[int]:
......
# SPDX-License-Identifier: Apache-2.0
"""Attention backend utils"""
from collections import defaultdict
from contextlib import contextmanager
......
# SPDX-License-Identifier: Apache-2.0
"""Attention layer with xFormers and PagedAttention."""
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type
......
# SPDX-License-Identifier: Apache-2.0
"""Attention layer."""
from typing import Any, Dict, List, Optional
......@@ -155,9 +156,13 @@ class Attention(nn.Module):
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
if self.calculate_kv_scales and \
attn_metadata.enable_kv_scales_calculation:
self.calc_kv_scales(key, value)
# NOTE: please avoid accessing `kv_cache` and `attn_metadata` arguments
# directly, use `self.kv_cache` and
# `get_forward_context().attn_metadata` instead.
if self.calculate_kv_scales:
ctx_attn_metadata = get_forward_context().attn_metadata
if ctx_attn_metadata.enable_kv_scales_calculation:
self.calc_kv_scales(key, value)
if self.use_output:
output = torch.empty_like(query)
hidden_size = query.size(-1)
......@@ -171,15 +176,27 @@ class Attention(nn.Module):
if value is not None:
value = value.view(-1, self.num_kv_heads, self.head_size)
if self.use_direct_call:
unified_attention_with_output(query, key, value, output,
self.layer_name)
forward_context: ForwardContext = get_forward_context()
ctx_attn_metadata = forward_context.attn_metadata
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
self.impl.forward(self,
query,
key,
value,
self_kv_cache,
ctx_attn_metadata,
output=output)
else:
torch.ops.vllm.unified_attention_with_output(
query, key, value, output, self.layer_name)
return output.view(-1, hidden_size)
else:
if self.use_direct_call:
return unified_attention(query, key, value, self.layer_name)
forward_context = get_forward_context()
ctx_attn_metadata = forward_context.attn_metadata
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
return self.impl.forward(self, query, key, value,
self_kv_cache, ctx_attn_metadata)
else:
return torch.ops.vllm.unified_attention(
query, key, value, self.layer_name)
......
# SPDX-License-Identifier: Apache-2.0
import torch
import triton
import triton.language as tl
......
# SPDX-License-Identifier: Apache-2.0
import math
import torch
......
# SPDX-License-Identifier: Apache-2.0
# Helper functions for 3D sparse pattern
# These function are not optimized and very inefficient.
# Avoid calling them too frequent or use a cache mechanism.
......
# SPDX-License-Identifier: Apache-2.0
###############################################################################
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company
###############################################################################
......
# SPDX-License-Identifier: Apache-2.0
from typing import Dict, List, Optional, Tuple
try:
......
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
import neuronxcc.nki.isa as nisa
......
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from typing import List, Optional, Tuple
......
# SPDX-License-Identifier: Apache-2.0
# The kernels in this file are adapted from LightLLM's context_attention_fwd:
# https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py
......@@ -9,9 +11,12 @@ from vllm.platforms import current_platform
# Static kernels parameters
# BASE_BLOCK = 128 if current_platform.has_device_capability(80) else 64
# NUM_WARPS = 4 if current_platform.is_rocm() else 8
BASE_BLOCK = 32 if current_platform.has_device_capability(80) else 32
NUM_WARPS = 8
# To check compatibility
IS_TURING = current_platform.get_device_capability() == (7, 5)
......
# SPDX-License-Identifier: Apache-2.0
# Adapted from
# https://github.com/sgl-project/sglang/blob/9f635ea50de920aa507f486daafba26a5b837574/python/sglang/srt/layers/attention/triton_ops/decode_attention.py
# which was originally adapted from
......@@ -202,10 +204,10 @@ def _decode_att_m_fwd(
Req_to_tokens.stride(0),
q.stride(0),
q.stride(1),
k_buffer.stride(-2),
k_buffer.stride(-1),
v_buffer.stride(-2),
v_buffer.stride(-1),
k_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
k_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
v_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
v_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
att_out.stride(0),
att_out.stride(1),
att_out.stride(2),
......@@ -436,10 +438,10 @@ def _decode_grouped_att_m_fwd(
Req_to_tokens.stride(0),
q.stride(0),
q.stride(1),
k_buffer.stride(-2),
k_buffer.stride(-1),
v_buffer.stride(-2),
v_buffer.stride(-1),
k_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
k_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
v_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
v_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
att_out.stride(0),
att_out.stride(1),
att_out.stride(2),
......
#!/usr/bin/env python
# SPDX-License-Identifier: Apache-2.0
"""
Fused Attention
===============
......
# SPDX-License-Identifier: Apache-2.0
import os
from contextlib import contextmanager
from functools import cache
......
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
......
# SPDX-License-Identifier: Apache-2.0
import ast
import copy
import dataclasses
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment