Unverified Commit 232e8d52 authored by fxmarty's avatar fxmarty Committed by GitHub
Browse files

MI300 compatibility (#1764)

Adds support for AMD Instinct MI300 in TGI.

Most changes are:
* Support PyTorch TunableOp to pick the GEMM/GEMV kernels for decoding
https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable.
TunableOp is disabled by default, and can be enabled with
`PYTORCH_TUNABLEOP_ENABLED=1`.
* Update ROCm dockerfile to PyTorch 2.3 (actually patched with changes
from https://github.com/pytorch/pytorch/pull/124362)
* Support SILU & Linear custom kernels contributed by AMD
* Update vLLM paged attention to https://github.com/fxmarty/rocm-vllm/,
branching out of a much more recent commit
https://github.com/ROCm/vllm/commit/3489ce7936c5de588916ae3047c44c23c0b0c308


* Support FA2 Triton kernel as recommended by AMD. Can be used by
specifying `ROCM_USE_FLASH_ATTN_V2_TRITON=1`.
* Update dockerfile to ROCm 6.1

By default, TunableOp tuning results are saved in `/data` (e.g.
`/data/tunableop_meta-llama-Llama-2-70b-chat-hf_tp1_rank0.csv`) in order
to avoid to have to rerun the tuning at each `docker run`.

Example:
```
Validator,PT_VERSION,2.3.0
Validator,ROCM_VERSION,6.1.0.0-82-5fabb4c
Validator,HIPBLASLT_VERSION,0.7.0-1549b021
Validator,GCN_ARCH_NAME,gfx942:sramecc+:xnack-
Validator,ROCBLAS_VERSION,4.1.0-cefa4a9b-dirty
GemmTunableOp_Half_TN,tn_8192_7_28672,Gemm_Rocblas_45475,0.132098
GemmTunableOp_Half_TN,tn_10240_4_8192,Gemm_Rocblas_45546,0.0484431
GemmTunableOp_Half_TN,tn_32000_6_8192,Default,0.149546
GemmTunableOp_Half_TN,tn_32000_3_8192,Gemm_Rocblas_45520,0.147119
GemmTunableOp_Half_TN,tn_8192_3_28672,Gemm_Rocblas_45475,0.132645
GemmTunableOp_Half_TN,tn_10240_3_8192,Gemm_Rocblas_45546,0.0482971
GemmTunableOp_Half_TN,tn_57344_5_8192,Gemm_Rocblas_45520,0.255694
GemmTunableOp_Half_TN,tn_10240_7_8192,Gemm_Rocblas_45517,0.0482522
GemmTunableOp_Half_TN,tn_8192_3_8192,Gemm_Rocblas_45546,0.0444671
GemmTunableOp_Half_TN,tn_8192_5_8192,Gemm_Rocblas_45546,0.0445834
GemmTunableOp_Half_TN,tn_57344_7_8192,Gemm_Rocblas_45520,0.25622
GemmTunableOp_Half_TN,tn_8192_2_28672,Gemm_Rocblas_45475,0.132122
GemmTunableOp_Half_TN,tn_8192_4_8192,Gemm_Rocblas_45517,0.0453191
GemmTunableOp_Half_TN,tn_10240_5_8192,Gemm_Rocblas_45517,0.0482514
GemmTunableOp_Half_TN,tn_8192_5_28672,Gemm_Rocblas_45542,0.133914
GemmTunableOp_Half_TN,tn_8192_2_8192,Gemm_Rocblas_45517,0.0446516
GemmTunableOp_Half_TN,tn_8192_1_28672,Gemm_Hipblaslt_TN_10814,0.131953
GemmTunableOp_Half_TN,tn_10240_2_8192,Gemm_Rocblas_45546,0.0481043
GemmTunableOp_Half_TN,tn_32000_4_8192,Gemm_Rocblas_45520,0.147497
GemmTunableOp_Half_TN,tn_8192_6_28672,Gemm_Rocblas_45529,0.134895
GemmTunableOp_Half_TN,tn_57344_2_8192,Gemm_Rocblas_45520,0.254716
GemmTunableOp_Half_TN,tn_57344_4_8192,Gemm_Rocblas_45520,0.255731
GemmTunableOp_Half_TN,tn_10240_6_8192,Gemm_Rocblas_45517,0.0484816
GemmTunableOp_Half_TN,tn_57344_3_8192,Gemm_Rocblas_45520,0.254701
GemmTunableOp_Half_TN,tn_8192_4_28672,Gemm_Rocblas_45475,0.132159
GemmTunableOp_Half_TN,tn_32000_2_8192,Default,0.147524
GemmTunableOp_Half_TN,tn_32000_5_8192,Default,0.147074
GemmTunableOp_Half_TN,tn_8192_6_8192,Gemm_Rocblas_45546,0.0454045
GemmTunableOp_Half_TN,tn_57344_6_8192,Gemm_Rocblas_45520,0.255582
GemmTunableOp_Half_TN,tn_32000_7_8192,Default,0.146705
GemmTunableOp_Half_TN,tn_8192_7_8192,Gemm_Rocblas_45546,0.0445489
```

---------
Co-authored-by: default avatarMohit Sharma <mohit21sharma.ms@gmail.com>
parent a60fa840
...@@ -26,6 +26,7 @@ from transformers.activations import ACT2FN ...@@ -26,6 +26,7 @@ from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
...@@ -40,6 +41,13 @@ from text_generation_server.layers.layernorm import ( ...@@ -40,6 +41,13 @@ from text_generation_server.layers.layernorm import (
) )
if SYSTEM == "rocm":
try:
from vllm import _custom_C
except Exception as e:
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
class MistralConfig(PretrainedConfig): class MistralConfig(PretrainedConfig):
model_type = "mistral" model_type = "mistral"
...@@ -251,14 +259,16 @@ class MistralAttention(torch.nn.Module): ...@@ -251,14 +259,16 @@ class MistralAttention(torch.nn.Module):
class MistralMLP(nn.Module): class MistralMLP(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
act = config.hidden_act self.hidden_act = config.hidden_act
self.act = ( self.act = (
ACT2FN[act] ACT2FN[self.hidden_act]
if "gelu" not in act if "gelu" not in self.hidden_act
else lambda x: torch.nn.functional.gelu( else lambda x: torch.nn.functional.gelu(
x, x,
approximate=( approximate=(
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none" "tanh"
if self.hidden_act in ["gelu_fast", "gelu_pytorch_tanh"]
else "none"
), ),
) )
) )
...@@ -281,6 +291,20 @@ class MistralMLP(nn.Module): ...@@ -281,6 +291,20 @@ class MistralMLP(nn.Module):
) )
def forward(self, hidden_states): def forward(self, hidden_states):
if (
SYSTEM == "rocm"
and self.hidden_act == "silu"
and hidden_states.shape[0] == 1
):
out = torch.empty(
hidden_states.shape[0],
self.intermediate_size,
dtype=hidden_states.dtype,
device="cuda",
)
_custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8)
return self.down_proj(out)
else:
gate_up_states = self.gate_up_proj(hidden_states) gate_up_states = self.gate_up_proj(hidden_states)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]) return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
......
...@@ -60,7 +60,7 @@ from text_generation_server.utils.import_utils import SYSTEM ...@@ -60,7 +60,7 @@ from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM == "cuda": if SYSTEM == "cuda":
import dropout_layer_norm import dropout_layer_norm
elif SYSTEM == "rocm": elif SYSTEM == "rocm":
from vllm import layernorm_ops from vllm._C import ops
else: else:
raise RuntimeError(f"Unsupported system {SYSTEM}") raise RuntimeError(f"Unsupported system {SYSTEM}")
...@@ -420,7 +420,7 @@ class IdeficsRMSNorm(nn.Module): ...@@ -420,7 +420,7 @@ class IdeficsRMSNorm(nn.Module):
hidden_states = hidden_states.reshape(-1, shape[-1]) hidden_states = hidden_states.reshape(-1, shape[-1])
out = torch.empty_like(hidden_states) out = torch.empty_like(hidden_states)
layernorm_ops.rms_norm( ops.rms_norm(
out, out,
hidden_states, hidden_states,
self.weight.data, self.weight.data,
......
...@@ -12,6 +12,9 @@ from dataclasses import dataclass ...@@ -12,6 +12,9 @@ from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type, Dict from typing import Optional, Tuple, List, Type, Dict
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models import Model from text_generation_server.models import Model
from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.utils.speculate import get_speculate from text_generation_server.utils.speculate import get_speculate
...@@ -28,6 +31,7 @@ from text_generation_server.models.cache_manager import ( ...@@ -28,6 +31,7 @@ from text_generation_server.models.cache_manager import (
) )
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.models.globals import MEM_POOL, CUDA_GRAPHS from text_generation_server.models.globals import MEM_POOL, CUDA_GRAPHS
import text_generation_server.models.globals as tgi_globals
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
from text_generation_server.utils.dist import MEMORY_FRACTION from text_generation_server.utils.dist import MEMORY_FRACTION
...@@ -783,6 +787,9 @@ class FlashCausalLM(Model): ...@@ -783,6 +787,9 @@ class FlashCausalLM(Model):
) )
max_bt = batch.max_blocks max_bt = batch.max_blocks
max_s = max_bt * get_cache_manager().block_size max_s = max_bt * get_cache_manager().block_size
if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
torch.cuda.tunable.tuning_enable(False)
_, batch, _ = self.generate_token(batch) _, batch, _ = self.generate_token(batch)
except torch.cuda.OutOfMemoryError as e: except torch.cuda.OutOfMemoryError as e:
raise RuntimeError( raise RuntimeError(
...@@ -820,6 +827,49 @@ class FlashCausalLM(Model): ...@@ -820,6 +827,49 @@ class FlashCausalLM(Model):
self.device, self.device,
) )
if SYSTEM == "rocm":
if (
os.environ.get("PYTORCH_TUNABLEOP_ENABLED") is None
or os.environ.get("PYTORCH_TUNABLEOP_ENABLED") == "1"
):
if os.environ.get("PYTORCH_TUNABLEOP_TUNING") != "0":
torch.cuda.tunable.tuning_enable(True)
if os.environ.get("PYTORCH_TUNABLEOP_SEQLENS") is not None:
tuning_sequences = [
int(val)
for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",")
]
else:
tuning_sequences = CUDA_GRAPHS
tunableop_filepath = os.path.join(
HUGGINGFACE_HUB_CACHE,
f"tunableop_{tgi_globals.MODEL_ID.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv",
)
logger.info(
f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`."
)
if os.path.isfile(tunableop_filepath):
logger.info(
f"The file {tunableop_filepath} already exists and will be reused."
)
torch.cuda.tunable.read_file(tunableop_filepath)
os.makedirs(HUGGINGFACE_HUB_CACHE, exist_ok=True)
for seqlen in tuning_sequences:
logger.info(f"Warming up TunableOp for seqlen={seqlen}")
self.tunableop_warmup(seqlen)
torch.cuda.tunable.write_file(tunableop_filepath)
torch.cuda.tunable.tuning_enable(False)
else:
logger.info(
"PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp."
)
if CUDA_GRAPHS: if CUDA_GRAPHS:
try: try:
logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}") logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}")
...@@ -834,6 +884,27 @@ class FlashCausalLM(Model): ...@@ -834,6 +884,27 @@ class FlashCausalLM(Model):
return int(num_blocks * BLOCK_SIZE) return int(num_blocks * BLOCK_SIZE)
def tunableop_warmup(self, seqlen: int):
input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device)
position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
kv_cache = get_cache_manager().kv_cache
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=torch.tensor(
[0, seqlen], device=self.device, dtype=torch.int32
),
kv_cache=get_cache_manager().kv_cache,
block_tables=None,
input_lengths=None,
slots=slots,
max_s=seqlen,
lm_head_indices=None,
)
def forward( def forward(
self, batch: FlashCausalLMBatch self, batch: FlashCausalLMBatch
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
...@@ -1113,8 +1184,6 @@ class FlashCausalLM(Model): ...@@ -1113,8 +1184,6 @@ class FlashCausalLM(Model):
next_token_texts = [] next_token_texts = []
left = 0 left = 0
logger.debug(f"Accepted ids {n_accepted_ids}")
current_stopped = False current_stopped = False
for j in range(index, index + n_accepted_ids): for j in range(index, index + n_accepted_ids):
# Generated token # Generated token
......
...@@ -15,11 +15,10 @@ from text_generation_server.utils import ( ...@@ -15,11 +15,10 @@ from text_generation_server.utils import (
weight_files, weight_files,
Weights, Weights,
) )
from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
from text_generation_server.utils.import_utils import SYSTEM
class FlashGPT2(FlashCausalLM): class FlashGPT2(FlashCausalLM):
def __init__( def __init__(
......
...@@ -15,3 +15,12 @@ else: ...@@ -15,3 +15,12 @@ else:
cuda_graphs = None cuda_graphs = None
CUDA_GRAPHS = cuda_graphs CUDA_GRAPHS = cuda_graphs
# This is overridden at model loading.
global MODEL_ID
MODEL_ID = None
def set_model_id(model_id: str):
global MODEL_ID
MODEL_ID = model_id
...@@ -21,6 +21,7 @@ from text_generation_server.models.vlm_causal_lm import ( ...@@ -21,6 +21,7 @@ from text_generation_server.models.vlm_causal_lm import (
from text_generation_server.pb import generate_pb2_grpc, generate_pb2 from text_generation_server.pb import generate_pb2_grpc, generate_pb2
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
from text_generation_server.models.globals import set_model_id
class SignalHandler: class SignalHandler:
...@@ -252,6 +253,7 @@ def serve( ...@@ -252,6 +253,7 @@ def serve(
while signal_handler.KEEP_PROCESSING: while signal_handler.KEEP_PROCESSING:
await asyncio.sleep(0.5) await asyncio.sleep(0.5)
set_model_id(model_id)
asyncio.run( asyncio.run(
serve_inner( serve_inner(
model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code
......
...@@ -2,14 +2,18 @@ import os ...@@ -2,14 +2,18 @@ import os
import torch import torch
from loguru import logger from loguru import logger
import math
from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.flash_attn_triton import triton_attention
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
raise ImportError("`USE_FLASH_ATTENTION` is false.") raise ImportError("`USE_FLASH_ATTENTION` is false.")
HAS_FLASH_ATTN = True HAS_FLASH_ATTN = False
HAS_FLASH_ATTN_V2_CUDA = False HAS_FLASH_ATTN_V2_CUDA = False
HAS_FLASH_ATTN_V2_ROCM = False HAS_FLASH_ATTN_V2_ROCM = False
ROCM_USE_FLASH_ATTN_V2_CK = False
ROCM_USE_FLASH_ATTN_V2_TRITON = False
if SYSTEM == "xpu": if SYSTEM == "xpu":
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex
...@@ -57,10 +61,21 @@ if SYSTEM in {"cuda", "rocm"}: ...@@ -57,10 +61,21 @@ if SYSTEM in {"cuda", "rocm"}:
is_sm75 = major == 7 and minor == 5 is_sm75 = major == 7 and minor == 5
is_sm8x = major == 8 and minor >= 0 is_sm8x = major == 8 and minor >= 0
is_sm90 = major == 9 and minor == 0 is_sm90 = major == 9 and minor == 0
is_sm94 = major == 9 and minor == 4
if SYSTEM == "rocm":
if (
os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() == "true"
or os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "0") == "1"
):
ROCM_USE_FLASH_ATTN_V2_TRITON = True
logger.info("ROCm: using Flash Attention 2 Triton implementation.")
else:
ROCM_USE_FLASH_ATTN_V2_CK = True
logger.info(
"ROCm: using Flash Attention 2 Composable Kernel implementation."
)
HAS_FLASH_ATTN = False
HAS_FLASH_ATTN_V2_CUDA = False
HAS_FLASH_ATTN_V2_ROCM = False
try: try:
try: try:
import flash_attn_2_cuda import flash_attn_2_cuda
...@@ -71,11 +86,16 @@ if SYSTEM in {"cuda", "rocm"}: ...@@ -71,11 +86,16 @@ if SYSTEM in {"cuda", "rocm"}:
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`" f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
) )
if not (is_sm8x or is_sm90): if SYSTEM == "cuda" and not (is_sm8x or is_sm90):
raise ImportError( raise ImportError(
f"GPU with CUDA capability {major} {minor} is not supported for " f"GPU with CUDA capability {major} {minor} is not supported for "
"Flash Attention V2" "Flash Attention V2"
) )
elif SYSTEM == "rocm" and not (is_sm8x or is_sm90 or is_sm94):
raise ImportError(
f"AMD GPU with compute capability {major} {minor} is not supported for "
"Flash Attention V2"
)
HAS_FLASH_ATTN_V2_CUDA = SYSTEM == "cuda" HAS_FLASH_ATTN_V2_CUDA = SYSTEM == "cuda"
HAS_FLASH_ATTN_V2_ROCM = SYSTEM == "rocm" HAS_FLASH_ATTN_V2_ROCM = SYSTEM == "rocm"
except ImportError as e: except ImportError as e:
...@@ -142,7 +162,7 @@ if HAS_FLASH_ATTN_V2_CUDA: ...@@ -142,7 +162,7 @@ if HAS_FLASH_ATTN_V2_CUDA:
None, None,
) )
elif HAS_FLASH_ATTN_V2_ROCM: elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_CK:
def attention( def attention(
q, q,
...@@ -153,6 +173,7 @@ elif HAS_FLASH_ATTN_V2_ROCM: ...@@ -153,6 +173,7 @@ elif HAS_FLASH_ATTN_V2_ROCM:
max_s, max_s,
softmax_scale, softmax_scale,
window_size_left=-1, window_size_left=-1,
causal=True,
): ):
if window_size_left <= 0 and window_size_left != -1: if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1") raise ValueError("`window_size_left` must be > 0 or -1")
...@@ -174,11 +195,38 @@ elif HAS_FLASH_ATTN_V2_ROCM: ...@@ -174,11 +195,38 @@ elif HAS_FLASH_ATTN_V2_ROCM:
0.0, 0.0,
softmax_scale, softmax_scale,
False, False,
True, causal,
False, False,
None, None,
) )
elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_TRITON:
def attention(
q,
k,
v,
out,
cu_seqlens,
max_s,
softmax_scale,
window_size_left=-1,
causal=True,
):
output, _ = triton_attention(
q,
k,
v,
out,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
causal,
softmax_scale,
)
return output
elif HAS_FLASH_ATTN: elif HAS_FLASH_ATTN:
def attention( def attention(
......
This diff is collapsed.
...@@ -5,6 +5,14 @@ _PARTITION_SIZE = 512 ...@@ -5,6 +5,14 @@ _PARTITION_SIZE = 512
if SYSTEM == "xpu": if SYSTEM == "xpu":
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex
else:
try:
from vllm._C import cache_ops
from vllm._C import ops
except Exception as e:
raise ImportError(
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
)
def reshape_and_cache( def reshape_and_cache(
...@@ -14,22 +22,14 @@ def reshape_and_cache( ...@@ -14,22 +22,14 @@ def reshape_and_cache(
value_cache: torch.Tensor, value_cache: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
): ):
if SYSTEM == "cuda": if SYSTEM == "xpu":
from vllm._C import cache_ops
cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, "auto", 1.0
)
elif SYSTEM == "rocm":
from vllm import cache_ops
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots)
elif SYSTEM == "xpu":
ipex.llm.modules.PagedAttention.reshape_and_cache( ipex.llm.modules.PagedAttention.reshape_and_cache(
key, value, key_cache, value_cache, slots key, value, key_cache, value_cache, slots
) )
else: else:
raise ValueError("vllm is not supported on your system") cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, "auto", 1.0
)
def attention( def attention(
...@@ -87,9 +87,6 @@ def attention( ...@@ -87,9 +87,6 @@ def attention(
# to parallelize. # to parallelize.
use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512) use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)
if use_v1: if use_v1:
if SYSTEM == "cuda":
from vllm._C import ops
ops.paged_attention_v1( ops.paged_attention_v1(
out, out,
query, query,
...@@ -105,25 +102,6 @@ def attention( ...@@ -105,25 +102,6 @@ def attention(
"auto", "auto",
1.0, 1.0,
) )
elif SYSTEM == "rocm":
from vllm import attention_ops
attention_ops.paged_attention_v1(
out,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
)
else:
raise ValueError("vllm is not supported on your system")
else: else:
# Run PagedAttention V2. # Run PagedAttention V2.
assert _PARTITION_SIZE % block_size == 0 assert _PARTITION_SIZE % block_size == 0
...@@ -139,9 +117,6 @@ def attention( ...@@ -139,9 +117,6 @@ def attention(
) )
max_logits = torch.empty_like(exp_sums) max_logits = torch.empty_like(exp_sums)
if SYSTEM == "cuda":
from vllm._C import ops
ops.paged_attention_v2( ops.paged_attention_v2(
out, out,
exp_sums, exp_sums,
...@@ -160,24 +135,3 @@ def attention( ...@@ -160,24 +135,3 @@ def attention(
"auto", "auto",
1.0, 1.0,
) )
elif SYSTEM == "rocm":
from vllm import attention_ops
attention_ops.paged_attention_v2(
out,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
)
else:
raise ValueError("vllm is not supported on your system")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment