Unverified Commit 368275f3 authored by kaixuanliu's avatar kaixuanliu Committed by GitHub
Browse files

add xpu support HFLM (#3211)


Signed-off-by: default avatarLiu, Kaixuan <kaixuan.liu@intel.com>
parent 7f698a5a
...@@ -124,14 +124,22 @@ class HFLM(TemplateLM): ...@@ -124,14 +124,22 @@ class HFLM(TemplateLM):
assert isinstance(pretrained, str) assert isinstance(pretrained, str)
assert isinstance(batch_size, (int, str)) assert isinstance(batch_size, (int, str))
gpus = torch.cuda.device_count()
accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52)) accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs]) accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs])
if accelerator.num_processes > 1: if accelerator.num_processes > 1:
self.accelerator = accelerator self.accelerator = accelerator
if "npu" in accelerator.device.type: # Detect device count based on accelerator device type
device_type = accelerator.device.type
if "cuda" in device_type:
gpus = torch.cuda.device_count()
elif "npu" in device_type:
gpus = torch.npu.device_count() gpus = torch.npu.device_count()
elif "xpu" in device_type:
gpus = torch.xpu.device_count()
else:
# Fallback to CUDA count for compatibility
gpus = torch.cuda.device_count()
# using one process with no model parallelism # using one process with no model parallelism
if not (parallelize or accelerator.num_processes > 1): if not (parallelize or accelerator.num_processes > 1):
...@@ -141,6 +149,7 @@ class HFLM(TemplateLM): ...@@ -141,6 +149,7 @@ class HFLM(TemplateLM):
+ [f"cuda:{i}" for i in range(gpus)] + [f"cuda:{i}" for i in range(gpus)]
+ ["mps", "mps:0"] + ["mps", "mps:0"]
+ [f"npu:{i}" for i in range(gpus)] + [f"npu:{i}" for i in range(gpus)]
+ [f"xpu:{i}" for i in range(gpus)]
) )
if device and device in device_list: if device and device in device_list:
self._device = torch.device(device) self._device = torch.device(device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment