Unverified Commit 5da4cfab authored by Wang, Yi's avatar Wang, Yi Committed by GitHub
Browse files

refine get xpu free memory/enable Qwen2/gemma2/gemma/phi in intel platform (#2132)



* refine get xpu free memory
Signed-off-by: default avatarWang, Yi A <yi.a.wang@intel.com>

* enable qwen2 in xpu
Signed-off-by: default avatarWang, Yi A <yi.a.wang@intel.com>

* enable gemma/gemma2/phi in intel platform
Signed-off-by: default avatarWang, Yi A <yi.a.wang@intel.com>

---------
Signed-off-by: default avatarWang, Yi A <yi.a.wang@intel.com>
parent 9d0ca503
...@@ -14,6 +14,7 @@ def attention( ...@@ -14,6 +14,7 @@ def attention(
max_s, max_s,
softmax_scale, softmax_scale,
window_size_left=-1, window_size_left=-1,
causal=True,
): ):
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load. # We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
return ipex.llm.functional.varlen_attention( return ipex.llm.functional.varlen_attention(
...@@ -28,7 +29,7 @@ def attention( ...@@ -28,7 +29,7 @@ def attention(
0.0, 0.0,
softmax_scale, softmax_scale,
False, False,
True, causal,
False, False,
None, None,
) )
......
...@@ -14,6 +14,7 @@ from text_generation_server.utils import ( ...@@ -14,6 +14,7 @@ from text_generation_server.utils import (
weight_files, weight_files,
Weights, Weights,
) )
from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
...@@ -32,6 +33,13 @@ class FlashGemma(FlashCausalLM): ...@@ -32,6 +33,13 @@ class FlashGemma(FlashCausalLM):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.bfloat16 if dtype is None else dtype dtype = torch.bfloat16 if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else: else:
raise NotImplementedError("FlashGemma is only available on GPU") raise NotImplementedError("FlashGemma is only available on GPU")
......
...@@ -14,6 +14,7 @@ from text_generation_server.utils import ( ...@@ -14,6 +14,7 @@ from text_generation_server.utils import (
weight_files, weight_files,
Weights, Weights,
) )
from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
...@@ -32,6 +33,13 @@ class FlashGemma2(FlashCausalLM): ...@@ -32,6 +33,13 @@ class FlashGemma2(FlashCausalLM):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.bfloat16 if dtype is None else dtype dtype = torch.bfloat16 if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else: else:
raise NotImplementedError("FlashGemma2 is only available on GPU") raise NotImplementedError("FlashGemma2 is only available on GPU")
......
...@@ -14,6 +14,7 @@ from text_generation_server.utils import ( ...@@ -14,6 +14,7 @@ from text_generation_server.utils import (
weight_files, weight_files,
Weights, Weights,
) )
from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
...@@ -32,6 +33,13 @@ class FlashPhi(FlashCausalLM): ...@@ -32,6 +33,13 @@ class FlashPhi(FlashCausalLM):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else: else:
raise NotImplementedError("FlashPhi is only available on GPU") raise NotImplementedError("FlashPhi is only available on GPU")
......
...@@ -19,6 +19,7 @@ from text_generation_server.utils import ( ...@@ -19,6 +19,7 @@ from text_generation_server.utils import (
weight_files, weight_files,
Weights, Weights,
) )
from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
...@@ -37,6 +38,13 @@ class FlashQwen2(BaseFlashMistral): ...@@ -37,6 +38,13 @@ class FlashQwen2(BaseFlashMistral):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype
else:
device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype
else: else:
raise NotImplementedError("FlashQwen2 is only available on GPU") raise NotImplementedError("FlashQwen2 is only available on GPU")
......
import torch import torch
from loguru import logger from loguru import logger
import subprocess import subprocess
import os
def is_ipex_available(): def is_ipex_available():
...@@ -21,10 +22,13 @@ def get_cuda_free_memory(device, memory_fraction): ...@@ -21,10 +22,13 @@ def get_cuda_free_memory(device, memory_fraction):
def get_xpu_free_memory(device, memory_fraction): def get_xpu_free_memory(device, memory_fraction):
total_memory = torch.xpu.get_device_properties(device).total_memory total_memory = torch.xpu.get_device_properties(device).total_memory
device_id = device.index device_id = device.index
query = f"xpu-smi dump -d {device_id} -m 18 -n 1" memory_fraction = float(os.getenv("XPU_MEMORY_FRACTION", "1.0"))
output = subprocess.check_output(query.split()).decode("utf-8").split("\n") free_memory = max(
used_memory = float(output[1].split(",")[-1]) * 1024 * 1024 0,
free_memory = int(total_memory * 0.95 - used_memory) int(
total_memory * 0.9 * memory_fraction - torch.xpu.memory_reserved(device_id)
),
)
return free_memory return free_memory
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment