Unverified Commit 58848cb4 authored by Dmitry Rogozhkin's avatar Dmitry Rogozhkin Committed by GitHub
Browse files

feat: enable pytorch xpu support for non-attention models (#2561)



XPU backend is available natively (without IPEX) in pytorch starting
from pytorch 2.4. This commit extends TGI to cover the case when user
has XPU support thru pytorch 2.4, but does not have IPEX installed.
Models which don't require attention can work. For attention required
models more work is needed to provide attention implementation.

Tested with the following models:
* teknium/OpenHermes-2.5-Mistral-7B
* bigscience/bloom-560m
* google/gemma-7b
* google/flan-t5-xxl
Signed-off-by: default avatarDmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
parent 7a82ddcb
......@@ -517,14 +517,13 @@ class CausalLM(Model):
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = default_dtype if dtype is None else dtype
elif hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = default_dtype if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = default_dtype if dtype is None else dtype
else:
device = torch.device("cpu")
# Float16 doesn't exist on target.
dtype = torch.bfloat16 if dtype is None else dtype
device = torch.device("cpu")
# Float16 doesn't exist on target.
dtype = torch.bfloat16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
......@@ -593,8 +592,14 @@ class CausalLM(Model):
if speculator:
raise RuntimeError("Speculator decoding is not enabled for AutoModel")
device_count = 0
if torch.cuda.is_available():
device = torch.device("cuda")
device_count = torch.cuda.device_count()
dtype = torch.float16 if dtype is None else dtype
elif hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device("xpu")
device_count = torch.xpu.device_count()
dtype = torch.float16 if dtype is None else dtype
else:
if quantize:
......@@ -616,18 +621,17 @@ class CausalLM(Model):
torch_dtype=dtype,
device_map=(
"auto"
if torch.cuda.is_available() and torch.cuda.device_count() > 1
if device_count > 1
else None
),
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code,
)
if (
torch.cuda.is_available()
and torch.cuda.device_count() == 1
device_count == 1
and quantize != "bitsandbytes"
):
model = model.cuda()
model = model.to(device)
if tokenizer.pad_token_id is None:
if model.config.pad_token_id is not None:
......
......@@ -558,14 +558,13 @@ class Seq2SeqLM(Model):
if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}")
dtype = default_dtype if dtype is None else dtype
elif hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = default_dtype if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = default_dtype if dtype is None else dtype
else:
device = torch.device("cpu")
# Float16 doesn't exist on target.
dtype = torch.bfloat16 if dtype is None else dtype
device = torch.device("cpu")
# Float16 doesn't exist on target.
dtype = torch.bfloat16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
......@@ -630,8 +629,14 @@ class Seq2SeqLM(Model):
if speculator:
raise RuntimeError("Speculator decoding is not enabled for AutoModel")
device_count = 0
if torch.cuda.is_available():
device = torch.device("cuda")
device_count = torch.cuda.device_count()
dtype = torch.float16 if dtype is None else dtype
elif hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device("xpu")
device_count = torch.xpu.device_count()
dtype = torch.float16 if dtype is None else dtype
else:
if quantize:
......@@ -646,14 +651,14 @@ class Seq2SeqLM(Model):
torch_dtype=dtype,
device_map=(
"auto"
if torch.cuda.is_available() and torch.cuda.device_count() > 1
if device_count > 1
else None
),
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code,
)
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
model = model.cuda()
if device_count == 1:
model = model.to(device)
tokenizer = AutoTokenizer.from_pretrained(
model_id,
......
......@@ -66,6 +66,11 @@ elif is_ipex_available():
empty_cache = noop
synchronize = noop
get_free_memory = get_cpu_free_memory
elif hasattr(torch, "xpu") and torch.xpu.is_available():
SYSTEM = "xpu"
empty_cache = torch.xpu.empty_cache
synchronize = torch.xpu.synchronize
get_free_memory = get_xpu_free_memory
else:
SYSTEM = "cpu"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment