Commit 66b809cc authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.7.2' into v0.7.2-dev

parents 37b63c24 0408efc6
# SPDX-License-Identifier: Apache-2.0
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type
......
# SPDX-License-Identifier: Apache-2.0
"""Attention layer ROCm GPUs.""" """Attention layer ROCm GPUs."""
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
......
# SPDX-License-Identifier: Apache-2.0
""" Attention layer with torch scaled_dot_product_attention """ Attention layer with torch scaled_dot_product_attention
and PagedAttention.""" and PagedAttention."""
from dataclasses import dataclass from dataclasses import dataclass
......
# SPDX-License-Identifier: Apache-2.0
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
...@@ -24,7 +26,6 @@ from vllm.attention.backends.mla.utils import MLACommonImpl, MLACommonMetadata ...@@ -24,7 +26,6 @@ from vllm.attention.backends.mla.utils import MLACommonImpl, MLACommonMetadata
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
compute_slot_mapping_start_idx, compute_slot_mapping_start_idx,
is_block_tables_empty) is_block_tables_empty)
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
from vllm.utils import async_tensor_h2d, make_tensor_with_pad from vllm.utils import async_tensor_h2d, make_tensor_with_pad
...@@ -70,14 +71,14 @@ class TritonMLABackend(AttentionBackend): ...@@ -70,14 +71,14 @@ class TritonMLABackend(AttentionBackend):
dst_kv_cache: torch.Tensor, dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor, src_to_dst: torch.Tensor,
) -> None: ) -> None:
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
@staticmethod @staticmethod
def copy_blocks( def copy_blocks(
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor, src_to_dists: torch.Tensor,
) -> None: ) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists) ops.copy_blocks_mla(kv_caches, src_to_dists)
@staticmethod @staticmethod
def get_supported_head_sizes() -> List[int]: def get_supported_head_sizes() -> List[int]:
......
# SPDX-License-Identifier: Apache-2.0
"""Attention backend utils""" """Attention backend utils"""
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
......
# SPDX-License-Identifier: Apache-2.0
"""Attention layer with xFormers and PagedAttention.""" """Attention layer with xFormers and PagedAttention."""
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type from typing import Any, Dict, List, Optional, Tuple, Type
......
# SPDX-License-Identifier: Apache-2.0
"""Attention layer.""" """Attention layer."""
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
...@@ -155,9 +156,13 @@ class Attention(nn.Module): ...@@ -155,9 +156,13 @@ class Attention(nn.Module):
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
if self.calculate_kv_scales and \ # NOTE: please avoid accessing `kv_cache` and `attn_metadata` arguments
attn_metadata.enable_kv_scales_calculation: # directly, use `self.kv_cache` and
self.calc_kv_scales(key, value) # `get_forward_context().attn_metadata` instead.
if self.calculate_kv_scales:
ctx_attn_metadata = get_forward_context().attn_metadata
if ctx_attn_metadata.enable_kv_scales_calculation:
self.calc_kv_scales(key, value)
if self.use_output: if self.use_output:
output = torch.empty_like(query) output = torch.empty_like(query)
hidden_size = query.size(-1) hidden_size = query.size(-1)
...@@ -171,15 +176,27 @@ class Attention(nn.Module): ...@@ -171,15 +176,27 @@ class Attention(nn.Module):
if value is not None: if value is not None:
value = value.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size)
if self.use_direct_call: if self.use_direct_call:
unified_attention_with_output(query, key, value, output, forward_context: ForwardContext = get_forward_context()
self.layer_name) ctx_attn_metadata = forward_context.attn_metadata
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
self.impl.forward(self,
query,
key,
value,
self_kv_cache,
ctx_attn_metadata,
output=output)
else: else:
torch.ops.vllm.unified_attention_with_output( torch.ops.vllm.unified_attention_with_output(
query, key, value, output, self.layer_name) query, key, value, output, self.layer_name)
return output.view(-1, hidden_size) return output.view(-1, hidden_size)
else: else:
if self.use_direct_call: if self.use_direct_call:
return unified_attention(query, key, value, self.layer_name) forward_context = get_forward_context()
ctx_attn_metadata = forward_context.attn_metadata
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
return self.impl.forward(self, query, key, value,
self_kv_cache, ctx_attn_metadata)
else: else:
return torch.ops.vllm.unified_attention( return torch.ops.vllm.unified_attention(
query, key, value, self.layer_name) query, key, value, self.layer_name)
......
# SPDX-License-Identifier: Apache-2.0
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
......
# SPDX-License-Identifier: Apache-2.0
import math import math
import torch import torch
......
# SPDX-License-Identifier: Apache-2.0
# Helper functions for 3D sparse pattern # Helper functions for 3D sparse pattern
# These function are not optimized and very inefficient. # These function are not optimized and very inefficient.
# Avoid calling them too frequent or use a cache mechanism. # Avoid calling them too frequent or use a cache mechanism.
......
# SPDX-License-Identifier: Apache-2.0
############################################################################### ###############################################################################
# Copyright (C) 2024 Habana Labs, Ltd. an Intel Company # Copyright (C) 2024 Habana Labs, Ltd. an Intel Company
############################################################################### ###############################################################################
......
# SPDX-License-Identifier: Apache-2.0
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
try: try:
......
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass from dataclasses import dataclass
import neuronxcc.nki.isa as nisa import neuronxcc.nki.isa as nisa
......
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
......
# SPDX-License-Identifier: Apache-2.0
# The kernels in this file are adapted from LightLLM's context_attention_fwd: # The kernels in this file are adapted from LightLLM's context_attention_fwd:
# https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py # https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py
...@@ -9,9 +11,12 @@ from vllm.platforms import current_platform ...@@ -9,9 +11,12 @@ from vllm.platforms import current_platform
# Static kernels parameters # Static kernels parameters
# BASE_BLOCK = 128 if current_platform.has_device_capability(80) else 64 # BASE_BLOCK = 128 if current_platform.has_device_capability(80) else 64
# NUM_WARPS = 4 if current_platform.is_rocm() else 8
BASE_BLOCK = 32 if current_platform.has_device_capability(80) else 32 BASE_BLOCK = 32 if current_platform.has_device_capability(80) else 32
NUM_WARPS = 8 NUM_WARPS = 8
# To check compatibility # To check compatibility
IS_TURING = current_platform.get_device_capability() == (7, 5) IS_TURING = current_platform.get_device_capability() == (7, 5)
......
# SPDX-License-Identifier: Apache-2.0
# Adapted from # Adapted from
# https://github.com/sgl-project/sglang/blob/9f635ea50de920aa507f486daafba26a5b837574/python/sglang/srt/layers/attention/triton_ops/decode_attention.py # https://github.com/sgl-project/sglang/blob/9f635ea50de920aa507f486daafba26a5b837574/python/sglang/srt/layers/attention/triton_ops/decode_attention.py
# which was originally adapted from # which was originally adapted from
...@@ -202,10 +204,10 @@ def _decode_att_m_fwd( ...@@ -202,10 +204,10 @@ def _decode_att_m_fwd(
Req_to_tokens.stride(0), Req_to_tokens.stride(0),
q.stride(0), q.stride(0),
q.stride(1), q.stride(1),
k_buffer.stride(-2), k_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
k_buffer.stride(-1), k_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
v_buffer.stride(-2), v_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
v_buffer.stride(-1), v_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
att_out.stride(0), att_out.stride(0),
att_out.stride(1), att_out.stride(1),
att_out.stride(2), att_out.stride(2),
...@@ -436,10 +438,10 @@ def _decode_grouped_att_m_fwd( ...@@ -436,10 +438,10 @@ def _decode_grouped_att_m_fwd(
Req_to_tokens.stride(0), Req_to_tokens.stride(0),
q.stride(0), q.stride(0),
q.stride(1), q.stride(1),
k_buffer.stride(-2), k_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
k_buffer.stride(-1), k_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
v_buffer.stride(-2), v_buffer.stride(-3), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
v_buffer.stride(-1), v_buffer.stride(-2), # Assume (..., PAGE_SIZE, NUM_HEADS, HEAD_DIM)
att_out.stride(0), att_out.stride(0),
att_out.stride(1), att_out.stride(1),
att_out.stride(2), att_out.stride(2),
......
#!/usr/bin/env python #!/usr/bin/env python
# SPDX-License-Identifier: Apache-2.0
""" """
Fused Attention Fused Attention
=============== ===============
......
# SPDX-License-Identifier: Apache-2.0
import os import os
from contextlib import contextmanager from contextlib import contextmanager
from functools import cache from functools import cache
......
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
......
# SPDX-License-Identifier: Apache-2.0
import ast import ast
import copy import copy
import dataclasses import dataclasses
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment