Unverified Commit 8823cc48 authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

Merge pull request #5310 from hpcaitech/feature/npu

Feature/npu
parents bce9499e 73f4dc57
...@@ -5,11 +5,11 @@ from torch.nn.parallel import DistributedDataParallel as DDP ...@@ -5,11 +5,11 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close from torch.testing import assert_close
import colossalai import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.legacy.amp import convert_to_apex_amp from colossalai.legacy.amp import convert_to_apex_amp
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import set_seed from colossalai.utils import set_seed
from colossalai.utils.device import get_current_device
from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.kit.model_zoo import model_zoo, run_fwd_bwd from tests.kit.model_zoo import model_zoo, run_fwd_bwd
...@@ -47,7 +47,7 @@ def exam_gpt_fwd_bwd( ...@@ -47,7 +47,7 @@ def exam_gpt_fwd_bwd(
use_grad_checkpoint: bool = False, use_grad_checkpoint: bool = False,
master_weights: bool = True, master_weights: bool = True,
): ):
init_device = get_current_device() init_device = get_accelerator().get_current_device()
model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next( model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
iter(model_zoo.get_sub_registry(model_name).values()) iter(model_zoo.get_sub_registry(model_name).values())
) )
......
...@@ -6,10 +6,10 @@ from torch.nn.parallel import DistributedDataParallel as DDP ...@@ -6,10 +6,10 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close from torch.testing import assert_close
import colossalai import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import set_seed from colossalai.utils import set_seed
from colossalai.utils.device import get_current_device
from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.kit.model_zoo import model_zoo, run_fwd from tests.kit.model_zoo import model_zoo, run_fwd
...@@ -53,7 +53,7 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module): ...@@ -53,7 +53,7 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module):
def exam_gemini_grad_acc( def exam_gemini_grad_acc(
placement_config, keep_gathered: bool, model_name: str, master_weights: bool, use_grad_checkpoint: bool placement_config, keep_gathered: bool, model_name: str, master_weights: bool, use_grad_checkpoint: bool
): ):
init_device = get_current_device() init_device = get_accelerator().get_current_device()
model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next( model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
iter(model_zoo.get_sub_registry(model_name).values()) iter(model_zoo.get_sub_registry(model_name).values())
) )
......
...@@ -7,11 +7,11 @@ from torch.nn.parallel import DistributedDataParallel as DDP ...@@ -7,11 +7,11 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close from torch.testing import assert_close
import colossalai import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.legacy.amp import convert_to_apex_amp from colossalai.legacy.amp import convert_to_apex_amp
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import set_seed from colossalai.utils import set_seed
from colossalai.utils.device import get_current_device
from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.kit.model_zoo import model_zoo, run_fwd, run_fwd_bwd from tests.kit.model_zoo import model_zoo, run_fwd, run_fwd_bwd
...@@ -47,7 +47,9 @@ def multi_chunk_init(model: torch.nn.Module, placement_config: dict): ...@@ -47,7 +47,9 @@ def multi_chunk_init(model: torch.nn.Module, placement_config: dict):
def single_chunk_init(model: torch.nn.Module, placement_config: dict): def single_chunk_init(model: torch.nn.Module, placement_config: dict):
model = GeminiDDP(model, chunk_init_device=get_current_device(), pin_memory=True, **placement_config) model = GeminiDDP(
model, chunk_init_device=get_accelerator().get_current_device(), pin_memory=True, **placement_config
)
return model return model
...@@ -63,7 +65,7 @@ def exam_inference(placement_config: dict, model_name: str, model_init_func: Cal ...@@ -63,7 +65,7 @@ def exam_inference(placement_config: dict, model_name: str, model_init_func: Cal
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3) torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
torch_model = DDP(torch_model, device_ids=[dist.get_rank()]) torch_model = DDP(torch_model, device_ids=[dist.get_rank()])
init_dev = get_current_device() init_dev = get_accelerator().get_current_device()
model = model_builder().to(init_dev) model = model_builder().to(init_dev)
for torch_p, p in zip(torch_model.parameters(), model.parameters()): for torch_p, p in zip(torch_model.parameters(), model.parameters()):
......
...@@ -5,11 +5,11 @@ from torch.nn.parallel import DistributedDataParallel as DDP ...@@ -5,11 +5,11 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close from torch.testing import assert_close
import colossalai import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.legacy.amp import convert_to_apex_amp from colossalai.legacy.amp import convert_to_apex_amp
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import set_seed from colossalai.utils import set_seed
from colossalai.utils.device import get_current_device
from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.kit.model_zoo import model_zoo, run_fwd_bwd from tests.kit.model_zoo import model_zoo, run_fwd_bwd
...@@ -150,7 +150,7 @@ def exam_tiny_example(placement_config, model_name: str, mixed_precision: torch. ...@@ -150,7 +150,7 @@ def exam_tiny_example(placement_config, model_name: str, mixed_precision: torch.
model = GeminiDDP( model = GeminiDDP(
model, model,
chunk_init_device=get_current_device(), chunk_init_device=get_accelerator().get_current_device(),
search_range_m=1, search_range_m=1,
pin_memory=True, pin_memory=True,
mixed_precision=mixed_precision, mixed_precision=mixed_precision,
......
...@@ -2,8 +2,8 @@ import pytest ...@@ -2,8 +2,8 @@ import pytest
import torch import torch
import colossalai import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
from colossalai.zero.gemini.chunk import init_chunk_manager, search_chunk_configuration from colossalai.zero.gemini.chunk import init_chunk_manager, search_chunk_configuration
from tests.kit.model_zoo import model_zoo from tests.kit.model_zoo import model_zoo
...@@ -34,7 +34,7 @@ def exam_chunk_manager(): ...@@ -34,7 +34,7 @@ def exam_chunk_manager():
sharded_ddp_model = model_builder() sharded_ddp_model = model_builder()
chunk_manager = init_chunk_manager( chunk_manager = init_chunk_manager(
sharded_ddp_model, sharded_ddp_model,
get_current_device(), get_accelerator().get_current_device(),
hidden_dim=128, hidden_dim=128,
search_range_m=1, search_range_m=1,
min_chunk_size_m=0, min_chunk_size_m=0,
......
...@@ -7,9 +7,10 @@ from torch.nn.parallel import DistributedDataParallel as DDP ...@@ -7,9 +7,10 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close from torch.testing import assert_close
import colossalai import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.testing import spawn from colossalai.testing import spawn
from colossalai.testing.random import seed_all from colossalai.testing.random import seed_all
from colossalai.utils import conditional_context, get_current_device from colossalai.utils import conditional_context
from colossalai.zero import LowLevelZeroOptimizer from colossalai.zero import LowLevelZeroOptimizer
...@@ -28,7 +29,7 @@ class MlpModel(nn.Module): ...@@ -28,7 +29,7 @@ class MlpModel(nn.Module):
def exam_zero_1_2_grad_acc(): def exam_zero_1_2_grad_acc():
local_rank = torch.distributed.get_rank() local_rank = torch.distributed.get_rank()
seed_all(2009) seed_all(2009)
device = get_current_device() device = get_accelerator().get_current_device()
# create model # create model
zero1_model = MlpModel().to(device) zero1_model = MlpModel().to(device)
zero2_model = copy.deepcopy(zero1_model) zero2_model = copy.deepcopy(zero1_model)
...@@ -71,7 +72,7 @@ def exam_zero_1_2_grad_acc(): ...@@ -71,7 +72,7 @@ def exam_zero_1_2_grad_acc():
def exam_zero_1_grad_acc(sync): def exam_zero_1_grad_acc(sync):
local_rank = torch.distributed.get_rank() local_rank = torch.distributed.get_rank()
seed_all(2008) seed_all(2008)
device = get_current_device() device = get_accelerator().get_current_device()
# create models # create models
zero_model = MlpModel() zero_model = MlpModel()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment