Unverified Commit 8823cc48 authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

Merge pull request #5310 from hpcaitech/feature/npu

Feature/npu
parents bce9499e 73f4dc57
......@@ -5,11 +5,11 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close
import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.legacy.amp import convert_to_apex_amp
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import set_seed
from colossalai.utils.device import get_current_device
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.kit.model_zoo import model_zoo, run_fwd_bwd
......@@ -47,7 +47,7 @@ def exam_gpt_fwd_bwd(
use_grad_checkpoint: bool = False,
master_weights: bool = True,
):
init_device = get_current_device()
init_device = get_accelerator().get_current_device()
model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
iter(model_zoo.get_sub_registry(model_name).values())
)
......
......@@ -6,10 +6,10 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close
import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import set_seed
from colossalai.utils.device import get_current_device
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.kit.model_zoo import model_zoo, run_fwd
......@@ -53,7 +53,7 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module):
def exam_gemini_grad_acc(
placement_config, keep_gathered: bool, model_name: str, master_weights: bool, use_grad_checkpoint: bool
):
init_device = get_current_device()
init_device = get_accelerator().get_current_device()
model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
iter(model_zoo.get_sub_registry(model_name).values())
)
......
......@@ -7,11 +7,11 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close
import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.legacy.amp import convert_to_apex_amp
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import set_seed
from colossalai.utils.device import get_current_device
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.kit.model_zoo import model_zoo, run_fwd, run_fwd_bwd
......@@ -47,7 +47,9 @@ def multi_chunk_init(model: torch.nn.Module, placement_config: dict):
def single_chunk_init(model: torch.nn.Module, placement_config: dict):
model = GeminiDDP(model, chunk_init_device=get_current_device(), pin_memory=True, **placement_config)
model = GeminiDDP(
model, chunk_init_device=get_accelerator().get_current_device(), pin_memory=True, **placement_config
)
return model
......@@ -63,7 +65,7 @@ def exam_inference(placement_config: dict, model_name: str, model_init_func: Cal
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
torch_model = DDP(torch_model, device_ids=[dist.get_rank()])
init_dev = get_current_device()
init_dev = get_accelerator().get_current_device()
model = model_builder().to(init_dev)
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
......
......@@ -5,11 +5,11 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close
import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.legacy.amp import convert_to_apex_amp
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import set_seed
from colossalai.utils.device import get_current_device
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.kit.model_zoo import model_zoo, run_fwd_bwd
......@@ -150,7 +150,7 @@ def exam_tiny_example(placement_config, model_name: str, mixed_precision: torch.
model = GeminiDDP(
model,
chunk_init_device=get_current_device(),
chunk_init_device=get_accelerator().get_current_device(),
search_range_m=1,
pin_memory=True,
mixed_precision=mixed_precision,
......
......@@ -2,8 +2,8 @@ import pytest
import torch
import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
from colossalai.zero.gemini.chunk import init_chunk_manager, search_chunk_configuration
from tests.kit.model_zoo import model_zoo
......@@ -34,7 +34,7 @@ def exam_chunk_manager():
sharded_ddp_model = model_builder()
chunk_manager = init_chunk_manager(
sharded_ddp_model,
get_current_device(),
get_accelerator().get_current_device(),
hidden_dim=128,
search_range_m=1,
min_chunk_size_m=0,
......
......@@ -7,9 +7,10 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close
import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.testing import spawn
from colossalai.testing.random import seed_all
from colossalai.utils import conditional_context, get_current_device
from colossalai.utils import conditional_context
from colossalai.zero import LowLevelZeroOptimizer
......@@ -28,7 +29,7 @@ class MlpModel(nn.Module):
def exam_zero_1_2_grad_acc():
local_rank = torch.distributed.get_rank()
seed_all(2009)
device = get_current_device()
device = get_accelerator().get_current_device()
# create model
zero1_model = MlpModel().to(device)
zero2_model = copy.deepcopy(zero1_model)
......@@ -71,7 +72,7 @@ def exam_zero_1_2_grad_acc():
def exam_zero_1_grad_acc(sync):
local_rank = torch.distributed.get_rank()
seed_all(2008)
device = get_current_device()
device = get_accelerator().get_current_device()
# create models
zero_model = MlpModel()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment