Unverified Commit e5ce4c8e authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[npu] add npu support for gemini and zero (#5067)

* [npu] setup device utils (#5047)

* [npu] add npu device support

* [npu] support low level zero

* [test] update npu zero plugin test

* [hotfix] fix import

* [test] recover tests

* [npu] gemini support npu (#5052)

* [npu] refactor device utils

* [gemini] support npu

* [example] llama2+gemini support npu

* [kernel] add arm cpu adam kernel (#5065)

* [kernel] add arm cpu adam

* [optim] update adam optimizer

* [kernel] arm cpu adam remove bf16 support
parent 8d56c9c3
......@@ -3,7 +3,7 @@ import pytest
import colossalai
from colossalai.legacy.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction
from colossalai.testing import spawn
from colossalai.utils.cuda import get_current_device
from colossalai.utils.device import get_current_device
def _run_colo_set_process_memory_fraction_and_colo_device_memory_capacity():
......
......@@ -9,7 +9,7 @@ from colossalai.legacy.amp import convert_to_apex_amp
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import set_seed
from colossalai.utils.cuda import get_current_device
from colossalai.utils.device import get_current_device
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.kit.model_zoo import model_zoo, run_fwd_bwd
......
......@@ -9,7 +9,7 @@ import colossalai
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import set_seed
from colossalai.utils.cuda import get_current_device
from colossalai.utils.device import get_current_device
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.kit.model_zoo import model_zoo, run_fwd
......
......@@ -11,7 +11,7 @@ from colossalai.legacy.amp import convert_to_apex_amp
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import set_seed
from colossalai.utils.cuda import get_current_device
from colossalai.utils.device import get_current_device
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.kit.model_zoo import model_zoo, run_fwd, run_fwd_bwd
......
......@@ -9,7 +9,7 @@ from colossalai.legacy.amp import convert_to_apex_amp
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import set_seed
from colossalai.utils.cuda import get_current_device
from colossalai.utils.device import get_current_device
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.kit.model_zoo import model_zoo, run_fwd_bwd
......
......@@ -9,7 +9,7 @@ from torch.testing import assert_close
import colossalai
from colossalai.testing import spawn
from colossalai.testing.random import seed_all
from colossalai.utils import conditional_context
from colossalai.utils import conditional_context, get_current_device
from colossalai.zero import LowLevelZeroOptimizer
......@@ -28,9 +28,9 @@ class MlpModel(nn.Module):
def exam_zero_1_2_grad_acc():
local_rank = torch.distributed.get_rank()
seed_all(2009)
device = get_current_device()
# create model
zero1_model = MlpModel().cuda()
zero1_model = MlpModel().to(device)
zero2_model = copy.deepcopy(zero1_model)
# create optimizer
zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1)
......@@ -43,8 +43,8 @@ def exam_zero_1_2_grad_acc():
)
# create data
seed_all(2021 + local_rank)
input_data1 = torch.randn(32, 128).cuda()
input_data2 = torch.randn(32, 128).cuda()
input_data1 = torch.randn(32, 128, device=device)
input_data2 = torch.randn(32, 128, device=device)
def fwd_bwd_func(number, cur_data, check_flag):
# zero-dp forward
......@@ -71,14 +71,15 @@ def exam_zero_1_2_grad_acc():
def exam_zero_1_grad_acc(sync):
local_rank = torch.distributed.get_rank()
seed_all(2008)
device = get_current_device()
# create models
zero_model = MlpModel()
torch_model = copy.deepcopy(zero_model)
seed_all(2008)
zero_model = zero_model.cuda()
torch_model = DDP(torch_model.cuda(), bucket_cap_mb=0)
zero_model = zero_model.to(device)
torch_model = DDP(torch_model.to(device), bucket_cap_mb=0)
# create optimizer
zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1)
......@@ -94,8 +95,8 @@ def exam_zero_1_grad_acc(sync):
# create data
seed_all(2022 + local_rank)
input_data1 = torch.randn(32, 128).cuda()
input_data2 = torch.randn(32, 128).cuda()
input_data1 = torch.randn(32, 128, device=device)
input_data2 = torch.randn(32, 128, device=device)
def fwd_bwd_func(no_sync, cur_data, check_flag):
# zero1 fwd and bwd
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment