Unverified Commit e5ce4c8e authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[npu] add npu support for gemini and zero (#5067)

* [npu] setup device utils (#5047)

* [npu] add npu device support

* [npu] support low level zero

* [test] update npu zero plugin test

* [hotfix] fix import

* [test] recover tests

* [npu] gemini support npu (#5052)

* [npu] refactor device utils

* [gemini] support npu

* [example] llama2+gemini support npu

* [kernel] add arm cpu adam kernel (#5065)

* [kernel] add arm cpu adam

* [optim] update adam optimizer

* [kernel] arm cpu adam remove bf16 support
parent 8d56c9c3
...@@ -3,7 +3,7 @@ import pytest ...@@ -3,7 +3,7 @@ import pytest
import colossalai import colossalai
from colossalai.legacy.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction from colossalai.legacy.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction
from colossalai.testing import spawn from colossalai.testing import spawn
from colossalai.utils.cuda import get_current_device from colossalai.utils.device import get_current_device
def _run_colo_set_process_memory_fraction_and_colo_device_memory_capacity(): def _run_colo_set_process_memory_fraction_and_colo_device_memory_capacity():
......
...@@ -9,7 +9,7 @@ from colossalai.legacy.amp import convert_to_apex_amp ...@@ -9,7 +9,7 @@ from colossalai.legacy.amp import convert_to_apex_amp
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import set_seed from colossalai.utils import set_seed
from colossalai.utils.cuda import get_current_device from colossalai.utils.device import get_current_device
from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.kit.model_zoo import model_zoo, run_fwd_bwd from tests.kit.model_zoo import model_zoo, run_fwd_bwd
......
...@@ -9,7 +9,7 @@ import colossalai ...@@ -9,7 +9,7 @@ import colossalai
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import set_seed from colossalai.utils import set_seed
from colossalai.utils.cuda import get_current_device from colossalai.utils.device import get_current_device
from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.kit.model_zoo import model_zoo, run_fwd from tests.kit.model_zoo import model_zoo, run_fwd
......
...@@ -11,7 +11,7 @@ from colossalai.legacy.amp import convert_to_apex_amp ...@@ -11,7 +11,7 @@ from colossalai.legacy.amp import convert_to_apex_amp
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import set_seed from colossalai.utils import set_seed
from colossalai.utils.cuda import get_current_device from colossalai.utils.device import get_current_device
from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.kit.model_zoo import model_zoo, run_fwd, run_fwd_bwd from tests.kit.model_zoo import model_zoo, run_fwd, run_fwd_bwd
......
...@@ -9,7 +9,7 @@ from colossalai.legacy.amp import convert_to_apex_amp ...@@ -9,7 +9,7 @@ from colossalai.legacy.amp import convert_to_apex_amp
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import set_seed from colossalai.utils import set_seed
from colossalai.utils.cuda import get_current_device from colossalai.utils.device import get_current_device
from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.kit.model_zoo import model_zoo, run_fwd_bwd from tests.kit.model_zoo import model_zoo, run_fwd_bwd
......
...@@ -9,7 +9,7 @@ from torch.testing import assert_close ...@@ -9,7 +9,7 @@ from torch.testing import assert_close
import colossalai import colossalai
from colossalai.testing import spawn from colossalai.testing import spawn
from colossalai.testing.random import seed_all from colossalai.testing.random import seed_all
from colossalai.utils import conditional_context from colossalai.utils import conditional_context, get_current_device
from colossalai.zero import LowLevelZeroOptimizer from colossalai.zero import LowLevelZeroOptimizer
...@@ -28,9 +28,9 @@ class MlpModel(nn.Module): ...@@ -28,9 +28,9 @@ class MlpModel(nn.Module):
def exam_zero_1_2_grad_acc(): def exam_zero_1_2_grad_acc():
local_rank = torch.distributed.get_rank() local_rank = torch.distributed.get_rank()
seed_all(2009) seed_all(2009)
device = get_current_device()
# create model # create model
zero1_model = MlpModel().cuda() zero1_model = MlpModel().to(device)
zero2_model = copy.deepcopy(zero1_model) zero2_model = copy.deepcopy(zero1_model)
# create optimizer # create optimizer
zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1) zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1)
...@@ -43,8 +43,8 @@ def exam_zero_1_2_grad_acc(): ...@@ -43,8 +43,8 @@ def exam_zero_1_2_grad_acc():
) )
# create data # create data
seed_all(2021 + local_rank) seed_all(2021 + local_rank)
input_data1 = torch.randn(32, 128).cuda() input_data1 = torch.randn(32, 128, device=device)
input_data2 = torch.randn(32, 128).cuda() input_data2 = torch.randn(32, 128, device=device)
def fwd_bwd_func(number, cur_data, check_flag): def fwd_bwd_func(number, cur_data, check_flag):
# zero-dp forward # zero-dp forward
...@@ -71,14 +71,15 @@ def exam_zero_1_2_grad_acc(): ...@@ -71,14 +71,15 @@ def exam_zero_1_2_grad_acc():
def exam_zero_1_grad_acc(sync): def exam_zero_1_grad_acc(sync):
local_rank = torch.distributed.get_rank() local_rank = torch.distributed.get_rank()
seed_all(2008) seed_all(2008)
device = get_current_device()
# create models # create models
zero_model = MlpModel() zero_model = MlpModel()
torch_model = copy.deepcopy(zero_model) torch_model = copy.deepcopy(zero_model)
seed_all(2008) seed_all(2008)
zero_model = zero_model.cuda() zero_model = zero_model.to(device)
torch_model = DDP(torch_model.cuda(), bucket_cap_mb=0) torch_model = DDP(torch_model.to(device), bucket_cap_mb=0)
# create optimizer # create optimizer
zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1) zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1)
...@@ -94,8 +95,8 @@ def exam_zero_1_grad_acc(sync): ...@@ -94,8 +95,8 @@ def exam_zero_1_grad_acc(sync):
# create data # create data
seed_all(2022 + local_rank) seed_all(2022 + local_rank)
input_data1 = torch.randn(32, 128).cuda() input_data1 = torch.randn(32, 128, device=device)
input_data2 = torch.randn(32, 128).cuda() input_data2 = torch.randn(32, 128, device=device)
def fwd_bwd_func(no_sync, cur_data, check_flag): def fwd_bwd_func(no_sync, cur_data, check_flag):
# zero1 fwd and bwd # zero1 fwd and bwd
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment