"examples/vscode:/vscode.git/clone" did not exist on "7df61696f57a11fbefb850c28acde501fd5b753f"
Unverified Commit dea9c0dc authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Fixing rocm. (#2164)

parent b966bc0d
...@@ -2,6 +2,7 @@ import os ...@@ -2,6 +2,7 @@ import os
import torch import torch
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import FLASH_DECODING from text_generation_server.models.globals import FLASH_DECODING
from text_generation_server.layers.attention import Seqlen
from loguru import logger from loguru import logger
major, minor = torch.cuda.get_device_capability() major, minor = torch.cuda.get_device_capability()
...@@ -45,8 +46,7 @@ def paged_attention( ...@@ -45,8 +46,7 @@ def paged_attention(
kv_head_mapping: torch.Tensor, kv_head_mapping: torch.Tensor,
softmax_scale: float, softmax_scale: float,
block_tables: torch.Tensor, block_tables: torch.Tensor,
cu_seqlen_q: torch.Tensor, input_lengths: Seqlen,
cu_seqlen_k: torch.Tensor,
max_s: int, max_s: int,
): ):
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py # Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
...@@ -70,7 +70,7 @@ def paged_attention( ...@@ -70,7 +70,7 @@ def paged_attention(
block_size = value_cache.shape[3] block_size = value_cache.shape[3]
num_seqs, num_heads, head_size = query.shape num_seqs, num_heads, head_size = query.shape
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
input_lengths = cu_seqlen_k input_lengths = input_lengths.input_lengths
# NOTE(woosuk): We use a simple heuristic to decide whether to use # NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use # PagedAttention V1 or V2. If the number of partitions is 1, we use
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment