Unverified Commit 9e2fdf57 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Removing IPEX_AVAIL. (#2115)

* Removing IPEX_AVAIL.

Chose to unify CPU and XPU under `ipex`. Most code is exactly similar
except for a very few spots.

The biggest number of spots is the kv-cache layout and the flash_xxx.py
files.
Since those files should be removed soon and factored away, we should
not need them.

* Forgot a few places.

* Unrelated change.

* Fixing HF_TOKEN.

* HF_TOKEN
parent 3f3b7ffd
...@@ -178,6 +178,6 @@ jobs: ...@@ -178,6 +178,6 @@ jobs:
export DOCKER_VOLUME=/mnt/cache export DOCKER_VOLUME=/mnt/cache
export DOCKER_IMAGE=${{ needs.build-and-push.outputs.docker_image }} export DOCKER_IMAGE=${{ needs.build-and-push.outputs.docker_image }}
export DOCKER_DEVICES=${{ needs.build-and-push.outputs.docker_devices }} export DOCKER_DEVICES=${{ needs.build-and-push.outputs.docker_devices }}
export HF_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }} export HF_TOKEN=${{ secrets.HF_TOKEN }}
echo $DOCKER_IMAGE echo $DOCKER_IMAGE
pytest -s -vv integration-tests pytest -s -vv integration-tests
...@@ -22,5 +22,5 @@ jobs: ...@@ -22,5 +22,5 @@ jobs:
- name: Run tests - name: Run tests
run: | run: |
pip install pytest pytest-asyncio pip install pytest pytest-asyncio
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }} export HF_TOKEN=${{ secrets.HF_TOKEN }}
make python-client-tests make python-client-tests
...@@ -37,5 +37,5 @@ jobs: ...@@ -37,5 +37,5 @@ jobs:
export DOCKER_VOLUME=/mnt/cache export DOCKER_VOLUME=/mnt/cache
export DOCKER_IMAGE=${{ inputs.docker_image }} export DOCKER_IMAGE=${{ inputs.docker_image }}
export DOCKER_DEVICES=${{ inputs.docker_devices }} export DOCKER_DEVICES=${{ inputs.docker_devices }}
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }} export HF_TOKEN=${{ secrets.HF_TOKEN }}
pytest -s -vv integration-tests pytest -s -vv integration-tests
...@@ -28,7 +28,7 @@ jobs: ...@@ -28,7 +28,7 @@ jobs:
- name: Start starcoder - name: Start starcoder
run: | run: |
docker run --name tgi-starcoder --rm --gpus all -p 3000:80 -v /mnt/cache:/data -e HF_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }} --pull always -d ghcr.io/huggingface/text-generation-inference:latest --model-id bigcode/starcoder --num-shard 2 --max-batch-total-tokens 32768 docker run --name tgi-starcoder --rm --gpus all -p 3000:80 -v /mnt/cache:/data -e HF_TOKEN=${{ secrets.HF_TOKEN }} --pull always -d ghcr.io/huggingface/text-generation-inference:latest --model-id bigcode/starcoder --num-shard 2 --max-batch-total-tokens 32768
sleep 10 sleep 10
wget --timeout 10 --retry-on-http-error --waitretry=1 --tries=240 http://localhost:3000/health wget --timeout 10 --retry-on-http-error --waitretry=1 --tries=240 http://localhost:3000/health
......
...@@ -72,7 +72,7 @@ jobs: ...@@ -72,7 +72,7 @@ jobs:
- name: Run server tests - name: Run server tests
run: | run: |
pip install pytest pip install pytest
export HF_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }} export HF_TOKEN=${{ secrets.HF_TOKEN }}
pytest -s -vv server/tests pytest -s -vv server/tests
- name: Pre-commit checks - name: Pre-commit checks
run: | run: |
......
...@@ -455,6 +455,6 @@ class DeployedModel(BaseModel): ...@@ -455,6 +455,6 @@ class DeployedModel(BaseModel):
# Disable warning for use of `model_` prefix in `model_id`. Be mindful about adding members # Disable warning for use of `model_` prefix in `model_id`. Be mindful about adding members
# with model_ prefixes, since this disables guardrails for colliding fields: # with model_ prefixes, since this disables guardrails for colliding fields:
# https://github.com/pydantic/pydantic/issues/9177 # https://github.com/pydantic/pydantic/issues/9177
model_config = ConfigDict(protected_namespaces=()) model_config = ConfigDict(protected_namespaces=())
model_id: str model_id: str
sha: str sha: str
from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL from text_generation_server.utils.import_utils import SYSTEM
import os import os
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
...@@ -7,7 +7,7 @@ if SYSTEM == "cuda": ...@@ -7,7 +7,7 @@ if SYSTEM == "cuda":
from .cuda import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING from .cuda import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
elif SYSTEM == "rocm": elif SYSTEM == "rocm":
from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
elif IPEX_AVAIL: elif SYSTEM == "ipex":
from .xpu import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING from .ipex import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
else: else:
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention") raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
...@@ -3,7 +3,6 @@ from torch import nn ...@@ -3,7 +3,6 @@ from torch import nn
from accelerate import init_empty_weights from accelerate import init_empty_weights
from text_generation_server.utils.import_utils import ( from text_generation_server.utils.import_utils import (
SYSTEM, SYSTEM,
IPEX_AVAIL,
) )
...@@ -83,7 +82,7 @@ elif SYSTEM == "rocm": ...@@ -83,7 +82,7 @@ elif SYSTEM == "rocm":
return super().forward(hidden_states), residual return super().forward(hidden_states), residual
elif IPEX_AVAIL: elif SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex
class FastLayerNorm(nn.LayerNorm): class FastLayerNorm(nn.LayerNorm):
...@@ -112,7 +111,7 @@ class FastRMSNorm(nn.Module): ...@@ -112,7 +111,7 @@ class FastRMSNorm(nn.Module):
return cls(weight, eps) return cls(weight, eps)
def forward(self, hidden_states, residual=None): def forward(self, hidden_states, residual=None):
if IPEX_AVAIL: if SYSTEM == "ipex":
out = ipex.llm.functional.add_rms_norm( out = ipex.llm.functional.add_rms_norm(
residual, residual,
hidden_states, hidden_states,
......
...@@ -2,14 +2,14 @@ import os ...@@ -2,14 +2,14 @@ import os
import torch import torch
from torch import nn from torch import nn
from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM == "cuda": if SYSTEM == "cuda":
from flash_attn.layers.rotary import RotaryEmbedding from flash_attn.layers.rotary import RotaryEmbedding
import rotary_emb import rotary_emb
elif SYSTEM == "rocm": elif SYSTEM == "rocm":
from vllm._C import ops from vllm._C import ops
elif IPEX_AVAIL: elif SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex
...@@ -69,7 +69,7 @@ class PositionRotaryEmbedding(nn.Module): ...@@ -69,7 +69,7 @@ class PositionRotaryEmbedding(nn.Module):
# Inplace operation, updating query and key. # Inplace operation, updating query and key.
ops.rotary_embedding(query, key, head_size, cos, sin, True) ops.rotary_embedding(query, key, head_size, cos, sin, True)
elif IPEX_AVAIL: elif SYSTEM == "ipex":
ipex.llm.functional.rotary_embedding( ipex.llm.functional.rotary_embedding(
query, key, sin, cos, query.size(-1), True query, key, sin, cos, query.size(-1), True
) )
......
...@@ -3,9 +3,9 @@ from torch.nn import functional as F ...@@ -3,9 +3,9 @@ from torch.nn import functional as F
from typing import Iterable, List from typing import Iterable, List
from text_generation_server.layers.linear import get_linear, FastLinear from text_generation_server.layers.linear import get_linear, FastLinear
from text_generation_server.layers.exl2 import Exl2Weight from text_generation_server.layers.exl2 import Exl2Weight
from text_generation_server.utils.import_utils import IPEX_AVAIL from text_generation_server.utils.import_utils import SYSTEM
if IPEX_AVAIL: if SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex
...@@ -100,7 +100,7 @@ class TensorParallelHead(SuperLayer): ...@@ -100,7 +100,7 @@ class TensorParallelHead(SuperLayer):
local_out = gather_input.T local_out = gather_input.T
torch.mm(input, self.linear.weight.T, out=local_out) torch.mm(input, self.linear.weight.T, out=local_out)
if IPEX_AVAIL: if SYSTEM == "ipex":
ipex.distributed.all_gather_into_tensor( ipex.distributed.all_gather_into_tensor(
world_out, gather_input, group=self.process_group world_out, gather_input, group=self.process_group
) )
...@@ -117,7 +117,7 @@ class TensorParallelHead(SuperLayer): ...@@ -117,7 +117,7 @@ class TensorParallelHead(SuperLayer):
world_output = [ world_output = [
torch.empty_like(output) for _ in range(self.process_group.size()) torch.empty_like(output) for _ in range(self.process_group.size())
] ]
if IPEX_AVAIL: if SYSTEM == "ipex":
ipex.distributed.all_gather(world_output, output, group=self.process_group) ipex.distributed.all_gather(world_output, output, group=self.process_group)
else: else:
torch.distributed.all_gather(world_output, output, group=self.process_group) torch.distributed.all_gather(world_output, output, group=self.process_group)
...@@ -217,7 +217,7 @@ class TensorParallelRowLinear(SuperLayer): ...@@ -217,7 +217,7 @@ class TensorParallelRowLinear(SuperLayer):
def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor: def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor:
out = super().forward(input) out = super().forward(input)
if self.process_group.size() > 1 and reduce: if self.process_group.size() > 1 and reduce:
if IPEX_AVAIL: if SYSTEM == "ipex":
ipex.distributed.all_reduce(out, group=self.process_group) ipex.distributed.all_reduce(out, group=self.process_group)
else: else:
torch.distributed.all_reduce(out, group=self.process_group) torch.distributed.all_reduce(out, group=self.process_group)
...@@ -257,7 +257,7 @@ class TensorParallelEmbedding(torch.nn.Module): ...@@ -257,7 +257,7 @@ class TensorParallelEmbedding(torch.nn.Module):
) )
out = torch.nn.functional.embedding(input, self.weight) out = torch.nn.functional.embedding(input, self.weight)
if self.reduce and self.process_group.size() > 1: if self.reduce and self.process_group.size() > 1:
if IPEX_AVAIL: if SYSTEM == "ipex":
ipex.distributed.all_reduce(out, group=self.process_group) ipex.distributed.all_reduce(out, group=self.process_group)
else: else:
torch.distributed.all_reduce(out, group=self.process_group) torch.distributed.all_reduce(out, group=self.process_group)
......
...@@ -20,9 +20,9 @@ from torch import nn ...@@ -20,9 +20,9 @@ from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple, Any from typing import Optional, List, Tuple, Any
from text_generation_server.utils.import_utils import IPEX_AVAIL from text_generation_server.utils.import_utils import SYSTEM
if not IPEX_AVAIL: if SYSTEM != "ipex":
from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe import fused_moe
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (
......
...@@ -24,9 +24,9 @@ import torch.distributed ...@@ -24,9 +24,9 @@ import torch.distributed
import numpy as np import numpy as np
from torch import nn from torch import nn
from text_generation_server.utils.import_utils import IPEX_AVAIL from text_generation_server.utils.import_utils import SYSTEM
if not IPEX_AVAIL: if SYSTEM != "ipex":
from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe import fused_moe
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
......
...@@ -15,7 +15,7 @@ from typing import Iterable, Optional, Tuple, List, Type, Dict ...@@ -15,7 +15,7 @@ from typing import Iterable, Optional, Tuple, List, Type, Dict
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models import Model from text_generation_server.models import Model
from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.utils.dist import RANK from text_generation_server.utils.dist import RANK
...@@ -768,12 +768,9 @@ class FlashCausalLM(Model): ...@@ -768,12 +768,9 @@ class FlashCausalLM(Model):
empty_cache() empty_cache()
element_size = torch.tensor([], dtype=dtype).element_size() element_size = torch.tensor([], dtype=dtype).element_size()
if SYSTEM == "xpu": x = BLOCK_SIZE // element_size
x = 1
else:
x = BLOCK_SIZE // element_size
if IPEX_AVAIL and SYSTEM == "cpu": if SYSTEM == "ipex" and device == torch.device("cpu"):
self.kv_cache = [ self.kv_cache = [
( (
torch.empty( torch.empty(
......
...@@ -15,7 +15,7 @@ from text_generation_server.utils import ( ...@@ -15,7 +15,7 @@ from text_generation_server.utils import (
weight_files, weight_files,
Weights, Weights,
) )
from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
...@@ -34,12 +34,12 @@ class FlashGPT2(FlashCausalLM): ...@@ -34,12 +34,12 @@ class FlashGPT2(FlashCausalLM):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "xpu": elif SYSTEM == "ipex":
device = torch.device(f"xpu:{rank}") if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
else:
device = torch.device("cpu")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
elif IPEX_AVAIL:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else: else:
raise NotImplementedError("FlashGPT2 is only available on GPU") raise NotImplementedError("FlashGPT2 is only available on GPU")
......
...@@ -17,7 +17,7 @@ from text_generation_server.utils import ( ...@@ -17,7 +17,7 @@ from text_generation_server.utils import (
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL from text_generation_server.utils.import_utils import SYSTEM
class FlashLlama(FlashCausalLM): class FlashLlama(FlashCausalLM):
...@@ -34,12 +34,12 @@ class FlashLlama(FlashCausalLM): ...@@ -34,12 +34,12 @@ class FlashLlama(FlashCausalLM):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "xpu": elif SYSTEM == "ipex":
device = torch.device(f"xpu:{rank}") if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
else:
device = torch.device("cpu")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
elif IPEX_AVAIL:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else: else:
raise NotImplementedError("FlashLlama is only available on GPU") raise NotImplementedError("FlashLlama is only available on GPU")
......
...@@ -16,7 +16,7 @@ from text_generation_server.utils import ( ...@@ -16,7 +16,7 @@ from text_generation_server.utils import (
weight_files, weight_files,
Weights, Weights,
) )
from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
...@@ -38,12 +38,12 @@ class BaseFlashMistral(FlashCausalLM): ...@@ -38,12 +38,12 @@ class BaseFlashMistral(FlashCausalLM):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "xpu": elif SYSTEM == "ipex":
device = torch.device(f"xpu:{rank}") if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
else:
device = torch.device("cpu")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
elif IPEX_AVAIL:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else: else:
raise NotImplementedError("FlashMistral is only available on GPU") raise NotImplementedError("FlashMistral is only available on GPU")
......
...@@ -14,7 +14,7 @@ from text_generation_server.utils import ( ...@@ -14,7 +14,7 @@ from text_generation_server.utils import (
weight_files, weight_files,
Weights, Weights,
) )
from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
...@@ -33,12 +33,12 @@ class FlashNeoXSharded(FlashCausalLM): ...@@ -33,12 +33,12 @@ class FlashNeoXSharded(FlashCausalLM):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "xpu": elif SYSTEM == "ipex":
device = torch.device(f"xpu:{rank}") if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
else:
device = torch.device("cpu")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
elif IPEX_AVAIL:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else: else:
raise NotImplementedError("FlashNeoX is only available on GPU") raise NotImplementedError("FlashNeoX is only available on GPU")
......
...@@ -15,7 +15,7 @@ from text_generation_server.utils import ( ...@@ -15,7 +15,7 @@ from text_generation_server.utils import (
weight_files, weight_files,
Weights, Weights,
) )
from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
...@@ -34,12 +34,12 @@ class FlashRWSharded(FlashCausalLM): ...@@ -34,12 +34,12 @@ class FlashRWSharded(FlashCausalLM):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "xpu": elif SYSTEM == "ipex":
device = torch.device(f"xpu:{rank}") if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
else:
device = torch.device("cpu")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
elif IPEX_AVAIL:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else: else:
raise NotImplementedError("FlashRW is only available on GPU") raise NotImplementedError("FlashRW is only available on GPU")
......
...@@ -18,7 +18,7 @@ from text_generation_server.utils import ( ...@@ -18,7 +18,7 @@ from text_generation_server.utils import (
Weights, Weights,
) )
from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
...@@ -37,12 +37,12 @@ class FlashSantacoderSharded(FlashCausalLM): ...@@ -37,12 +37,12 @@ class FlashSantacoderSharded(FlashCausalLM):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "xpu": elif SYSTEM == "ipex":
device = torch.device(f"xpu:{rank}") if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
else:
device = torch.device("cpu")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
elif IPEX_AVAIL:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else: else:
raise NotImplementedError("FlashSantacoderSharded is only available on GPU") raise NotImplementedError("FlashSantacoderSharded is only available on GPU")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment