Unverified Commit b2b5df0e authored by fxmarty's avatar fxmarty Committed by GitHub
Browse files

Add RoCm support (#1243)



This PR adds support for AMD Instinct MI210 & MI250 GPUs, with paged
attention and FAv2 support.

Remaining items to discuss, on top of possible others:
* Should we have a
`ghcr.io/huggingface/text-generation-inference:1.1.0+rocm` hosted image,
or is it too early?
* Should we set up a CI on MI210/MI250? I don't have access to the
runners of TGI though.
* Are we comfortable with those changes being directly in TGI, or do we
need a fork?

---------
Co-authored-by: default avatarFelix Marty <felix@hf.co>
Co-authored-by: default avatarOlivierDehaene <olivier@huggingface.co>
Co-authored-by: default avatarYour Name <you@example.com>
parent ed2a3f61
import torch
IS_ROCM_SYSTEM = torch.version.hip is not None
IS_CUDA_SYSTEM = torch.version.cuda is not None
...@@ -12,14 +12,13 @@ HAS_BITS_AND_BYTES = True ...@@ -12,14 +12,13 @@ HAS_BITS_AND_BYTES = True
try: try:
import bitsandbytes as bnb import bitsandbytes as bnb
from bitsandbytes.nn import Int8Params, Params4bit from bitsandbytes.nn import Int8Params, Params4bit
except ImportError: except ImportError:
HAS_BITS_AND_BYTES = False HAS_BITS_AND_BYTES = False
from accelerate import init_empty_weights from accelerate import init_empty_weights
from text_generation_server.utils.gptq.quant_linear import QuantLinear from text_generation_server.utils.gptq.quant_linear import QuantLinear
from text_generation_server.utils.import_utils import IS_CUDA_SYSTEM, IS_ROCM_SYSTEM
HAS_AWQ = True HAS_AWQ = True
try: try:
...@@ -525,11 +524,14 @@ class TensorParallelEmbedding(nn.Module): ...@@ -525,11 +524,14 @@ class TensorParallelEmbedding(nn.Module):
try: try:
import dropout_layer_norm if IS_CUDA_SYSTEM:
import dropout_layer_norm
else:
dropout_layer_norm = None
class FastLayerNorm(nn.LayerNorm): class FastLayerNorm(nn.LayerNorm):
def forward(self, hidden_states, residual=None): def forward(self, hidden_states, residual=None):
if hidden_states.shape[-1] > 8192: if hidden_states.shape[-1] > 8192 or IS_ROCM_SYSTEM:
if residual is not None: if residual is not None:
hidden_states += residual hidden_states += residual
residual = hidden_states residual = hidden_states
...@@ -561,14 +563,16 @@ try: ...@@ -561,14 +563,16 @@ try:
residual = hidden_states residual = hidden_states
return normed_hidden_states, residual return normed_hidden_states, residual
except ImportError: except ImportError:
pass pass
try: try:
from flash_attn.layers.rotary import RotaryEmbedding if IS_CUDA_SYSTEM:
import rotary_emb from flash_attn.layers.rotary import RotaryEmbedding
import rotary_emb
elif IS_ROCM_SYSTEM:
from vllm import pos_encoding_ops
def _create_inv_freq(dim, base, device): def _create_inv_freq(dim, base, device):
inv_freq = 1.0 / ( inv_freq = 1.0 / (
...@@ -597,6 +601,37 @@ try: ...@@ -597,6 +601,37 @@ try:
self.scaling_factor = scaling_factor self.scaling_factor = scaling_factor
self.dynamic_args = None self.dynamic_args = None
def forward(self, query: torch.Tensor, key: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
# Such controlflows may add some overhead.
if IS_CUDA_SYSTEM:
rotary_dim = cos.shape[-1]
q1 = query[..., :rotary_dim]
q2 = query[..., rotary_dim : 2 * rotary_dim]
rotary_emb.apply_rotary(q1, q2, cos, sin, q1, q2, False)
k1 = key[..., :rotary_dim]
k2 = key[..., rotary_dim : 2 * rotary_dim]
rotary_emb.apply_rotary(k1, k2, cos, sin, k1, k2, False)
elif IS_ROCM_SYSTEM:
# NOTE: On RoCm systems, we use a ROPE implementatation adapted from VLLM which launches a single kernel for both query/key, contrary to flash-attn implementation used on NVIDIA systems.
# Compiling flash-attn rotary on RoCm, it appears hipcc is unable to unroll loops, resulting in an even slower inference compared to eager: https://github.com/pytorch/pytorch/issues/113773
head_size = query.shape[-1]
# Inplace operation, updating query and key.
pos_encoding_ops.rotary_embedding(
query,
key,
head_size,
cos,
sin,
True
)
else:
raise ValueError("Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction.")
@classmethod @classmethod
def static(cls, config, dim, base, device): def static(cls, config, dim, base, device):
inv_freq = _create_inv_freq(dim, base, device) inv_freq = _create_inv_freq(dim, base, device)
...@@ -699,21 +734,19 @@ try: ...@@ -699,21 +734,19 @@ try:
""" """
Return cos and sin for the asked position ids Return cos and sin for the asked position ids
""" """
if IS_ROCM_SYSTEM:
# For RoCm, we always use float cos/sin to avoid a cast.
# For NVIDIA, for some reason, the flash-attn rotary kernel requires cos/sin and query/key to be of same dtype: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary.cpp#L26
# But later on goes and cast cos/sin to float anyway: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary_cuda.cu#L29, which looks suboptimal.
dtype = torch.float32
self._update_cos_sin_cache(dtype, position_ids.device, max_s) self._update_cos_sin_cache(dtype, position_ids.device, max_s)
cos = torch.index_select(self._cos_cached, 0, position_ids) cos = torch.index_select(self._cos_cached, 0, position_ids)
sin = torch.index_select(self._sin_cached, 0, position_ids) sin = torch.index_select(self._sin_cached, 0, position_ids)
# Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.
return cos.unsqueeze(1), sin.unsqueeze(1) return cos.unsqueeze(1), sin.unsqueeze(1)
def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
rotary_dim = cos.shape[-1]
x1 = x[..., :rotary_dim]
x2 = x[..., rotary_dim : 2 * rotary_dim]
rotary_emb.apply_rotary(x1, x2, cos, sin, x1, x2, False)
return x
class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding): class DynamicPositionRotaryEmbedding(PositionRotaryEmbedding):
def __init__(self, dim, max_position_embeddings, base, device, scaling_factor): def __init__(self, dim, max_position_embeddings, base, device, scaling_factor):
inv_freq = _create_inv_freq(dim, base, device) inv_freq = _create_inv_freq(dim, base, device)
...@@ -722,7 +755,7 @@ try: ...@@ -722,7 +755,7 @@ try:
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.base = base self.base = base
def _update_cos_sin_cache(self, dtype, device, seqlen): def _update_cos_sin_cache(self, dtype, device, seqlen):
# Reset the tables if the sequence length has changed, # Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance) # or if we're on a new device (possibly due to tracing for instance)
if ( if (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment