Unverified Commit fd89d9df authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Refactor layers. (#1866)

# What does this PR do?

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)


## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?


## Who can review?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @


@OlivierDehaene OR @Narsil

 -->
parent 59b3ffea
......@@ -8,12 +8,12 @@ from typing import Optional, Tuple, Any
from transformers.configuration_utils import PretrainedConfig
import torch.nn.functional as F
from text_generation_server.utils.layers import (
from text_generation_server.layers import (
SpeculativeHead,
TensorParallelEmbedding,
FastRMSNorm,
FastLinear,
)
from text_generation_server.layers.layernorm import FastRMSNorm
from einops import rearrange
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
......
......@@ -17,7 +17,7 @@ from transformers.modeling_outputs import (
)
from einops import rearrange
from packaging import version
from text_generation_server.utils.layers import (
from text_generation_server.layers import (
TensorParallelEmbedding,
TensorParallelColumnLinear,
TensorParallelRowLinear,
......
......@@ -40,7 +40,7 @@ from transformers.modeling_outputs import (
from transformers.modeling_utils import PreTrainedModel
from transformers import GPTNeoXConfig
from loguru import logger
from text_generation_server.utils.layers import (
from text_generation_server.layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
......
......@@ -27,7 +27,7 @@ from transformers.modeling_outputs import (
)
from transformers.modeling_utils import PreTrainedModel
from transformers import OPTConfig
from text_generation_server.utils.layers import (
from text_generation_server.layers import (
FastLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
......
......@@ -9,7 +9,7 @@ from typing import Optional, List, Tuple, Any
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_outputs import CausalLMOutputWithPast
from text_generation_server.utils.layers import (
from text_generation_server.layers import (
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
......
......@@ -38,7 +38,7 @@ from transformers.utils import (
is_torch_fx_proxy,
)
from transformers import T5Config
from text_generation_server.utils.layers import (
from text_generation_server.layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
TensorParallelRowLinear,
......
......@@ -12,7 +12,6 @@ from dataclasses import dataclass
from opentelemetry import trace
from transformers import PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.models import Model
from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.utils.speculate import get_speculate
......@@ -32,13 +31,14 @@ from text_generation_server.models.globals import MEM_POOL, CUDA_GRAPHS
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
from text_generation_server.utils.dist import MEMORY_FRACTION
tracer = trace.get_tracer(__name__)
from text_generation_server.utils.import_utils import (
IS_CUDA_SYSTEM,
IS_ROCM_SYSTEM,
IS_XPU_SYSTEM,
empty_cache,
synchronize,
get_free_memory,
)
tracer = trace.get_tracer(__name__)
@dataclass
class FlashCausalLMBatch(Batch):
......@@ -757,10 +757,8 @@ class FlashCausalLM(Model):
def warmup(self, batch: FlashCausalLMBatch):
# The warmup batch is the biggest batch we could ever receive
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
torch.cuda.empty_cache()
elif IS_XPU_SYSTEM:
torch.xpu.empty_cache()
empty_cache()
try:
cache_manager = set_cache_manager(
batch.blocks,
......@@ -780,10 +778,7 @@ class FlashCausalLM(Model):
f"You need to decrease `--max-batch-prefill-tokens`"
) from e
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
torch.cuda.synchronize(self.device)
elif IS_XPU_SYSTEM:
torch.xpu.synchronize(self.device)
synchronize(self.device)
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
# Calculate the number of blocks that can be allocated with the free memory
......@@ -791,20 +786,7 @@ class FlashCausalLM(Model):
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
total_free_memory, _ = torch.cuda.mem_get_info(self.device)
total_gpu_memory = torch.cuda.get_device_properties(
self.device
).total_memory
free_memory = max(
0, total_free_memory - (1 - MEMORY_FRACTION) * total_gpu_memory
)
elif IS_XPU_SYSTEM:
total_gpu_memory = torch.xpu.get_device_properties(self.device).total_memory
free_memory = int(total_gpu_memory * 0.5)
else:
raise NotImplementedError("FlashModel is only available on GPU")
free_memory = get_free_memory(self.device, MEMORY_FRACTION)
num_blocks = (
# Leave 5% for some wiggle room
......
......@@ -18,7 +18,7 @@ from text_generation_server.utils import (
tracer = trace.get_tracer(__name__)
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
from text_generation_server.utils.import_utils import SYSTEM
class FlashLlama(FlashCausalLM):
......@@ -35,7 +35,7 @@ class FlashLlama(FlashCausalLM):
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif IS_XPU_SYSTEM:
elif SYSTEM == "xpu":
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
......
......@@ -33,7 +33,7 @@ tracer = trace.get_tracer(__name__)
# Will be set in init
SLIDING_WINDOW: Optional[int] = None
SLIDING_WINDOW_BLOCKS: Optional[int] = None
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
from text_generation_server.utils.import_utils import SYSTEM
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
......@@ -322,7 +322,7 @@ class BaseFlashMistral(FlashCausalLM):
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif IS_XPU_SYSTEM:
elif SYSTEM == "xpu":
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
......
......@@ -14,7 +14,7 @@ from text_generation_server.utils import (
weight_files,
Weights,
)
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__)
......@@ -33,7 +33,7 @@ class FlashNeoXSharded(FlashCausalLM):
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif IS_XPU_SYSTEM:
elif SYSTEM == "xpu":
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
......
......@@ -15,7 +15,7 @@ from text_generation_server.utils import (
weight_files,
Weights,
)
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__)
......@@ -34,7 +34,7 @@ class FlashRWSharded(FlashCausalLM):
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif IS_XPU_SYSTEM:
elif SYSTEM == "xpu":
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
......
......@@ -18,7 +18,7 @@ from text_generation_server.utils import (
Weights,
)
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM
from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__)
......@@ -37,7 +37,7 @@ class FlashSantacoderSharded(FlashCausalLM):
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype
elif IS_XPU_SYSTEM:
elif SYSTEM == "xpu":
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
......
......@@ -85,7 +85,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
# When using GPTQ, Exllama kernels need some global kernels
# For which we have the finale shapes only after the model has loaded
# This will allocate those buffers.
from text_generation_server.utils.layers import (
from text_generation_server.layers.gptq import (
create_exllama_buffers,
set_device,
)
......
......@@ -2,13 +2,8 @@ import os
import torch
from loguru import logger
import math
from text_generation_server.utils.import_utils import (
IS_CUDA_SYSTEM,
IS_ROCM_SYSTEM,
IS_XPU_SYSTEM,
)
from text_generation_server.utils.import_utils import SYSTEM
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
raise ImportError("`USE_FLASH_ATTENTION` is false.")
......@@ -16,10 +11,45 @@ HAS_FLASH_ATTN = True
HAS_FLASH_ATTN_V2_CUDA = False
HAS_FLASH_ATTN_V2_ROCM = False
if IS_XPU_SYSTEM:
if SYSTEM == "xpu":
import intel_extension_for_pytorch as ipex
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
def attention(
q,
k,
v,
out,
cu_seqlens,
max_s,
softmax_scale,
window_size_left=-1,
):
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
if window_size_left != -1:
raise ValueError(
f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
)
return ipex.llm.functional.varlen_attention(
q,
k,
v,
out,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
0.0,
softmax_scale,
False,
True,
False,
None,
)
if SYSTEM in {"cuda", "rocm"}:
if not torch.cuda.is_available():
raise ImportError("CUDA is not available")
......@@ -35,11 +65,7 @@ if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
try:
import flash_attn_2_cuda
except ImportError:
architecture_suffix = ""
if IS_CUDA_SYSTEM:
architecture_suffix = "-cuda"
elif IS_ROCM_SYSTEM:
architecture_suffix = "-rocm"
architecture_suffix = f"-{SYSTEM}"
raise ImportError(
"Flash Attention V2 is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
......@@ -50,8 +76,8 @@ if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
f"GPU with CUDA capability {major} {minor} is not supported for "
"Flash Attention V2"
)
HAS_FLASH_ATTN_V2_CUDA = IS_CUDA_SYSTEM
HAS_FLASH_ATTN_V2_ROCM = IS_ROCM_SYSTEM
HAS_FLASH_ATTN_V2_CUDA = SYSTEM == "cuda"
HAS_FLASH_ATTN_V2_ROCM = SYSTEM == "rocm"
except ImportError as e:
try:
import flash_attn_cuda
......@@ -62,11 +88,11 @@ if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
"or install flash attention with `cd server && make install install-flash-attention`"
) from e
if IS_CUDA_SYSTEM and not (is_sm75 or is_sm8x or is_sm90):
if SYSTEM == "cuda" and not (is_sm75 or is_sm8x or is_sm90):
raise ImportError(
f"GPU with CUDA capability {major} {minor} is not supported"
) from e
elif IS_ROCM_SYSTEM:
elif SYSTEM == "rocm":
for idx in range(torch.cuda.device_count()):
if "MI210" not in torch.cuda.get_device_name(
idx
......@@ -79,42 +105,20 @@ if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
HAS_FLASH_ATTN = True
def attention(
q,
k,
v,
out,
cu_seqlens,
max_s,
softmax_scale,
window_size_left=-1,
):
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
if IS_XPU_SYSTEM:
if window_size_left != -1:
raise ValueError(
f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
)
return ipex.llm.functional.varlen_attention(
q,
k,
v,
out,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
0.0,
softmax_scale,
False,
True,
False,
None,
)
if HAS_FLASH_ATTN_V2_CUDA:
if HAS_FLASH_ATTN_V2_CUDA:
def attention(
q,
k,
v,
out,
cu_seqlens,
max_s,
softmax_scale,
window_size_left=-1,
):
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
return flash_attn_2_cuda.varlen_fwd(
q,
k,
......@@ -136,7 +140,21 @@ def attention(
False,
None,
)
elif HAS_FLASH_ATTN_V2_ROCM:
elif HAS_FLASH_ATTN_V2_ROCM:
def attention(
q,
k,
v,
out,
cu_seqlens,
max_s,
softmax_scale,
window_size_left=-1,
):
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
if window_size_left != -1:
raise ValueError(
f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
......@@ -159,7 +177,19 @@ def attention(
False,
None,
)
elif HAS_FLASH_ATTN:
elif HAS_FLASH_ATTN:
def attention(
q,
k,
v,
out,
cu_seqlens,
max_s,
softmax_scale,
window_size_left=-1,
):
if window_size_left != -1:
raise NotImplementedError(
"window_size_left is only available with flash attn v2"
......@@ -209,4 +239,5 @@ def attention(
None,
)
else:
raise NotImplementedError("flash attention is not installed")
......@@ -10,6 +10,41 @@ def is_xpu_available():
return hasattr(torch, "xpu") and torch.xpu.is_available()
IS_ROCM_SYSTEM = torch.version.hip is not None
IS_CUDA_SYSTEM = torch.version.cuda is not None
IS_XPU_SYSTEM = is_xpu_available()
def get_cuda_free_memory(device, memory_fraction):
total_free_memory, _ = torch.cuda.mem_get_info(device)
total_gpu_memory = torch.cuda.get_device_properties(device).total_memory
free_memory = max(0, total_free_memory - (1 - memory_fraction) * total_gpu_memory)
return free_memory
def get_xpu_free_memory(device):
total_gpu_memory = torch.xpu.get_device_properties(device).total_memory
free_memory = int(total_gpu_memory * 0.5)
return free_memory
SYSTEM = None
if torch.version.hip is not None:
SYSTEM = "rocm"
empty_cache = torch.cuda.empty_cache
synchronize = torch.cuda.synchronize
get_free_memory = get_cuda_free_memory
elif torch.version.cuda is not None and torch.cuda.is_available():
SYSTEM = "cuda"
empty_cache = torch.cuda.empty_cache
synchronize = torch.cuda.synchronize
get_free_memory = get_cuda_free_memory
elif is_xpu_available():
SYSTEM = "xpu"
empty_cache = torch.xpu.empty_cache
synchronize = torch.xpu.synchronize
get_free_memory = get_xpu_free_memory
else:
SYSTEM = "cpu"
def noop(*args, **kwargs):
pass
empty_cache = noop
synchronize = noop
get_free_memory = noop
This diff is collapsed.
import torch
from text_generation_server.utils.import_utils import (
IS_CUDA_SYSTEM,
IS_ROCM_SYSTEM,
IS_XPU_SYSTEM,
)
from text_generation_server.utils.import_utils import SYSTEM
_PARTITION_SIZE = 512
if IS_XPU_SYSTEM:
if SYSTEM == "xpu":
import intel_extension_for_pytorch as ipex
......@@ -18,17 +14,17 @@ def reshape_and_cache(
value_cache: torch.Tensor,
slots: torch.Tensor,
):
if IS_CUDA_SYSTEM:
if SYSTEM == "cuda":
from vllm._C import cache_ops
cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, "auto", 1.0
)
elif IS_ROCM_SYSTEM:
elif SYSTEM == "rocm":
from vllm import cache_ops
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots)
elif IS_XPU_SYSTEM:
elif SYSTEM == "xpu":
ipex.llm.modules.PagedAttention.reshape_and_cache(
key, value, key_cache, value_cache, slots
)
......@@ -68,7 +64,7 @@ def attention(
block_size = value_cache.shape[3]
num_seqs, num_heads, head_size = query.shape
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
if IS_XPU_SYSTEM:
if SYSTEM == "xpu":
query = query.contiguous()
return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
out,
......@@ -91,7 +87,7 @@ def attention(
# to parallelize.
use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)
if use_v1:
if IS_CUDA_SYSTEM:
if SYSTEM == "cuda":
from vllm._C import ops
ops.paged_attention_v1(
......@@ -109,7 +105,7 @@ def attention(
"auto",
1.0,
)
elif IS_ROCM_SYSTEM:
elif SYSTEM == "rocm":
from vllm import attention_ops
attention_ops.paged_attention_v1(
......@@ -143,7 +139,7 @@ def attention(
)
max_logits = torch.empty_like(exp_sums)
if IS_CUDA_SYSTEM:
if SYSTEM == "cuda":
from vllm._C import ops
ops.paged_attention_v2(
......@@ -164,7 +160,7 @@ def attention(
"auto",
1.0,
)
elif IS_ROCM_SYSTEM:
elif SYSTEM == "rocm":
from vllm import attention_ops
attention_ops.paged_attention_v2(
......
......@@ -171,7 +171,7 @@ class Weights:
log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
)
from text_generation_server.utils.awq.conversion_utils import (
from text_generation_server.layers.awq.conversion_utils import (
fast_awq_to_gptq,
)
......@@ -227,7 +227,7 @@ class Weights:
bits, groupsize, desc_act, quant_method = self._get_gptq_params()
from text_generation_server.utils.layers import HAS_EXLLAMA
from text_generation_server.layers.gptq import HAS_EXLLAMA
use_exllama = (
bits == 4 and HAS_EXLLAMA and quantize == "gptq" and not desc_act
......@@ -242,7 +242,7 @@ class Weights:
log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
)
from text_generation_server.utils.awq.conversion_utils import (
from text_generation_server.layers.awq.conversion_utils import (
fast_awq_to_gptq,
)
......@@ -321,7 +321,7 @@ class Weights:
# it would require to reorder input activations that are split unto several GPUs
use_exllama = False
from text_generation_server.utils.layers import HAS_EXLLAMA, CAN_EXLLAMA
from text_generation_server.layers.gptq import HAS_EXLLAMA, CAN_EXLLAMA
if use_exllama:
if not HAS_EXLLAMA:
......@@ -348,7 +348,7 @@ class Weights:
log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
)
from text_generation_server.utils.awq.conversion_utils import (
from text_generation_server.layers.awq.conversion_utils import (
fast_awq_to_gptq,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment