Unverified Commit 232e8d52 authored by fxmarty's avatar fxmarty Committed by GitHub
Browse files

MI300 compatibility (#1764)

Adds support for AMD Instinct MI300 in TGI.

Most changes are:
* Support PyTorch TunableOp to pick the GEMM/GEMV kernels for decoding
https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable.
TunableOp is disabled by default, and can be enabled with
`PYTORCH_TUNABLEOP_ENABLED=1`.
* Update ROCm dockerfile to PyTorch 2.3 (actually patched with changes
from https://github.com/pytorch/pytorch/pull/124362)
* Support SILU & Linear custom kernels contributed by AMD
* Update vLLM paged attention to https://github.com/fxmarty/rocm-vllm/,
branching out of a much more recent commit
https://github.com/ROCm/vllm/commit/3489ce7936c5de588916ae3047c44c23c0b0c308


* Support FA2 Triton kernel as recommended by AMD. Can be used by
specifying `ROCM_USE_FLASH_ATTN_V2_TRITON=1`.
* Update dockerfile to ROCm 6.1

By default, TunableOp tuning results are saved in `/data` (e.g.
`/data/tunableop_meta-llama-Llama-2-70b-chat-hf_tp1_rank0.csv`) in order
to avoid to have to rerun the tuning at each `docker run`.

Example:
```
Validator,PT_VERSION,2.3.0
Validator,ROCM_VERSION,6.1.0.0-82-5fabb4c
Validator,HIPBLASLT_VERSION,0.7.0-1549b021
Validator,GCN_ARCH_NAME,gfx942:sramecc+:xnack-
Validator,ROCBLAS_VERSION,4.1.0-cefa4a9b-dirty
GemmTunableOp_Half_TN,tn_8192_7_28672,Gemm_Rocblas_45475,0.132098
GemmTunableOp_Half_TN,tn_10240_4_8192,Gemm_Rocblas_45546,0.0484431
GemmTunableOp_Half_TN,tn_32000_6_8192,Default,0.149546
GemmTunableOp_Half_TN,tn_32000_3_8192,Gemm_Rocblas_45520,0.147119
GemmTunableOp_Half_TN,tn_8192_3_28672,Gemm_Rocblas_45475,0.132645
GemmTunableOp_Half_TN,tn_10240_3_8192,Gemm_Rocblas_45546,0.0482971
GemmTunableOp_Half_TN,tn_57344_5_8192,Gemm_Rocblas_45520,0.255694
GemmTunableOp_Half_TN,tn_10240_7_8192,Gemm_Rocblas_45517,0.0482522
GemmTunableOp_Half_TN,tn_8192_3_8192,Gemm_Rocblas_45546,0.0444671
GemmTunableOp_Half_TN,tn_8192_5_8192,Gemm_Rocblas_45546,0.0445834
GemmTunableOp_Half_TN,tn_57344_7_8192,Gemm_Rocblas_45520,0.25622
GemmTunableOp_Half_TN,tn_8192_2_28672,Gemm_Rocblas_45475,0.132122
GemmTunableOp_Half_TN,tn_8192_4_8192,Gemm_Rocblas_45517,0.0453191
GemmTunableOp_Half_TN,tn_10240_5_8192,Gemm_Rocblas_45517,0.0482514
GemmTunableOp_Half_TN,tn_8192_5_28672,Gemm_Rocblas_45542,0.133914
GemmTunableOp_Half_TN,tn_8192_2_8192,Gemm_Rocblas_45517,0.0446516
GemmTunableOp_Half_TN,tn_8192_1_28672,Gemm_Hipblaslt_TN_10814,0.131953
GemmTunableOp_Half_TN,tn_10240_2_8192,Gemm_Rocblas_45546,0.0481043
GemmTunableOp_Half_TN,tn_32000_4_8192,Gemm_Rocblas_45520,0.147497
GemmTunableOp_Half_TN,tn_8192_6_28672,Gemm_Rocblas_45529,0.134895
GemmTunableOp_Half_TN,tn_57344_2_8192,Gemm_Rocblas_45520,0.254716
GemmTunableOp_Half_TN,tn_57344_4_8192,Gemm_Rocblas_45520,0.255731
GemmTunableOp_Half_TN,tn_10240_6_8192,Gemm_Rocblas_45517,0.0484816
GemmTunableOp_Half_TN,tn_57344_3_8192,Gemm_Rocblas_45520,0.254701
GemmTunableOp_Half_TN,tn_8192_4_28672,Gemm_Rocblas_45475,0.132159
GemmTunableOp_Half_TN,tn_32000_2_8192,Default,0.147524
GemmTunableOp_Half_TN,tn_32000_5_8192,Default,0.147074
GemmTunableOp_Half_TN,tn_8192_6_8192,Gemm_Rocblas_45546,0.0454045
GemmTunableOp_Half_TN,tn_57344_6_8192,Gemm_Rocblas_45520,0.255582
GemmTunableOp_Half_TN,tn_32000_7_8192,Default,0.146705
GemmTunableOp_Half_TN,tn_8192_7_8192,Gemm_Rocblas_45546,0.0445489
```

---------
Co-authored-by: default avatarMohit Sharma <mohit21sharma.ms@gmail.com>
parent a60fa840
......@@ -26,6 +26,7 @@ from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils import paged_attention, flash_attn
from text_generation_server.layers import (
TensorParallelRowLinear,
......@@ -40,6 +41,13 @@ from text_generation_server.layers.layernorm import (
)
if SYSTEM == "rocm":
try:
from vllm import _custom_C
except Exception as e:
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
class MistralConfig(PretrainedConfig):
model_type = "mistral"
......@@ -251,14 +259,16 @@ class MistralAttention(torch.nn.Module):
class MistralMLP(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
act = config.hidden_act
self.hidden_act = config.hidden_act
self.act = (
ACT2FN[act]
if "gelu" not in act
ACT2FN[self.hidden_act]
if "gelu" not in self.hidden_act
else lambda x: torch.nn.functional.gelu(
x,
approximate=(
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
"tanh"
if self.hidden_act in ["gelu_fast", "gelu_pytorch_tanh"]
else "none"
),
)
)
......@@ -281,9 +291,23 @@ class MistralMLP(nn.Module):
)
def forward(self, hidden_states):
gate_up_states = self.gate_up_proj(hidden_states)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
if (
SYSTEM == "rocm"
and self.hidden_act == "silu"
and hidden_states.shape[0] == 1
):
out = torch.empty(
hidden_states.shape[0],
self.intermediate_size,
dtype=hidden_states.dtype,
device="cuda",
)
_custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8)
return self.down_proj(out)
else:
gate_up_states = self.gate_up_proj(hidden_states)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
class MistralLayer(nn.Module):
......
......@@ -60,7 +60,7 @@ from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM == "cuda":
import dropout_layer_norm
elif SYSTEM == "rocm":
from vllm import layernorm_ops
from vllm._C import ops
else:
raise RuntimeError(f"Unsupported system {SYSTEM}")
......@@ -420,7 +420,7 @@ class IdeficsRMSNorm(nn.Module):
hidden_states = hidden_states.reshape(-1, shape[-1])
out = torch.empty_like(hidden_states)
layernorm_ops.rms_norm(
ops.rms_norm(
out,
hidden_states,
self.weight.data,
......
......@@ -12,6 +12,9 @@ from dataclasses import dataclass
from opentelemetry import trace
from transformers import PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type, Dict
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models import Model
from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.utils.speculate import get_speculate
......@@ -28,6 +31,7 @@ from text_generation_server.models.cache_manager import (
)
from text_generation_server.pb import generate_pb2
from text_generation_server.models.globals import MEM_POOL, CUDA_GRAPHS
import text_generation_server.models.globals as tgi_globals
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
from text_generation_server.utils.dist import MEMORY_FRACTION
......@@ -783,6 +787,9 @@ class FlashCausalLM(Model):
)
max_bt = batch.max_blocks
max_s = max_bt * get_cache_manager().block_size
if SYSTEM == "rocm" and os.environ.get("PYTORCH_TUNABLEOP_ENABLED", False):
torch.cuda.tunable.tuning_enable(False)
_, batch, _ = self.generate_token(batch)
except torch.cuda.OutOfMemoryError as e:
raise RuntimeError(
......@@ -820,6 +827,49 @@ class FlashCausalLM(Model):
self.device,
)
if SYSTEM == "rocm":
if (
os.environ.get("PYTORCH_TUNABLEOP_ENABLED") is None
or os.environ.get("PYTORCH_TUNABLEOP_ENABLED") == "1"
):
if os.environ.get("PYTORCH_TUNABLEOP_TUNING") != "0":
torch.cuda.tunable.tuning_enable(True)
if os.environ.get("PYTORCH_TUNABLEOP_SEQLENS") is not None:
tuning_sequences = [
int(val)
for val in os.environ["PYTORCH_TUNABLEOP_SEQLENS"].split(",")
]
else:
tuning_sequences = CUDA_GRAPHS
tunableop_filepath = os.path.join(
HUGGINGFACE_HUB_CACHE,
f"tunableop_{tgi_globals.MODEL_ID.replace('/', '-')}_tp{self.world_size}_rank{self.rank}.csv",
)
logger.info(
f"PyTorch TunableOp (https://github.com/fxmarty/pytorch/tree/2.3-patched/aten/src/ATen/cuda/tunable) is enabled. The warmup may take several minutes, picking the ROCm optimal matrix multiplication kernel for the target lengths {', '.join([str(seqlen) for seqlen in tuning_sequences])}, with typical 5-8% latency improvement for small sequence lengths. The picked GEMMs are saved in the file {tunableop_filepath}. To disable TunableOp, please launch TGI with `PYTORCH_TUNABLEOP_ENABLED=0`."
)
if os.path.isfile(tunableop_filepath):
logger.info(
f"The file {tunableop_filepath} already exists and will be reused."
)
torch.cuda.tunable.read_file(tunableop_filepath)
os.makedirs(HUGGINGFACE_HUB_CACHE, exist_ok=True)
for seqlen in tuning_sequences:
logger.info(f"Warming up TunableOp for seqlen={seqlen}")
self.tunableop_warmup(seqlen)
torch.cuda.tunable.write_file(tunableop_filepath)
torch.cuda.tunable.tuning_enable(False)
else:
logger.info(
"PyTorch ROCm TunableOp (https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cuda/tunable) is disabled. TunableOp brings an additional 5-8% latency improvement for small sequence lengths but requires a warmup. If necessary, please use the environment variable PYTORCH_TUNABLEOP_ENABLED=1 to enable TunableOp."
)
if CUDA_GRAPHS:
try:
logger.info(f"Cuda Graphs are enabled for sizes {CUDA_GRAPHS}")
......@@ -834,6 +884,27 @@ class FlashCausalLM(Model):
return int(num_blocks * BLOCK_SIZE)
def tunableop_warmup(self, seqlen: int):
input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device)
position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
kv_cache = get_cache_manager().kv_cache
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=torch.tensor(
[0, seqlen], device=self.device, dtype=torch.int32
),
kv_cache=get_cache_manager().kv_cache,
block_tables=None,
input_lengths=None,
slots=slots,
max_s=seqlen,
lm_head_indices=None,
)
def forward(
self, batch: FlashCausalLMBatch
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
......@@ -1113,8 +1184,6 @@ class FlashCausalLM(Model):
next_token_texts = []
left = 0
logger.debug(f"Accepted ids {n_accepted_ids}")
current_stopped = False
for j in range(index, index + n_accepted_ids):
# Generated token
......
......@@ -15,11 +15,10 @@ from text_generation_server.utils import (
weight_files,
Weights,
)
from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__)
from text_generation_server.utils.import_utils import SYSTEM
class FlashGPT2(FlashCausalLM):
def __init__(
......
......@@ -15,3 +15,12 @@ else:
cuda_graphs = None
CUDA_GRAPHS = cuda_graphs
# This is overridden at model loading.
global MODEL_ID
MODEL_ID = None
def set_model_id(model_id: str):
global MODEL_ID
MODEL_ID = model_id
......@@ -21,6 +21,7 @@ from text_generation_server.models.vlm_causal_lm import (
from text_generation_server.pb import generate_pb2_grpc, generate_pb2
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
from text_generation_server.models.globals import set_model_id
class SignalHandler:
......@@ -252,6 +253,7 @@ def serve(
while signal_handler.KEEP_PROCESSING:
await asyncio.sleep(0.5)
set_model_id(model_id)
asyncio.run(
serve_inner(
model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code
......
......@@ -2,14 +2,18 @@ import os
import torch
from loguru import logger
import math
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.flash_attn_triton import triton_attention
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
raise ImportError("`USE_FLASH_ATTENTION` is false.")
HAS_FLASH_ATTN = True
HAS_FLASH_ATTN = False
HAS_FLASH_ATTN_V2_CUDA = False
HAS_FLASH_ATTN_V2_ROCM = False
ROCM_USE_FLASH_ATTN_V2_CK = False
ROCM_USE_FLASH_ATTN_V2_TRITON = False
if SYSTEM == "xpu":
import intel_extension_for_pytorch as ipex
......@@ -57,10 +61,21 @@ if SYSTEM in {"cuda", "rocm"}:
is_sm75 = major == 7 and minor == 5
is_sm8x = major == 8 and minor >= 0
is_sm90 = major == 9 and minor == 0
is_sm94 = major == 9 and minor == 4
if SYSTEM == "rocm":
if (
os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() == "true"
or os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "0") == "1"
):
ROCM_USE_FLASH_ATTN_V2_TRITON = True
logger.info("ROCm: using Flash Attention 2 Triton implementation.")
else:
ROCM_USE_FLASH_ATTN_V2_CK = True
logger.info(
"ROCm: using Flash Attention 2 Composable Kernel implementation."
)
HAS_FLASH_ATTN = False
HAS_FLASH_ATTN_V2_CUDA = False
HAS_FLASH_ATTN_V2_ROCM = False
try:
try:
import flash_attn_2_cuda
......@@ -71,11 +86,16 @@ if SYSTEM in {"cuda", "rocm"}:
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`"
)
if not (is_sm8x or is_sm90):
if SYSTEM == "cuda" and not (is_sm8x or is_sm90):
raise ImportError(
f"GPU with CUDA capability {major} {minor} is not supported for "
"Flash Attention V2"
)
elif SYSTEM == "rocm" and not (is_sm8x or is_sm90 or is_sm94):
raise ImportError(
f"AMD GPU with compute capability {major} {minor} is not supported for "
"Flash Attention V2"
)
HAS_FLASH_ATTN_V2_CUDA = SYSTEM == "cuda"
HAS_FLASH_ATTN_V2_ROCM = SYSTEM == "rocm"
except ImportError as e:
......@@ -142,7 +162,7 @@ if HAS_FLASH_ATTN_V2_CUDA:
None,
)
elif HAS_FLASH_ATTN_V2_ROCM:
elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_CK:
def attention(
q,
......@@ -153,6 +173,7 @@ elif HAS_FLASH_ATTN_V2_ROCM:
max_s,
softmax_scale,
window_size_left=-1,
causal=True,
):
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
......@@ -174,11 +195,38 @@ elif HAS_FLASH_ATTN_V2_ROCM:
0.0,
softmax_scale,
False,
True,
causal,
False,
None,
)
elif HAS_FLASH_ATTN_V2_ROCM and ROCM_USE_FLASH_ATTN_V2_TRITON:
def attention(
q,
k,
v,
out,
cu_seqlens,
max_s,
softmax_scale,
window_size_left=-1,
causal=True,
):
output, _ = triton_attention(
q,
k,
v,
out,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
causal,
softmax_scale,
)
return output
elif HAS_FLASH_ATTN:
def attention(
......
This diff is collapsed.
......@@ -5,6 +5,14 @@ _PARTITION_SIZE = 512
if SYSTEM == "xpu":
import intel_extension_for_pytorch as ipex
else:
try:
from vllm._C import cache_ops
from vllm._C import ops
except Exception as e:
raise ImportError(
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
)
def reshape_and_cache(
......@@ -14,22 +22,14 @@ def reshape_and_cache(
value_cache: torch.Tensor,
slots: torch.Tensor,
):
if SYSTEM == "cuda":
from vllm._C import cache_ops
cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, "auto", 1.0
)
elif SYSTEM == "rocm":
from vllm import cache_ops
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots)
elif SYSTEM == "xpu":
if SYSTEM == "xpu":
ipex.llm.modules.PagedAttention.reshape_and_cache(
key, value, key_cache, value_cache, slots
)
else:
raise ValueError("vllm is not supported on your system")
cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, "auto", 1.0
)
def attention(
......@@ -87,43 +87,21 @@ def attention(
# to parallelize.
use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)
if use_v1:
if SYSTEM == "cuda":
from vllm._C import ops
ops.paged_attention_v1(
out,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
"auto",
1.0,
)
elif SYSTEM == "rocm":
from vllm import attention_ops
attention_ops.paged_attention_v1(
out,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
)
else:
raise ValueError("vllm is not supported on your system")
ops.paged_attention_v1(
out,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
"auto",
1.0,
)
else:
# Run PagedAttention V2.
assert _PARTITION_SIZE % block_size == 0
......@@ -139,45 +117,21 @@ def attention(
)
max_logits = torch.empty_like(exp_sums)
if SYSTEM == "cuda":
from vllm._C import ops
ops.paged_attention_v2(
out,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
"auto",
1.0,
)
elif SYSTEM == "rocm":
from vllm import attention_ops
attention_ops.paged_attention_v2(
out,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
)
else:
raise ValueError("vllm is not supported on your system")
ops.paged_attention_v2(
out,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
"auto",
1.0,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment