"vscode:/vscode.git/clone" did not exist on "f82ade4777a5e91adac4438cf4a59172d026670b"
Unverified Commit 8f716817 authored by Huazhong Ji's avatar Huazhong Ji Committed by GitHub
Browse files

[HFLM]Add support for Ascend NPU (#1886)



* [HFLM]Add support for Ascend NPU
Co-authored-by: default avatarjiaqiw09 <jiaqiw960714@gmail.com>
Co-authored-by: default avatarzhabuye <2947436155@qq.com>

* bump accelerate dependency version to 0.26.0 for NPU compat.

---------
Co-authored-by: default avatarjiaqiw09 <jiaqiw960714@gmail.com>
Co-authored-by: default avatarzhabuye <2947436155@qq.com>
Co-authored-by: default avatarHailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>
parent b4cd85d4
...@@ -153,12 +153,16 @@ class HFLM(TemplateLM): ...@@ -153,12 +153,16 @@ class HFLM(TemplateLM):
if accelerator.num_processes > 1: if accelerator.num_processes > 1:
self.accelerator = accelerator self.accelerator = accelerator
if "npu" in accelerator.device.type:
gpus = torch.npu.device_count()
if not (parallelize or accelerator.num_processes > 1): if not (parallelize or accelerator.num_processes > 1):
# use user-passed device # use user-passed device
device_list = set( device_list = set(
["cuda", "cpu"] ["cuda", "cpu"]
+ [f"cuda:{i}" for i in range(gpus)] + [f"cuda:{i}" for i in range(gpus)]
+ ["mps", "mps:0"] + ["mps", "mps:0"]
+ [f"npu:{i}" for i in range(gpus)]
) )
if device and device in device_list: if device and device in device_list:
self._device = torch.device(device) self._device = torch.device(device)
...@@ -323,6 +327,7 @@ class HFLM(TemplateLM): ...@@ -323,6 +327,7 @@ class HFLM(TemplateLM):
in [ in [
DistributedType.FSDP, DistributedType.FSDP,
DistributedType.MULTI_GPU, DistributedType.MULTI_GPU,
DistributedType.MULTI_NPU,
] ]
), "Unsupported distributed type provided. Only DDP and FSDP are supported." ), "Unsupported distributed type provided. Only DDP and FSDP are supported."
if accelerator.distributed_type == DistributedType.FSDP: if accelerator.distributed_type == DistributedType.FSDP:
......
...@@ -19,7 +19,7 @@ classifiers = [ ...@@ -19,7 +19,7 @@ classifiers = [
requires-python = ">=3.8" requires-python = ">=3.8"
license = { "text" = "MIT" } license = { "text" = "MIT" }
dependencies = [ dependencies = [
"accelerate>=0.21.0", "accelerate>=0.26.0",
"evaluate", "evaluate",
"datasets>=2.16.0", "datasets>=2.16.0",
"evaluate>=0.4.0", "evaluate>=0.4.0",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment