Unverified Commit 58848cb4 authored by Dmitry Rogozhkin's avatar Dmitry Rogozhkin Committed by GitHub
Browse files

feat: enable pytorch xpu support for non-attention models (#2561)



XPU backend is available natively (without IPEX) in pytorch starting
from pytorch 2.4. This commit extends TGI to cover the case when user
has XPU support thru pytorch 2.4, but does not have IPEX installed.
Models which don't require attention can work. For attention required
models more work is needed to provide attention implementation.

Tested with the following models:
* teknium/OpenHermes-2.5-Mistral-7B
* bigscience/bloom-560m
* google/gemma-7b
* google/flan-t5-xxl
Signed-off-by: default avatarDmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
parent 7a82ddcb
...@@ -517,11 +517,10 @@ class CausalLM(Model): ...@@ -517,11 +517,10 @@ class CausalLM(Model):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = default_dtype if dtype is None else dtype dtype = default_dtype if dtype is None else dtype
elif SYSTEM == "ipex": elif hasattr(torch, "xpu") and torch.xpu.is_available():
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}") device = torch.device(f"xpu:{rank}")
dtype = default_dtype if dtype is None else dtype dtype = default_dtype if dtype is None else dtype
else: elif SYSTEM == "ipex":
device = torch.device("cpu") device = torch.device("cpu")
# Float16 doesn't exist on target. # Float16 doesn't exist on target.
dtype = torch.bfloat16 if dtype is None else dtype dtype = torch.bfloat16 if dtype is None else dtype
...@@ -593,8 +592,14 @@ class CausalLM(Model): ...@@ -593,8 +592,14 @@ class CausalLM(Model):
if speculator: if speculator:
raise RuntimeError("Speculator decoding is not enabled for AutoModel") raise RuntimeError("Speculator decoding is not enabled for AutoModel")
device_count = 0
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
device_count = torch.cuda.device_count()
dtype = torch.float16 if dtype is None else dtype
elif hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device("xpu")
device_count = torch.xpu.device_count()
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
else: else:
if quantize: if quantize:
...@@ -616,18 +621,17 @@ class CausalLM(Model): ...@@ -616,18 +621,17 @@ class CausalLM(Model):
torch_dtype=dtype, torch_dtype=dtype,
device_map=( device_map=(
"auto" "auto"
if torch.cuda.is_available() and torch.cuda.device_count() > 1 if device_count > 1
else None else None
), ),
load_in_8bit=quantize == "bitsandbytes", load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if ( if (
torch.cuda.is_available() device_count == 1
and torch.cuda.device_count() == 1
and quantize != "bitsandbytes" and quantize != "bitsandbytes"
): ):
model = model.cuda() model = model.to(device)
if tokenizer.pad_token_id is None: if tokenizer.pad_token_id is None:
if model.config.pad_token_id is not None: if model.config.pad_token_id is not None:
......
...@@ -558,11 +558,10 @@ class Seq2SeqLM(Model): ...@@ -558,11 +558,10 @@ class Seq2SeqLM(Model):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = default_dtype if dtype is None else dtype dtype = default_dtype if dtype is None else dtype
elif SYSTEM == "ipex": elif hasattr(torch, "xpu") and torch.xpu.is_available():
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}") device = torch.device(f"xpu:{rank}")
dtype = default_dtype if dtype is None else dtype dtype = default_dtype if dtype is None else dtype
else: elif SYSTEM == "ipex":
device = torch.device("cpu") device = torch.device("cpu")
# Float16 doesn't exist on target. # Float16 doesn't exist on target.
dtype = torch.bfloat16 if dtype is None else dtype dtype = torch.bfloat16 if dtype is None else dtype
...@@ -630,8 +629,14 @@ class Seq2SeqLM(Model): ...@@ -630,8 +629,14 @@ class Seq2SeqLM(Model):
if speculator: if speculator:
raise RuntimeError("Speculator decoding is not enabled for AutoModel") raise RuntimeError("Speculator decoding is not enabled for AutoModel")
device_count = 0
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
device_count = torch.cuda.device_count()
dtype = torch.float16 if dtype is None else dtype
elif hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device("xpu")
device_count = torch.xpu.device_count()
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
else: else:
if quantize: if quantize:
...@@ -646,14 +651,14 @@ class Seq2SeqLM(Model): ...@@ -646,14 +651,14 @@ class Seq2SeqLM(Model):
torch_dtype=dtype, torch_dtype=dtype,
device_map=( device_map=(
"auto" "auto"
if torch.cuda.is_available() and torch.cuda.device_count() > 1 if device_count > 1
else None else None
), ),
load_in_8bit=quantize == "bitsandbytes", load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
if torch.cuda.is_available() and torch.cuda.device_count() == 1: if device_count == 1:
model = model.cuda() model = model.to(device)
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, model_id,
......
...@@ -66,6 +66,11 @@ elif is_ipex_available(): ...@@ -66,6 +66,11 @@ elif is_ipex_available():
empty_cache = noop empty_cache = noop
synchronize = noop synchronize = noop
get_free_memory = get_cpu_free_memory get_free_memory = get_cpu_free_memory
elif hasattr(torch, "xpu") and torch.xpu.is_available():
SYSTEM = "xpu"
empty_cache = torch.xpu.empty_cache
synchronize = torch.xpu.synchronize
get_free_memory = get_xpu_free_memory
else: else:
SYSTEM = "cpu" SYSTEM = "cpu"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment