"...text-generation-inference.git" did not exist on "153ff3740bd32f43c4346b6d9c1708a4c46c793f"
Unverified Commit fd89d9df authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Refactor layers. (#1866)

# What does this PR do?

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)


## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?


## Who can review?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @


@OlivierDehaene OR @Narsil

 -->
parent 59b3ffea
...@@ -8,12 +8,12 @@ from typing import Optional, Tuple, Any ...@@ -8,12 +8,12 @@ from typing import Optional, Tuple, Any
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
import torch.nn.functional as F import torch.nn.functional as F
from text_generation_server.utils.layers import ( from text_generation_server.layers import (
SpeculativeHead, SpeculativeHead,
TensorParallelEmbedding, TensorParallelEmbedding,
FastRMSNorm,
FastLinear, FastLinear,
) )
from text_generation_server.layers.layernorm import FastRMSNorm
from einops import rearrange from einops import rearrange
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
......
...@@ -17,7 +17,7 @@ from transformers.modeling_outputs import ( ...@@ -17,7 +17,7 @@ from transformers.modeling_outputs import (
) )
from einops import rearrange from einops import rearrange
from packaging import version from packaging import version
from text_generation_server.utils.layers import ( from text_generation_server.layers import (
TensorParallelEmbedding, TensorParallelEmbedding,
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelRowLinear, TensorParallelRowLinear,
......
...@@ -40,7 +40,7 @@ from transformers.modeling_outputs import ( ...@@ -40,7 +40,7 @@ from transformers.modeling_outputs import (
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from transformers import GPTNeoXConfig from transformers import GPTNeoXConfig
from loguru import logger from loguru import logger
from text_generation_server.utils.layers import ( from text_generation_server.layers import (
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
TensorParallelRowLinear, TensorParallelRowLinear,
......
...@@ -27,7 +27,7 @@ from transformers.modeling_outputs import ( ...@@ -27,7 +27,7 @@ from transformers.modeling_outputs import (
) )
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from transformers import OPTConfig from transformers import OPTConfig
from text_generation_server.utils.layers import ( from text_generation_server.layers import (
FastLinear, FastLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
......
...@@ -9,7 +9,7 @@ from typing import Optional, List, Tuple, Any ...@@ -9,7 +9,7 @@ from typing import Optional, List, Tuple, Any
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.modeling_outputs import CausalLMOutputWithPast
from text_generation_server.utils.layers import ( from text_generation_server.layers import (
TensorParallelRowLinear, TensorParallelRowLinear,
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
......
...@@ -38,7 +38,7 @@ from transformers.utils import ( ...@@ -38,7 +38,7 @@ from transformers.utils import (
is_torch_fx_proxy, is_torch_fx_proxy,
) )
from transformers import T5Config from transformers import T5Config
from text_generation_server.utils.layers import ( from text_generation_server.layers import (
TensorParallelColumnLinear, TensorParallelColumnLinear,
TensorParallelEmbedding, TensorParallelEmbedding,
TensorParallelRowLinear, TensorParallelRowLinear,
......
...@@ -12,7 +12,6 @@ from dataclasses import dataclass ...@@ -12,7 +12,6 @@ from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type, Dict from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.models import Model from text_generation_server.models import Model
from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.utils.speculate import get_speculate from text_generation_server.utils.speculate import get_speculate
...@@ -32,13 +31,14 @@ from text_generation_server.models.globals import MEM_POOL, CUDA_GRAPHS ...@@ -32,13 +31,14 @@ from text_generation_server.models.globals import MEM_POOL, CUDA_GRAPHS
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
from text_generation_server.utils.dist import MEMORY_FRACTION from text_generation_server.utils.dist import MEMORY_FRACTION
tracer = trace.get_tracer(__name__)
from text_generation_server.utils.import_utils import ( from text_generation_server.utils.import_utils import (
IS_CUDA_SYSTEM, empty_cache,
IS_ROCM_SYSTEM, synchronize,
IS_XPU_SYSTEM, get_free_memory,
) )
tracer = trace.get_tracer(__name__)
@dataclass @dataclass
class FlashCausalLMBatch(Batch): class FlashCausalLMBatch(Batch):
...@@ -757,10 +757,8 @@ class FlashCausalLM(Model): ...@@ -757,10 +757,8 @@ class FlashCausalLM(Model):
def warmup(self, batch: FlashCausalLMBatch): def warmup(self, batch: FlashCausalLMBatch):
# The warmup batch is the biggest batch we could ever receive # The warmup batch is the biggest batch we could ever receive
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM: empty_cache()
torch.cuda.empty_cache()
elif IS_XPU_SYSTEM:
torch.xpu.empty_cache()
try: try:
cache_manager = set_cache_manager( cache_manager = set_cache_manager(
batch.blocks, batch.blocks,
...@@ -780,10 +778,7 @@ class FlashCausalLM(Model): ...@@ -780,10 +778,7 @@ class FlashCausalLM(Model):
f"You need to decrease `--max-batch-prefill-tokens`" f"You need to decrease `--max-batch-prefill-tokens`"
) from e ) from e
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM: synchronize(self.device)
torch.cuda.synchronize(self.device)
elif IS_XPU_SYSTEM:
torch.xpu.synchronize(self.device)
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm) # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
# Calculate the number of blocks that can be allocated with the free memory # Calculate the number of blocks that can be allocated with the free memory
...@@ -791,20 +786,7 @@ class FlashCausalLM(Model): ...@@ -791,20 +786,7 @@ class FlashCausalLM(Model):
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM: free_memory = get_free_memory(self.device, MEMORY_FRACTION)
total_free_memory, _ = torch.cuda.mem_get_info(self.device)
total_gpu_memory = torch.cuda.get_device_properties(
self.device
).total_memory
free_memory = max(
0, total_free_memory - (1 - MEMORY_FRACTION) * total_gpu_memory
)
elif IS_XPU_SYSTEM:
total_gpu_memory = torch.xpu.get_device_properties(self.device).total_memory
free_memory = int(total_gpu_memory * 0.5)
else:
raise NotImplementedError("FlashModel is only available on GPU")
num_blocks = ( num_blocks = (
# Leave 5% for some wiggle room # Leave 5% for some wiggle room
......
...@@ -18,7 +18,7 @@ from text_generation_server.utils import ( ...@@ -18,7 +18,7 @@ from text_generation_server.utils import (
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM from text_generation_server.utils.import_utils import SYSTEM
class FlashLlama(FlashCausalLM): class FlashLlama(FlashCausalLM):
...@@ -35,7 +35,7 @@ class FlashLlama(FlashCausalLM): ...@@ -35,7 +35,7 @@ class FlashLlama(FlashCausalLM):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
elif IS_XPU_SYSTEM: elif SYSTEM == "xpu":
device = torch.device(f"xpu:{rank}") device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
else: else:
......
...@@ -33,7 +33,7 @@ tracer = trace.get_tracer(__name__) ...@@ -33,7 +33,7 @@ tracer = trace.get_tracer(__name__)
# Will be set in init # Will be set in init
SLIDING_WINDOW: Optional[int] = None SLIDING_WINDOW: Optional[int] = None
SLIDING_WINDOW_BLOCKS: Optional[int] = None SLIDING_WINDOW_BLOCKS: Optional[int] = None
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM from text_generation_server.utils.import_utils import SYSTEM
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
...@@ -322,7 +322,7 @@ class BaseFlashMistral(FlashCausalLM): ...@@ -322,7 +322,7 @@ class BaseFlashMistral(FlashCausalLM):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
elif IS_XPU_SYSTEM: elif SYSTEM == "xpu":
device = torch.device(f"xpu:{rank}") device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
else: else:
......
...@@ -14,7 +14,7 @@ from text_generation_server.utils import ( ...@@ -14,7 +14,7 @@ from text_generation_server.utils import (
weight_files, weight_files,
Weights, Weights,
) )
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
...@@ -33,7 +33,7 @@ class FlashNeoXSharded(FlashCausalLM): ...@@ -33,7 +33,7 @@ class FlashNeoXSharded(FlashCausalLM):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
elif IS_XPU_SYSTEM: elif SYSTEM == "xpu":
device = torch.device(f"xpu:{rank}") device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
else: else:
......
...@@ -15,7 +15,7 @@ from text_generation_server.utils import ( ...@@ -15,7 +15,7 @@ from text_generation_server.utils import (
weight_files, weight_files,
Weights, Weights,
) )
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
...@@ -34,7 +34,7 @@ class FlashRWSharded(FlashCausalLM): ...@@ -34,7 +34,7 @@ class FlashRWSharded(FlashCausalLM):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
elif IS_XPU_SYSTEM: elif SYSTEM == "xpu":
device = torch.device(f"xpu:{rank}") device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
else: else:
......
...@@ -18,7 +18,7 @@ from text_generation_server.utils import ( ...@@ -18,7 +18,7 @@ from text_generation_server.utils import (
Weights, Weights,
) )
from text_generation_server.utils.import_utils import IS_XPU_SYSTEM from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
...@@ -37,7 +37,7 @@ class FlashSantacoderSharded(FlashCausalLM): ...@@ -37,7 +37,7 @@ class FlashSantacoderSharded(FlashCausalLM):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
elif IS_XPU_SYSTEM: elif SYSTEM == "xpu":
device = torch.device(f"xpu:{rank}") device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
else: else:
......
...@@ -85,7 +85,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): ...@@ -85,7 +85,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
# When using GPTQ, Exllama kernels need some global kernels # When using GPTQ, Exllama kernels need some global kernels
# For which we have the finale shapes only after the model has loaded # For which we have the finale shapes only after the model has loaded
# This will allocate those buffers. # This will allocate those buffers.
from text_generation_server.utils.layers import ( from text_generation_server.layers.gptq import (
create_exllama_buffers, create_exllama_buffers,
set_device, set_device,
) )
......
...@@ -2,13 +2,8 @@ import os ...@@ -2,13 +2,8 @@ import os
import torch import torch
from loguru import logger from loguru import logger
import math
from text_generation_server.utils.import_utils import ( from text_generation_server.utils.import_utils import SYSTEM
IS_CUDA_SYSTEM,
IS_ROCM_SYSTEM,
IS_XPU_SYSTEM,
)
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
raise ImportError("`USE_FLASH_ATTENTION` is false.") raise ImportError("`USE_FLASH_ATTENTION` is false.")
...@@ -16,10 +11,45 @@ HAS_FLASH_ATTN = True ...@@ -16,10 +11,45 @@ HAS_FLASH_ATTN = True
HAS_FLASH_ATTN_V2_CUDA = False HAS_FLASH_ATTN_V2_CUDA = False
HAS_FLASH_ATTN_V2_ROCM = False HAS_FLASH_ATTN_V2_ROCM = False
if IS_XPU_SYSTEM: if SYSTEM == "xpu":
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex
if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM: def attention(
q,
k,
v,
out,
cu_seqlens,
max_s,
softmax_scale,
window_size_left=-1,
):
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
if window_size_left != -1:
raise ValueError(
f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
)
return ipex.llm.functional.varlen_attention(
q,
k,
v,
out,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
0.0,
softmax_scale,
False,
True,
False,
None,
)
if SYSTEM in {"cuda", "rocm"}:
if not torch.cuda.is_available(): if not torch.cuda.is_available():
raise ImportError("CUDA is not available") raise ImportError("CUDA is not available")
...@@ -35,11 +65,7 @@ if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM: ...@@ -35,11 +65,7 @@ if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
try: try:
import flash_attn_2_cuda import flash_attn_2_cuda
except ImportError: except ImportError:
architecture_suffix = "" architecture_suffix = f"-{SYSTEM}"
if IS_CUDA_SYSTEM:
architecture_suffix = "-cuda"
elif IS_ROCM_SYSTEM:
architecture_suffix = "-rocm"
raise ImportError( raise ImportError(
"Flash Attention V2 is not installed.\n" "Flash Attention V2 is not installed.\n"
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
...@@ -50,8 +76,8 @@ if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM: ...@@ -50,8 +76,8 @@ if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
f"GPU with CUDA capability {major} {minor} is not supported for " f"GPU with CUDA capability {major} {minor} is not supported for "
"Flash Attention V2" "Flash Attention V2"
) )
HAS_FLASH_ATTN_V2_CUDA = IS_CUDA_SYSTEM HAS_FLASH_ATTN_V2_CUDA = SYSTEM == "cuda"
HAS_FLASH_ATTN_V2_ROCM = IS_ROCM_SYSTEM HAS_FLASH_ATTN_V2_ROCM = SYSTEM == "rocm"
except ImportError as e: except ImportError as e:
try: try:
import flash_attn_cuda import flash_attn_cuda
...@@ -62,11 +88,11 @@ if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM: ...@@ -62,11 +88,11 @@ if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
"or install flash attention with `cd server && make install install-flash-attention`" "or install flash attention with `cd server && make install install-flash-attention`"
) from e ) from e
if IS_CUDA_SYSTEM and not (is_sm75 or is_sm8x or is_sm90): if SYSTEM == "cuda" and not (is_sm75 or is_sm8x or is_sm90):
raise ImportError( raise ImportError(
f"GPU with CUDA capability {major} {minor} is not supported" f"GPU with CUDA capability {major} {minor} is not supported"
) from e ) from e
elif IS_ROCM_SYSTEM: elif SYSTEM == "rocm":
for idx in range(torch.cuda.device_count()): for idx in range(torch.cuda.device_count()):
if "MI210" not in torch.cuda.get_device_name( if "MI210" not in torch.cuda.get_device_name(
idx idx
...@@ -79,42 +105,20 @@ if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM: ...@@ -79,42 +105,20 @@ if IS_CUDA_SYSTEM or IS_ROCM_SYSTEM:
HAS_FLASH_ATTN = True HAS_FLASH_ATTN = True
def attention( if HAS_FLASH_ATTN_V2_CUDA:
q,
k, def attention(
v, q,
out, k,
cu_seqlens, v,
max_s, out,
softmax_scale, cu_seqlens,
window_size_left=-1, max_s,
): softmax_scale,
if window_size_left <= 0 and window_size_left != -1: window_size_left=-1,
raise ValueError("`window_size_left` must be > 0 or -1") ):
if window_size_left <= 0 and window_size_left != -1:
if IS_XPU_SYSTEM: raise ValueError("`window_size_left` must be > 0 or -1")
if window_size_left != -1:
raise ValueError(
f"XPU version of Flash Attention does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
)
return ipex.llm.functional.varlen_attention(
q,
k,
v,
out,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
0.0,
softmax_scale,
False,
True,
False,
None,
)
if HAS_FLASH_ATTN_V2_CUDA:
return flash_attn_2_cuda.varlen_fwd( return flash_attn_2_cuda.varlen_fwd(
q, q,
k, k,
...@@ -136,7 +140,21 @@ def attention( ...@@ -136,7 +140,21 @@ def attention(
False, False,
None, None,
) )
elif HAS_FLASH_ATTN_V2_ROCM:
elif HAS_FLASH_ATTN_V2_ROCM:
def attention(
q,
k,
v,
out,
cu_seqlens,
max_s,
softmax_scale,
window_size_left=-1,
):
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
if window_size_left != -1: if window_size_left != -1:
raise ValueError( raise ValueError(
f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})." f"RoCm version of Flash Attention v2 does not support window attention (window_size_left != -1, got window_size_left={window_size_left})."
...@@ -159,7 +177,19 @@ def attention( ...@@ -159,7 +177,19 @@ def attention(
False, False,
None, None,
) )
elif HAS_FLASH_ATTN:
elif HAS_FLASH_ATTN:
def attention(
q,
k,
v,
out,
cu_seqlens,
max_s,
softmax_scale,
window_size_left=-1,
):
if window_size_left != -1: if window_size_left != -1:
raise NotImplementedError( raise NotImplementedError(
"window_size_left is only available with flash attn v2" "window_size_left is only available with flash attn v2"
...@@ -209,4 +239,5 @@ def attention( ...@@ -209,4 +239,5 @@ def attention(
None, None,
) )
else:
raise NotImplementedError("flash attention is not installed") raise NotImplementedError("flash attention is not installed")
...@@ -10,6 +10,41 @@ def is_xpu_available(): ...@@ -10,6 +10,41 @@ def is_xpu_available():
return hasattr(torch, "xpu") and torch.xpu.is_available() return hasattr(torch, "xpu") and torch.xpu.is_available()
IS_ROCM_SYSTEM = torch.version.hip is not None def get_cuda_free_memory(device, memory_fraction):
IS_CUDA_SYSTEM = torch.version.cuda is not None total_free_memory, _ = torch.cuda.mem_get_info(device)
IS_XPU_SYSTEM = is_xpu_available() total_gpu_memory = torch.cuda.get_device_properties(device).total_memory
free_memory = max(0, total_free_memory - (1 - memory_fraction) * total_gpu_memory)
return free_memory
def get_xpu_free_memory(device):
total_gpu_memory = torch.xpu.get_device_properties(device).total_memory
free_memory = int(total_gpu_memory * 0.5)
return free_memory
SYSTEM = None
if torch.version.hip is not None:
SYSTEM = "rocm"
empty_cache = torch.cuda.empty_cache
synchronize = torch.cuda.synchronize
get_free_memory = get_cuda_free_memory
elif torch.version.cuda is not None and torch.cuda.is_available():
SYSTEM = "cuda"
empty_cache = torch.cuda.empty_cache
synchronize = torch.cuda.synchronize
get_free_memory = get_cuda_free_memory
elif is_xpu_available():
SYSTEM = "xpu"
empty_cache = torch.xpu.empty_cache
synchronize = torch.xpu.synchronize
get_free_memory = get_xpu_free_memory
else:
SYSTEM = "cpu"
def noop(*args, **kwargs):
pass
empty_cache = noop
synchronize = noop
get_free_memory = noop
This diff is collapsed.
import torch import torch
from text_generation_server.utils.import_utils import ( from text_generation_server.utils.import_utils import SYSTEM
IS_CUDA_SYSTEM,
IS_ROCM_SYSTEM,
IS_XPU_SYSTEM,
)
_PARTITION_SIZE = 512 _PARTITION_SIZE = 512
if IS_XPU_SYSTEM: if SYSTEM == "xpu":
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex
...@@ -18,17 +14,17 @@ def reshape_and_cache( ...@@ -18,17 +14,17 @@ def reshape_and_cache(
value_cache: torch.Tensor, value_cache: torch.Tensor,
slots: torch.Tensor, slots: torch.Tensor,
): ):
if IS_CUDA_SYSTEM: if SYSTEM == "cuda":
from vllm._C import cache_ops from vllm._C import cache_ops
cache_ops.reshape_and_cache( cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, "auto", 1.0 key, value, key_cache, value_cache, slots, "auto", 1.0
) )
elif IS_ROCM_SYSTEM: elif SYSTEM == "rocm":
from vllm import cache_ops from vllm import cache_ops
cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots) cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slots)
elif IS_XPU_SYSTEM: elif SYSTEM == "xpu":
ipex.llm.modules.PagedAttention.reshape_and_cache( ipex.llm.modules.PagedAttention.reshape_and_cache(
key, value, key_cache, value_cache, slots key, value, key_cache, value_cache, slots
) )
...@@ -68,7 +64,7 @@ def attention( ...@@ -68,7 +64,7 @@ def attention(
block_size = value_cache.shape[3] block_size = value_cache.shape[3]
num_seqs, num_heads, head_size = query.shape num_seqs, num_heads, head_size = query.shape
max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
if IS_XPU_SYSTEM: if SYSTEM == "xpu":
query = query.contiguous() query = query.contiguous()
return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
out, out,
...@@ -91,7 +87,7 @@ def attention( ...@@ -91,7 +87,7 @@ def attention(
# to parallelize. # to parallelize.
use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512) use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)
if use_v1: if use_v1:
if IS_CUDA_SYSTEM: if SYSTEM == "cuda":
from vllm._C import ops from vllm._C import ops
ops.paged_attention_v1( ops.paged_attention_v1(
...@@ -109,7 +105,7 @@ def attention( ...@@ -109,7 +105,7 @@ def attention(
"auto", "auto",
1.0, 1.0,
) )
elif IS_ROCM_SYSTEM: elif SYSTEM == "rocm":
from vllm import attention_ops from vllm import attention_ops
attention_ops.paged_attention_v1( attention_ops.paged_attention_v1(
...@@ -143,7 +139,7 @@ def attention( ...@@ -143,7 +139,7 @@ def attention(
) )
max_logits = torch.empty_like(exp_sums) max_logits = torch.empty_like(exp_sums)
if IS_CUDA_SYSTEM: if SYSTEM == "cuda":
from vllm._C import ops from vllm._C import ops
ops.paged_attention_v2( ops.paged_attention_v2(
...@@ -164,7 +160,7 @@ def attention( ...@@ -164,7 +160,7 @@ def attention(
"auto", "auto",
1.0, 1.0,
) )
elif IS_ROCM_SYSTEM: elif SYSTEM == "rocm":
from vllm import attention_ops from vllm import attention_ops
attention_ops.paged_attention_v2( attention_ops.paged_attention_v2(
......
...@@ -171,7 +171,7 @@ class Weights: ...@@ -171,7 +171,7 @@ class Weights:
log_once( log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format." logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
) )
from text_generation_server.utils.awq.conversion_utils import ( from text_generation_server.layers.awq.conversion_utils import (
fast_awq_to_gptq, fast_awq_to_gptq,
) )
...@@ -227,7 +227,7 @@ class Weights: ...@@ -227,7 +227,7 @@ class Weights:
bits, groupsize, desc_act, quant_method = self._get_gptq_params() bits, groupsize, desc_act, quant_method = self._get_gptq_params()
from text_generation_server.utils.layers import HAS_EXLLAMA from text_generation_server.layers.gptq import HAS_EXLLAMA
use_exllama = ( use_exllama = (
bits == 4 and HAS_EXLLAMA and quantize == "gptq" and not desc_act bits == 4 and HAS_EXLLAMA and quantize == "gptq" and not desc_act
...@@ -242,7 +242,7 @@ class Weights: ...@@ -242,7 +242,7 @@ class Weights:
log_once( log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format." logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
) )
from text_generation_server.utils.awq.conversion_utils import ( from text_generation_server.layers.awq.conversion_utils import (
fast_awq_to_gptq, fast_awq_to_gptq,
) )
...@@ -321,7 +321,7 @@ class Weights: ...@@ -321,7 +321,7 @@ class Weights:
# it would require to reorder input activations that are split unto several GPUs # it would require to reorder input activations that are split unto several GPUs
use_exllama = False use_exllama = False
from text_generation_server.utils.layers import HAS_EXLLAMA, CAN_EXLLAMA from text_generation_server.layers.gptq import HAS_EXLLAMA, CAN_EXLLAMA
if use_exllama: if use_exllama:
if not HAS_EXLLAMA: if not HAS_EXLLAMA:
...@@ -348,7 +348,7 @@ class Weights: ...@@ -348,7 +348,7 @@ class Weights:
log_once( log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format." logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
) )
from text_generation_server.utils.awq.conversion_utils import ( from text_generation_server.layers.awq.conversion_utils import (
fast_awq_to_gptq, fast_awq_to_gptq,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment