Unverified Commit 8f716817 authored by Huazhong Ji's avatar Huazhong Ji Committed by GitHub
Browse files

[HFLM]Add support for Ascend NPU (#1886)



* [HFLM]Add support for Ascend NPU
Co-authored-by: default avatarjiaqiw09 <jiaqiw960714@gmail.com>
Co-authored-by: default avatarzhabuye <2947436155@qq.com>

* bump accelerate dependency version to 0.26.0 for NPU compat.

---------
Co-authored-by: default avatarjiaqiw09 <jiaqiw960714@gmail.com>
Co-authored-by: default avatarzhabuye <2947436155@qq.com>
Co-authored-by: default avatarHailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>
parent b4cd85d4
......@@ -153,12 +153,16 @@ class HFLM(TemplateLM):
if accelerator.num_processes > 1:
self.accelerator = accelerator
if "npu" in accelerator.device.type:
gpus = torch.npu.device_count()
if not (parallelize or accelerator.num_processes > 1):
# use user-passed device
device_list = set(
["cuda", "cpu"]
+ [f"cuda:{i}" for i in range(gpus)]
+ ["mps", "mps:0"]
+ [f"npu:{i}" for i in range(gpus)]
)
if device and device in device_list:
self._device = torch.device(device)
......@@ -323,6 +327,7 @@ class HFLM(TemplateLM):
in [
DistributedType.FSDP,
DistributedType.MULTI_GPU,
DistributedType.MULTI_NPU,
]
), "Unsupported distributed type provided. Only DDP and FSDP are supported."
if accelerator.distributed_type == DistributedType.FSDP:
......
......@@ -19,7 +19,7 @@ classifiers = [
requires-python = ">=3.8"
license = { "text" = "MIT" }
dependencies = [
"accelerate>=0.21.0",
"accelerate>=0.26.0",
"evaluate",
"datasets>=2.16.0",
"evaluate>=0.4.0",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment