Unverified Commit d202cc28 authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[npu] change device to accelerator api (#5239)



* update accelerator

* fix timer

* fix amp

* update

* fix

* update bug

* add error raise

* fix autocast

* fix set device

* remove doc accelerator

* update doc

* update doc

* update doc

* use nullcontext

* update cpu

* update null context

* change time limit for example

* udpate

* update

* update

* update

* [npu] polish accelerator code

---------
Co-authored-by: default avatarXuanlei Zhao <xuanlei.zhao@gmail.com>
Co-authored-by: default avatarzxl <43881818+oahzxl@users.noreply.github.com>
parent dd2c28a3
...@@ -12,6 +12,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler ...@@ -12,6 +12,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from colossalai.accelerator import get_accelerator
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO
from colossalai.checkpoint_io.utils import ( from colossalai.checkpoint_io.utils import (
get_optimizer_base_filenames, get_optimizer_base_filenames,
...@@ -24,7 +25,6 @@ from colossalai.checkpoint_io.utils import ( ...@@ -24,7 +25,6 @@ from colossalai.checkpoint_io.utils import (
sharded_optimizer_loading_epilogue, sharded_optimizer_loading_epilogue,
) )
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.utils import get_current_device
from colossalai.zero import LowLevelZeroOptimizer from colossalai.zero import LowLevelZeroOptimizer
from .dp_plugin_base import DPPluginBase from .dp_plugin_base import DPPluginBase
...@@ -52,7 +52,7 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin): ...@@ -52,7 +52,7 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
self.dtype = torch.bfloat16 self.dtype = torch.bfloat16
if self.dtype is not None: if self.dtype is not None:
module = module.to(self.dtype) module = module.to(self.dtype)
module = module.to(get_current_device()) module = module.to(get_accelerator().get_current_device())
self.module = module self.module = module
self.convert_fn = None self.convert_fn = None
if self.dtype is not None: if self.dtype is not None:
......
...@@ -6,12 +6,12 @@ import warnings ...@@ -6,12 +6,12 @@ import warnings
from pathlib import Path from pathlib import Path
from typing import Dict, Union from typing import Dict, Union
import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.accelerator import get_accelerator
from colossalai.context import Config from colossalai.context import Config
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import IS_NPU_AVAILABLE, set_device, set_seed from colossalai.utils import set_seed
def launch( def launch(
...@@ -47,17 +47,18 @@ def launch( ...@@ -47,17 +47,18 @@ def launch(
if rank == 0: if rank == 0:
warnings.warn("`config` is deprecated and will be removed soon.") warnings.warn("`config` is deprecated and will be removed soon.")
if IS_NPU_AVAILABLE and backend == "nccl": cur_accelerator = get_accelerator()
backend = "hccl"
backend = cur_accelerator.communication_backend
# init default process group # init default process group
init_method = f"tcp://[{host}]:{port}" init_method = f"tcp://[{host}]:{port}"
dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method) dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method)
# set cuda device # set cuda device
if torch.cuda.is_available() or IS_NPU_AVAILABLE: # if local rank is not given, calculate automatically
# if local rank is not given, calculate automatically if cur_accelerator.support_set_device:
set_device(local_rank) cur_accelerator.set_device(local_rank)
set_seed(seed) set_seed(seed)
......
...@@ -6,7 +6,7 @@ import torch ...@@ -6,7 +6,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from colossalai.utils.device import get_current_device from colossalai.accelerator import get_accelerator
class Unpad(torch.autograd.Function): class Unpad(torch.autograd.Function):
...@@ -70,7 +70,9 @@ class SeqLenInfo: ...@@ -70,7 +70,9 @@ class SeqLenInfo:
cu_seqlens: torch.Tensor = None cu_seqlens: torch.Tensor = None
@staticmethod @staticmethod
def materialize(attn_mask: torch.Tensor = None, size: Tuple[int] = None, device=get_current_device()): def materialize(
attn_mask: torch.Tensor = None, size: Tuple[int] = None, device=get_accelerator().get_current_device()
):
if attn_mask is not None: if attn_mask is not None:
indices = torch.nonzero(attn_mask.flatten(), as_tuple=False).flatten().to(device) indices = torch.nonzero(attn_mask.flatten(), as_tuple=False).flatten().to(device)
seqlens = attn_mask.sum(dim=-1, dtype=torch.int32).flatten() seqlens = attn_mask.sum(dim=-1, dtype=torch.int32).flatten()
......
import torch import torch
from colossalai.accelerator import get_accelerator
from colossalai.legacy.nn.layer.colossalai_layer import Embedding, Linear from colossalai.legacy.nn.layer.colossalai_layer import Embedding, Linear
from colossalai.utils import get_current_device
from .bias_dropout_add import bias_dropout_add_fused_train from .bias_dropout_add import bias_dropout_add_fused_train
from .bias_gelu import bias_gelu_impl from .bias_gelu import bias_gelu_impl
...@@ -46,11 +46,13 @@ def warmup_jit_fusion( ...@@ -46,11 +46,13 @@ def warmup_jit_fusion(
): ):
"""Compile JIT functions before the main training steps""" """Compile JIT functions before the main training steps"""
embed = Embedding(vocab_size, hidden_size).to(get_current_device()) embed = Embedding(vocab_size, hidden_size).to(get_accelerator().get_current_device())
linear_1 = Linear(hidden_size, hidden_size * 4, skip_bias_add=True).to(get_current_device()) linear_1 = Linear(hidden_size, hidden_size * 4, skip_bias_add=True).to(get_accelerator().get_current_device())
linear_2 = Linear(hidden_size * 4, hidden_size, skip_bias_add=True).to(get_current_device()) linear_2 = Linear(hidden_size * 4, hidden_size, skip_bias_add=True).to(get_accelerator().get_current_device())
x = torch.randint(vocab_size, (batch_size, seq_length), dtype=torch.long, device=get_current_device()) x = torch.randint(
vocab_size, (batch_size, seq_length), dtype=torch.long, device=get_accelerator().get_current_device()
)
x = embed(x) x = embed(x)
y, y_bias = linear_1(x) y, y_bias = linear_1(x)
z, z_bias = linear_2(y) z, z_bias = linear_2(y)
...@@ -58,8 +60,8 @@ def warmup_jit_fusion( ...@@ -58,8 +60,8 @@ def warmup_jit_fusion(
# prop and recomputation # prop and recomputation
for bias_grad, input_grad in zip([True, True], [False, True]): for bias_grad, input_grad in zip([True, True], [False, True]):
for _ in range(10): for _ in range(10):
bias = torch.rand_like(y_bias, dtype=dtype, device=get_current_device()) bias = torch.rand_like(y_bias, dtype=dtype, device=get_accelerator().get_current_device())
input_ = torch.rand_like(y, dtype=dtype, device=get_current_device()) input_ = torch.rand_like(y, dtype=dtype, device=get_accelerator().get_current_device())
bias.requires_grad, input_.requires_grad = bias_grad, input_grad bias.requires_grad, input_.requires_grad = bias_grad, input_grad
bias_gelu_impl(input_, bias) bias_gelu_impl(input_, bias)
...@@ -69,9 +71,9 @@ def warmup_jit_fusion( ...@@ -69,9 +71,9 @@ def warmup_jit_fusion(
# prop and recomputation # prop and recomputation
for input_grad, bias_grad, residual_grad in zip([False, True], [True, True], [True, True]): for input_grad, bias_grad, residual_grad in zip([False, True], [True, True], [True, True]):
for _ in range(10): for _ in range(10):
input_ = torch.rand_like(z, dtype=dtype, device=get_current_device()) input_ = torch.rand_like(z, dtype=dtype, device=get_accelerator().get_current_device())
residual = torch.rand_like(x, dtype=dtype, device=get_current_device()) residual = torch.rand_like(x, dtype=dtype, device=get_accelerator().get_current_device())
bias = torch.rand_like(z_bias, dtype=dtype, device=get_current_device()) bias = torch.rand_like(z_bias, dtype=dtype, device=get_accelerator().get_current_device())
input_.requires_grad = input_grad input_.requires_grad = input_grad
bias.requires_grad = bias_grad bias.requires_grad = bias_grad
residual.requires_grad = residual_grad residual.requires_grad = residual_grad
......
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
from colossalai.utils.device import autocast
import torch.nn as nn import torch.nn as nn
from torch import Tensor from torch import Tensor
from torch.nn.modules.loss import _Loss from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer from torch.optim import Optimizer
from colossalai.accelerator import get_accelerator
from colossalai.interface import OptimizerWrapper from colossalai.interface import OptimizerWrapper
from colossalai.legacy.utils import clip_grad_norm_fp32 from colossalai.legacy.utils import clip_grad_norm_fp32
from ._grad_scaler import GradScaler from ._grad_scaler import GradScaler
autocast = get_accelerator().autocast
class TorchAMPOptimizer(OptimizerWrapper): class TorchAMPOptimizer(OptimizerWrapper):
"""A wrapper class which integrate Pytorch AMP with an optimizer """A wrapper class which integrate Pytorch AMP with an optimizer
......
...@@ -8,9 +8,9 @@ from typing import List, Tuple, Union ...@@ -8,9 +8,9 @@ from typing import List, Tuple, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.accelerator import get_accelerator
from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.legacy.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
from colossalai.utils import get_current_device
from .utils import gather_split_1d_tensor, split_tensor_into_1d_equal_chunks from .utils import gather_split_1d_tensor, split_tensor_into_1d_equal_chunks
...@@ -43,12 +43,16 @@ def _get_tensor_shape(tensor_shape: TensorShape, chunk_tensor: bool = False) -> ...@@ -43,12 +43,16 @@ def _get_tensor_shape(tensor_shape: TensorShape, chunk_tensor: bool = False) ->
def create_recv_buffer_with_shapes(recv_shapes, dtype, scatter_gather_tensors): def create_recv_buffer_with_shapes(recv_shapes, dtype, scatter_gather_tensors):
if isinstance(recv_shapes, torch.Size): if isinstance(recv_shapes, torch.Size):
recv_chunk_shape, recv_split = _get_tensor_shape(recv_shapes, scatter_gather_tensors) recv_chunk_shape, recv_split = _get_tensor_shape(recv_shapes, scatter_gather_tensors)
buffer_recv = torch.empty(recv_chunk_shape, requires_grad=True, device=get_current_device(), dtype=dtype) buffer_recv = torch.empty(
recv_chunk_shape, requires_grad=True, device=get_accelerator().get_current_device(), dtype=dtype
)
return buffer_recv, recv_split return buffer_recv, recv_split
buffer_recv = [] buffer_recv = []
for recv_shape in recv_shapes: for recv_shape in recv_shapes:
recv_chunk_shape, recv_split = _get_tensor_shape(recv_shape, scatter_gather_tensors) recv_chunk_shape, recv_split = _get_tensor_shape(recv_shape, scatter_gather_tensors)
tensor_recv = torch.empty(recv_chunk_shape, requires_grad=True, device=get_current_device(), dtype=dtype) tensor_recv = torch.empty(
recv_chunk_shape, requires_grad=True, device=get_accelerator().get_current_device(), dtype=dtype
)
buffer_recv.append(tensor_recv) buffer_recv.append(tensor_recv)
return buffer_recv, recv_split return buffer_recv, recv_split
......
...@@ -3,9 +3,9 @@ ...@@ -3,9 +3,9 @@
import torch import torch
from colossalai.accelerator import get_accelerator
from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.legacy.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
from colossalai.utils import get_current_device, synchronize
def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode) -> torch.Tensor: def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode) -> torch.Tensor:
...@@ -29,7 +29,7 @@ def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode) -> ...@@ -29,7 +29,7 @@ def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode) ->
current_rank = gpc.get_global_rank() current_rank = gpc.get_global_rank()
tensor_recv_prev = torch.empty( tensor_recv_prev = torch.empty(
buffer_shape, requires_grad=True, device=get_current_device(), dtype=tensor_send_next.dtype buffer_shape, requires_grad=True, device=get_accelerator().get_current_device(), dtype=tensor_send_next.dtype
) )
# send to next rank # send to next rank
...@@ -52,6 +52,6 @@ def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode) -> ...@@ -52,6 +52,6 @@ def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode) ->
req.wait() req.wait()
# To protect against race condition when using batch_isend_irecv(). # To protect against race condition when using batch_isend_irecv().
synchronize() get_accelerator().synchronize()
return tensor_recv_prev return tensor_recv_prev
...@@ -3,9 +3,9 @@ from typing import List, Tuple, Union ...@@ -3,9 +3,9 @@ from typing import List, Tuple, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.accelerator import get_accelerator
from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.legacy.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
from colossalai.utils import get_current_device
TensorShape = Union[torch.Size, List[int], Tuple[int]] TensorShape = Union[torch.Size, List[int], Tuple[int]]
...@@ -35,7 +35,7 @@ def send_obj_meta(obj, need_meta=True, next_rank=None) -> bool: ...@@ -35,7 +35,7 @@ def send_obj_meta(obj, need_meta=True, next_rank=None) -> bool:
if next_rank is None: if next_rank is None:
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE) next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
tensor_kwargs = {"dtype": torch.long, "device": get_current_device()} tensor_kwargs = {"dtype": torch.long, "device": get_accelerator().get_current_device()}
if isinstance(obj, torch.Tensor): if isinstance(obj, torch.Tensor):
send_obj_nums = torch.tensor(1, **tensor_kwargs) send_obj_nums = torch.tensor(1, **tensor_kwargs)
dist.send(send_obj_nums, next_rank) dist.send(send_obj_nums, next_rank)
...@@ -74,7 +74,7 @@ def recv_obj_meta(obj_shape, prev_rank=None) -> torch.Size: ...@@ -74,7 +74,7 @@ def recv_obj_meta(obj_shape, prev_rank=None) -> torch.Size:
if prev_rank is None: if prev_rank is None:
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE) prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
tensor_kwargs = {"dtype": torch.long, "device": get_current_device()} tensor_kwargs = {"dtype": torch.long, "device": get_accelerator().get_current_device()}
recv_obj_nums = torch.empty((), **tensor_kwargs) recv_obj_nums = torch.empty((), **tensor_kwargs)
dist.recv(recv_obj_nums, prev_rank) dist.recv(recv_obj_nums, prev_rank)
if recv_obj_nums.item() == 1: if recv_obj_nums.item() == 1:
......
...@@ -6,8 +6,8 @@ from typing import Callable, Iterable ...@@ -6,8 +6,8 @@ from typing import Callable, Iterable
import torch import torch
from colossalai.accelerator import get_accelerator
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device
class BaseSchedule(ABC): class BaseSchedule(ABC):
...@@ -29,12 +29,12 @@ class BaseSchedule(ABC): ...@@ -29,12 +29,12 @@ class BaseSchedule(ABC):
def _move_tensor(element): def _move_tensor(element):
if torch.is_tensor(element): if torch.is_tensor(element):
if not element.is_cuda: if not element.is_cuda:
return element.to(get_current_device()).detach() return element.to(get_accelerator().get_current_device()).detach()
return element return element
def _move_to_device(self, data): def _move_to_device(self, data):
if isinstance(data, torch.Tensor): if isinstance(data, torch.Tensor):
data = data.to(get_current_device()) data = data.to(get_accelerator().get_current_device())
elif isinstance(data, (list, tuple)): elif isinstance(data, (list, tuple)):
data_to_return = [] data_to_return = []
for element in data: for element in data:
......
...@@ -7,12 +7,12 @@ from typing import Callable, List, Tuple, Union ...@@ -7,12 +7,12 @@ from typing import Callable, List, Tuple, Union
import torch.cuda import torch.cuda
import colossalai.legacy.communication as comm import colossalai.legacy.communication as comm
from colossalai.accelerator import get_accelerator
from colossalai.legacy.amp.naive_amp import NaiveAMPModel from colossalai.legacy.amp.naive_amp import NaiveAMPModel
from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.legacy.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.utils import switch_virtual_pipeline_parallel_rank from colossalai.legacy.utils import switch_virtual_pipeline_parallel_rank
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils.device import get_current_device
from ._base_schedule import BaseSchedule from ._base_schedule import BaseSchedule
...@@ -352,7 +352,7 @@ class PipelineSchedule(BaseSchedule): ...@@ -352,7 +352,7 @@ class PipelineSchedule(BaseSchedule):
output_objs = [] output_objs = []
return_tensors = [] return_tensors = []
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True): if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True):
accum_loss = torch.zeros(1, device=get_current_device()) accum_loss = torch.zeros(1, device=get_accelerator().get_current_device())
else: else:
accum_loss = None accum_loss = None
# Used for tensor meta information communication # Used for tensor meta information communication
...@@ -584,7 +584,7 @@ class InterleavedPipelineSchedule(PipelineSchedule): ...@@ -584,7 +584,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
if not forward_only: if not forward_only:
output_obj_grads = [[] for _ in range(len(model))] output_obj_grads = [[] for _ in range(len(model))]
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True): if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True):
accum_loss = torch.zeros(1, device=get_current_device()) accum_loss = torch.zeros(1, device=get_accelerator().get_current_device())
else: else:
accum_loss = None accum_loss = None
......
...@@ -6,10 +6,10 @@ from typing import Iterable, Tuple ...@@ -6,10 +6,10 @@ from typing import Iterable, Tuple
import torch.cuda import torch.cuda
import colossalai.legacy.communication.p2p_v2 as comm import colossalai.legacy.communication.p2p_v2 as comm
from colossalai.accelerator import get_accelerator
from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.legacy.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.engine import Engine from colossalai.legacy.engine import Engine
from colossalai.utils.device import get_current_device
from ._pipeline_schedule import PipelineSchedule from ._pipeline_schedule import PipelineSchedule
...@@ -99,7 +99,7 @@ class PipelineScheduleV2(PipelineSchedule): ...@@ -99,7 +99,7 @@ class PipelineScheduleV2(PipelineSchedule):
output_objs = [] output_objs = []
return_tensors = [] return_tensors = []
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True): if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True):
accum_loss = torch.zeros(1, device=get_current_device()) accum_loss = torch.zeros(1, device=get_accelerator().get_current_device())
else: else:
accum_loss = None accum_loss = None
......
...@@ -15,6 +15,7 @@ from torch.optim.lr_scheduler import _LRScheduler ...@@ -15,6 +15,7 @@ from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.optimizer import Optimizer from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from colossalai.accelerator import get_accelerator
from colossalai.context import Config, ConfigException from colossalai.context import Config, ConfigException
from colossalai.interface import OptimizerWrapper from colossalai.interface import OptimizerWrapper
from colossalai.legacy.amp import AMP_TYPE, convert_to_amp from colossalai.legacy.amp import AMP_TYPE, convert_to_amp
...@@ -34,7 +35,6 @@ from colossalai.legacy.utils import is_using_ddp, is_using_pp, is_using_sequence ...@@ -34,7 +35,6 @@ from colossalai.legacy.utils import is_using_ddp, is_using_pp, is_using_sequence
from colossalai.legacy.zero import ShardedOptimizerV2, convert_to_zero_v2 from colossalai.legacy.zero import ShardedOptimizerV2, convert_to_zero_v2
from colossalai.legacy.zero.gemini.ophooks import BaseOpHook from colossalai.legacy.zero.gemini.ophooks import BaseOpHook
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device
def get_default_parser(): def get_default_parser():
...@@ -309,9 +309,9 @@ def initialize( ...@@ -309,9 +309,9 @@ def initialize(
else: else:
if isinstance(model, nn.Module): if isinstance(model, nn.Module):
# first sync model across dp ranks # first sync model across dp ranks
model.to(get_current_device()) model.to(get_accelerator().get_current_device())
elif isinstance(model, Callable): elif isinstance(model, Callable):
model = model().to(get_current_device()) model = model().to(get_accelerator().get_current_device())
# optimizer maybe a optimizer_cls # optimizer maybe a optimizer_cls
if isinstance(optimizer, Callable): if isinstance(optimizer, Callable):
......
...@@ -3,8 +3,8 @@ from typing import Callable ...@@ -3,8 +3,8 @@ from typing import Callable
from torch import dtype, nn from torch import dtype, nn
from colossalai.accelerator import get_accelerator
from colossalai.nn import init from colossalai.nn import init
from colossalai.utils import get_current_device
from ..parallel_1d import Embedding1D, PatchEmbedding1D, VocabParallelEmbedding1D from ..parallel_1d import Embedding1D, PatchEmbedding1D, VocabParallelEmbedding1D
from ..parallel_2d import Embedding2D, PatchEmbedding2D, VocabParallelEmbedding2D from ..parallel_2d import Embedding2D, PatchEmbedding2D, VocabParallelEmbedding2D
...@@ -83,7 +83,7 @@ class Embedding(ColossalaiModule): ...@@ -83,7 +83,7 @@ class Embedding(ColossalaiModule):
embed = ( embed = (
nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx, *args, **kwargs) nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx, *args, **kwargs)
.to(dtype) .to(dtype)
.to(get_current_device()) .to(get_accelerator().get_current_device())
) )
weight_initializer(embed.weight, fan_in=num_embeddings, fan_out=embedding_dim) weight_initializer(embed.weight, fan_in=num_embeddings, fan_out=embedding_dim)
elif num_embeddings <= vocab_parallel_limit: elif num_embeddings <= vocab_parallel_limit:
......
from torch import nn from torch import nn
from colossalai.utils import get_current_device from colossalai.accelerator import get_accelerator
from ..parallel_1d import LayerNorm1D from ..parallel_1d import LayerNorm1D
from ..parallel_2d import LayerNorm2D from ..parallel_2d import LayerNorm2D
...@@ -36,7 +36,7 @@ class LayerNorm(ColossalaiModule): ...@@ -36,7 +36,7 @@ class LayerNorm(ColossalaiModule):
def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None) -> None: def __init__(self, normalized_shape: int, eps=1e-05, bias=True, dtype=None) -> None:
tensor_parallel = get_tensor_parallel_mode() tensor_parallel = get_tensor_parallel_mode()
if tensor_parallel is None: if tensor_parallel is None:
norm = nn.LayerNorm(normalized_shape, eps=eps).to(dtype).to(get_current_device()) norm = nn.LayerNorm(normalized_shape, eps=eps).to(dtype).to(get_accelerator().get_current_device())
else: else:
norm = _parallel_layernorm[tensor_parallel](normalized_shape, eps=eps, dtype=dtype) norm = _parallel_layernorm[tensor_parallel](normalized_shape, eps=eps, dtype=dtype)
super().__init__(norm) super().__init__(norm)
...@@ -10,6 +10,7 @@ import torch.nn.functional as F ...@@ -10,6 +10,7 @@ import torch.nn.functional as F
from torch import Tensor from torch import Tensor
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from colossalai.accelerator import get_accelerator
from colossalai.kernel import LayerNorm from colossalai.kernel import LayerNorm
from colossalai.legacy.communication import broadcast from colossalai.legacy.communication import broadcast
from colossalai.legacy.context import ParallelMode, seed from colossalai.legacy.context import ParallelMode, seed
...@@ -22,7 +23,6 @@ from colossalai.legacy.utils.checkpointing import ( ...@@ -22,7 +23,6 @@ from colossalai.legacy.utils.checkpointing import (
partition_tensor_parallel_state_dict, partition_tensor_parallel_state_dict,
) )
from colossalai.nn import init as init from colossalai.nn import init as init
from colossalai.utils.device import get_current_device
from ..base_layer import ParallelLayer from ..base_layer import ParallelLayer
from ..colossalai_layer._utils import ColossalaiModule from ..colossalai_layer._utils import ColossalaiModule
...@@ -221,7 +221,7 @@ class Classifier1D(ParallelLayer): ...@@ -221,7 +221,7 @@ class Classifier1D(ParallelLayer):
# Parameters. # Parameters.
# Initialize weight. # Initialize weight.
factory_kwargs = {"device": get_current_device(), "dtype": dtype} factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
if weight is not None: if weight is not None:
self.weight = weight self.weight = weight
self.has_weight = False self.has_weight = False
...@@ -357,7 +357,7 @@ class VocabParallelClassifier1D(ParallelLayer): ...@@ -357,7 +357,7 @@ class VocabParallelClassifier1D(ParallelLayer):
# Parameters. # Parameters.
# Initialize weight. # Initialize weight.
factory_kwargs = {"device": get_current_device(), "dtype": dtype} factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
if weight is not None: if weight is not None:
self.weight = weight self.weight = weight
self.has_weight = False self.has_weight = False
...@@ -499,7 +499,7 @@ class Linear1D_Col(ParallelLayer): ...@@ -499,7 +499,7 @@ class Linear1D_Col(ParallelLayer):
# Parameters. # Parameters.
# Initialize weight. # Initialize weight.
factory_kwargs = {"device": get_current_device(), "dtype": dtype} factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
self.weight = Parameter(torch.empty(self.out_features_per_partition, self.in_features, **factory_kwargs)) self.weight = Parameter(torch.empty(self.out_features_per_partition, self.in_features, **factory_kwargs))
if bias: if bias:
...@@ -638,7 +638,7 @@ class Linear1D_Row(ParallelLayer): ...@@ -638,7 +638,7 @@ class Linear1D_Row(ParallelLayer):
# Parameters. # Parameters.
# Initialize weight. # Initialize weight.
factory_kwargs = {"device": get_current_device(), "dtype": dtype} factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs)) self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs))
if self.stream_chunk_num > 1: if self.stream_chunk_num > 1:
...@@ -802,7 +802,9 @@ class Embedding1D(ParallelLayer): ...@@ -802,7 +802,9 @@ class Embedding1D(ParallelLayer):
self.embed_kwargs = kwargs self.embed_kwargs = kwargs
self.weight = Parameter( self.weight = Parameter(
torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype) torch.empty(
(num_embeddings, embed_dim_per_partition), device=get_accelerator().get_current_device(), dtype=dtype
)
) )
self.reset_parameters(weight_initializer) self.reset_parameters(weight_initializer)
...@@ -912,7 +914,11 @@ class VocabParallelEmbedding1D(ParallelLayer): ...@@ -912,7 +914,11 @@ class VocabParallelEmbedding1D(ParallelLayer):
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition
self.weight = Parameter( self.weight = Parameter(
torch.empty((self.num_embeddings_per_partition, self.embed_dim), device=get_current_device(), dtype=dtype) torch.empty(
(self.num_embeddings_per_partition, self.embed_dim),
device=get_accelerator().get_current_device(),
dtype=dtype,
)
) )
self.reset_parameters(weight_initializer) self.reset_parameters(weight_initializer)
......
...@@ -5,10 +5,10 @@ import torch.distributed as dist ...@@ -5,10 +5,10 @@ import torch.distributed as dist
from torch import Tensor from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd from torch.cuda.amp import custom_bwd, custom_fwd
from colossalai.accelerator import get_accelerator
from colossalai.legacy.communication.collective import all_gather, all_reduce, reduce_scatter from colossalai.legacy.communication.collective import all_gather, all_reduce, reduce_scatter
from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.legacy.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
from colossalai.utils import get_current_device
def matmul_2d( def matmul_2d(
...@@ -250,7 +250,7 @@ class Matmul_AB_2D(torch.autograd.Function): ...@@ -250,7 +250,7 @@ class Matmul_AB_2D(torch.autograd.Function):
B_shape = B.shape B_shape = B.shape
B = B.reshape((-1, B_shape[-1])) B = B.reshape((-1, B_shape[-1]))
C_shape = (A.shape[0], B.shape[-1]) C_shape = (A.shape[0], B.shape[-1])
C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device()) C = torch.zeros(C_shape, dtype=A.dtype, device=get_accelerator().get_current_device())
# use circular buffer to store the communication tensor # use circular buffer to store the communication tensor
# 2 is enough for all cases # 2 is enough for all cases
...@@ -399,7 +399,7 @@ class Matmul_ABT_2D(torch.autograd.Function): ...@@ -399,7 +399,7 @@ class Matmul_ABT_2D(torch.autograd.Function):
B_shape = B.shape B_shape = B.shape
B = B.reshape((-1, B_shape[-1])) B = B.reshape((-1, B_shape[-1]))
C_shape = (A.shape[0], B.shape[0]) C_shape = (A.shape[0], B.shape[0])
C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device()) C = torch.empty(C_shape, dtype=A.dtype, device=get_accelerator().get_current_device())
# use circular buffer to store the communication tensor # use circular buffer to store the communication tensor
# 2 is enough for all cases # 2 is enough for all cases
...@@ -556,7 +556,7 @@ class Matmul_ATB_2D(torch.autograd.Function): ...@@ -556,7 +556,7 @@ class Matmul_ATB_2D(torch.autograd.Function):
B_shape = B.shape B_shape = B.shape
B = B.reshape((-1, B_shape[-1])) B = B.reshape((-1, B_shape[-1]))
C_shape = (A.shape[-1], B.shape[-1]) C_shape = (A.shape[-1], B.shape[-1])
C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device()) C = torch.empty(C_shape, dtype=A.dtype, device=get_accelerator().get_current_device())
# use circular buffer to store the communication tensor # use circular buffer to store the communication tensor
# 2 is enough for all cases # 2 is enough for all cases
......
...@@ -8,6 +8,7 @@ import torch.nn.functional as F ...@@ -8,6 +8,7 @@ import torch.nn.functional as F
from torch import Tensor from torch import Tensor
from torch.nn import Parameter from torch.nn import Parameter
from colossalai.accelerator import get_accelerator
from colossalai.legacy.communication import broadcast from colossalai.legacy.communication import broadcast
from colossalai.legacy.context import ParallelMode, seed from colossalai.legacy.context import ParallelMode, seed
from colossalai.legacy.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
...@@ -18,7 +19,6 @@ from colossalai.legacy.utils.checkpointing import ( ...@@ -18,7 +19,6 @@ from colossalai.legacy.utils.checkpointing import (
partition_tensor_parallel_state_dict, partition_tensor_parallel_state_dict,
) )
from colossalai.nn import init as init from colossalai.nn import init as init
from colossalai.utils.device import get_current_device
from ..base_layer import ParallelLayer from ..base_layer import ParallelLayer
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
...@@ -82,7 +82,7 @@ class Linear2D(ParallelLayer): ...@@ -82,7 +82,7 @@ class Linear2D(ParallelLayer):
self.hidden_size_per_partition = divide(self.out_features, self.summa_dim) self.hidden_size_per_partition = divide(self.out_features, self.summa_dim)
# create weight, shape: [k/q, h/q] # create weight, shape: [k/q, h/q]
factory_kwargs = {"device": get_current_device(), "dtype": dtype} factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
self.weight = Parameter( self.weight = Parameter(
torch.empty(self.input_size_per_partition, self.hidden_size_per_partition, **factory_kwargs) torch.empty(self.input_size_per_partition, self.hidden_size_per_partition, **factory_kwargs)
) )
...@@ -259,7 +259,7 @@ class LayerNorm2D(ParallelLayer): ...@@ -259,7 +259,7 @@ class LayerNorm2D(ParallelLayer):
self.partitioned_partition = divide(normalized_shape, self.summa_dim**2) self.partitioned_partition = divide(normalized_shape, self.summa_dim**2)
# create parameters # create parameters
factory_kwargs = {"device": get_current_device(), "dtype": dtype} factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
self.weight = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs)) self.weight = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs))
if bias: if bias:
...@@ -438,18 +438,24 @@ class PatchEmbedding2D(ParallelLayer): ...@@ -438,18 +438,24 @@ class PatchEmbedding2D(ParallelLayer):
self.weight = Parameter( self.weight = Parameter(
torch.empty( torch.empty(
(self.embed_size_per_partition, in_chans, *self.patch_size), (self.embed_size_per_partition, in_chans, *self.patch_size),
device=get_current_device(), device=get_accelerator().get_current_device(),
dtype=dtype, dtype=dtype,
) )
) )
self.bias = Parameter(torch.empty(self.embed_size_per_partition, device=get_current_device(), dtype=dtype)) self.bias = Parameter(
torch.empty(self.embed_size_per_partition, device=get_accelerator().get_current_device(), dtype=dtype)
)
self.cls_token = Parameter( self.cls_token = Parameter(
torch.zeros((1, 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype) torch.zeros(
(1, 1, self.embed_size_per_partition), device=get_accelerator().get_current_device(), dtype=dtype
)
) )
self.pos_embed = Parameter( self.pos_embed = Parameter(
torch.zeros( torch.zeros(
(1, self.num_patches + 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype (1, self.num_patches + 1, self.embed_size_per_partition),
device=get_accelerator().get_current_device(),
dtype=dtype,
) )
) )
...@@ -619,7 +625,9 @@ class Embedding2D(ParallelLayer): ...@@ -619,7 +625,9 @@ class Embedding2D(ParallelLayer):
self.embed_kwargs = kwargs self.embed_kwargs = kwargs
self.weight = Parameter( self.weight = Parameter(
torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype) torch.empty(
(num_embeddings, embed_dim_per_partition), device=get_accelerator().get_current_device(), dtype=dtype
)
) )
self.reset_parameters(weight_initializer) self.reset_parameters(weight_initializer)
...@@ -758,7 +766,7 @@ class VocabParallelEmbedding2D(ParallelLayer): ...@@ -758,7 +766,7 @@ class VocabParallelEmbedding2D(ParallelLayer):
self.weight = Parameter( self.weight = Parameter(
torch.empty( torch.empty(
(self.num_embeddings_per_partition, self.embed_dim_per_partition), (self.num_embeddings_per_partition, self.embed_dim_per_partition),
device=get_current_device(), device=get_accelerator().get_current_device(),
dtype=dtype, dtype=dtype,
) )
) )
...@@ -895,11 +903,18 @@ class Classifier2D(ParallelLayer): ...@@ -895,11 +903,18 @@ class Classifier2D(ParallelLayer):
self.has_weight = False self.has_weight = False
else: else:
self.weight = Parameter( self.weight = Parameter(
torch.empty(self.num_classes, self.input_size_per_partition, device=get_current_device(), dtype=dtype) torch.empty(
self.num_classes,
self.input_size_per_partition,
device=get_accelerator().get_current_device(),
dtype=dtype,
)
) )
self.has_weight = True self.has_weight = True
if bias: if bias:
self.bias = Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype)) self.bias = Parameter(
torch.zeros(self.num_classes, device=get_accelerator().get_current_device(), dtype=dtype)
)
else: else:
self.bias = None self.bias = None
...@@ -1052,7 +1067,7 @@ class VocabParallelClassifier2D(ParallelLayer): ...@@ -1052,7 +1067,7 @@ class VocabParallelClassifier2D(ParallelLayer):
self.output_size_per_partition = divide(num_classes, self.summa_dim) self.output_size_per_partition = divide(num_classes, self.summa_dim)
# create weight, shape: [k/q, h/q] # create weight, shape: [k/q, h/q]
factory_kwargs = {"device": get_current_device(), "dtype": dtype} factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
if weight is not None: if weight is not None:
self.weight = weight self.weight = weight
self.has_weight = False self.has_weight = False
......
...@@ -5,10 +5,10 @@ import torch.distributed as dist ...@@ -5,10 +5,10 @@ import torch.distributed as dist
from torch import Tensor from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd from torch.cuda.amp import custom_bwd, custom_fwd
from colossalai.accelerator import get_accelerator
from colossalai.legacy.communication.collective import all_gather, all_reduce, reduce_scatter from colossalai.legacy.communication.collective import all_gather, all_reduce, reduce_scatter
from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.legacy.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
from colossalai.utils import get_current_device
def get_parallel_group(parallel_mode: ParallelMode): def get_parallel_group(parallel_mode: ParallelMode):
...@@ -205,7 +205,7 @@ class Matmul_AB_2p5D(torch.autograd.Function): ...@@ -205,7 +205,7 @@ class Matmul_AB_2p5D(torch.autograd.Function):
B_shape = B.shape B_shape = B.shape
B = B.reshape((-1, B_shape[-1])) B = B.reshape((-1, B_shape[-1]))
C_shape = (A.shape[0], B.shape[-1]) C_shape = (A.shape[0], B.shape[-1])
C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device()) C = torch.zeros(C_shape, dtype=A.dtype, device=get_accelerator().get_current_device())
# use circular buffer to store the communication tensor # use circular buffer to store the communication tensor
# 2 is enough for all cases # 2 is enough for all cases
...@@ -362,7 +362,7 @@ class Matmul_ABT_2p5D(torch.autograd.Function): ...@@ -362,7 +362,7 @@ class Matmul_ABT_2p5D(torch.autograd.Function):
B_shape = B.shape B_shape = B.shape
B = B.reshape((-1, B_shape[-1])) B = B.reshape((-1, B_shape[-1]))
C_shape = (A.shape[0], B.shape[0]) C_shape = (A.shape[0], B.shape[0])
C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device()) C = torch.empty(C_shape, dtype=A.dtype, device=get_accelerator().get_current_device())
# use circular buffer to store the communication tensor # use circular buffer to store the communication tensor
# 2 is enough for all cases # 2 is enough for all cases
...@@ -527,7 +527,7 @@ class Matmul_ATB_2p5D(torch.autograd.Function): ...@@ -527,7 +527,7 @@ class Matmul_ATB_2p5D(torch.autograd.Function):
B_shape = B.shape B_shape = B.shape
B = B.reshape((-1, B_shape[-1])) B = B.reshape((-1, B_shape[-1]))
C_shape = (A.shape[-1], B.shape[-1]) C_shape = (A.shape[-1], B.shape[-1])
C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device()) C = torch.empty(C_shape, dtype=A.dtype, device=get_accelerator().get_current_device())
# use circular buffer to store the communication tensor # use circular buffer to store the communication tensor
# 2 is enough for all cases # 2 is enough for all cases
...@@ -661,7 +661,9 @@ class _Add_Bias_2p5D(torch.autograd.Function): ...@@ -661,7 +661,9 @@ class _Add_Bias_2p5D(torch.autograd.Function):
if row_rank == 0: if row_rank == 0:
bias_temp = bias.clone() bias_temp = bias.clone()
else: else:
bias_temp = torch.zeros(output_size_per_partition, dtype=bias.dtype, device=get_current_device()) bias_temp = torch.zeros(
output_size_per_partition, dtype=bias.dtype, device=get_accelerator().get_current_device()
)
src_rank = ( src_rank = (
col_rank col_rank
+ dep_rank * tesseract_dim**2 + dep_rank * tesseract_dim**2
...@@ -984,7 +986,7 @@ class SplitFirst(torch.autograd.Function): ...@@ -984,7 +986,7 @@ class SplitFirst(torch.autograd.Function):
@custom_bwd @custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
grad_shape = (ctx.batch_size,) + output_grad.shape[1:] grad_shape = (ctx.batch_size,) + output_grad.shape[1:]
grad = torch.empty(grad_shape, dtype=output_grad.dtype, device=get_current_device()) grad = torch.empty(grad_shape, dtype=output_grad.dtype, device=get_accelerator().get_current_device())
dist.all_gather( dist.all_gather(
list(grad.chunk(ctx.tesseract_dim, dim=0)), output_grad.contiguous(), group=gpc.get_group(ctx.para_mode) list(grad.chunk(ctx.tesseract_dim, dim=0)), output_grad.contiguous(), group=gpc.get_group(ctx.para_mode)
) )
......
...@@ -8,6 +8,7 @@ import torch.nn.functional as F ...@@ -8,6 +8,7 @@ import torch.nn.functional as F
from torch import Tensor from torch import Tensor
from torch.nn import Parameter from torch.nn import Parameter
from colossalai.accelerator import get_accelerator
from colossalai.legacy.communication import broadcast from colossalai.legacy.communication import broadcast
from colossalai.legacy.context import ParallelMode, seed from colossalai.legacy.context import ParallelMode, seed
from colossalai.legacy.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
...@@ -19,7 +20,6 @@ from colossalai.legacy.utils.checkpointing import ( ...@@ -19,7 +20,6 @@ from colossalai.legacy.utils.checkpointing import (
partition_tensor_parallel_state_dict, partition_tensor_parallel_state_dict,
) )
from colossalai.nn import init as init from colossalai.nn import init as init
from colossalai.utils.device import get_current_device
from ..base_layer import ParallelLayer from ..base_layer import ParallelLayer
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
...@@ -84,7 +84,7 @@ class Linear2p5D(ParallelLayer): ...@@ -84,7 +84,7 @@ class Linear2p5D(ParallelLayer):
self.hidden_size_per_partition = divide(out_features, self.tesseract_dim) self.hidden_size_per_partition = divide(out_features, self.tesseract_dim)
# create weight, shape: [k/q, h/q] # create weight, shape: [k/q, h/q]
factory_kwargs = {"device": get_current_device(), "dtype": dtype} factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
self.weight = Parameter( self.weight = Parameter(
torch.empty(self.input_size_per_partition, self.hidden_size_per_partition, **factory_kwargs) torch.empty(self.input_size_per_partition, self.hidden_size_per_partition, **factory_kwargs)
) )
...@@ -272,7 +272,7 @@ class LayerNorm2p5D(ParallelLayer): ...@@ -272,7 +272,7 @@ class LayerNorm2p5D(ParallelLayer):
self.partitioned_partition = divide(normalized_shape, self.tesseract_dim) # * self.partitioned_partition = divide(normalized_shape, self.tesseract_dim) # *
# create parameters # create parameters
factory_kwargs = {"device": get_current_device(), "dtype": dtype} factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
self.weight = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs)) self.weight = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs))
if bias: if bias:
...@@ -451,18 +451,24 @@ class PatchEmbedding2p5D(ParallelLayer): ...@@ -451,18 +451,24 @@ class PatchEmbedding2p5D(ParallelLayer):
self.weight = Parameter( self.weight = Parameter(
torch.empty( torch.empty(
(self.embed_size_per_partition, in_chans, *self.patch_size), (self.embed_size_per_partition, in_chans, *self.patch_size),
device=get_current_device(), device=get_accelerator().get_current_device(),
dtype=dtype, dtype=dtype,
) )
) )
self.bias = Parameter(torch.empty(self.embed_size_per_partition, device=get_current_device(), dtype=dtype)) self.bias = Parameter(
torch.empty(self.embed_size_per_partition, device=get_accelerator().get_current_device(), dtype=dtype)
)
self.cls_token = Parameter( self.cls_token = Parameter(
torch.zeros((1, 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype) torch.zeros(
(1, 1, self.embed_size_per_partition), device=get_accelerator().get_current_device(), dtype=dtype
)
) )
self.pos_embed = Parameter( self.pos_embed = Parameter(
torch.zeros( torch.zeros(
(1, self.num_patches + 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype (1, self.num_patches + 1, self.embed_size_per_partition),
device=get_accelerator().get_current_device(),
dtype=dtype,
) )
) )
...@@ -632,7 +638,9 @@ class Embedding2p5D(ParallelLayer): ...@@ -632,7 +638,9 @@ class Embedding2p5D(ParallelLayer):
self.embed_kwargs = kwargs self.embed_kwargs = kwargs
self.weight = Parameter( self.weight = Parameter(
torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype) torch.empty(
(num_embeddings, embed_dim_per_partition), device=get_accelerator().get_current_device(), dtype=dtype
)
) )
self.reset_parameters(weight_initializer) self.reset_parameters(weight_initializer)
...@@ -772,7 +780,7 @@ class VocabParallelEmbedding2p5D(ParallelLayer): ...@@ -772,7 +780,7 @@ class VocabParallelEmbedding2p5D(ParallelLayer):
self.weight = Parameter( self.weight = Parameter(
torch.empty( torch.empty(
(self.num_embeddings_per_partition, self.embed_dim_per_partition), (self.num_embeddings_per_partition, self.embed_dim_per_partition),
device=get_current_device(), device=get_accelerator().get_current_device(),
dtype=dtype, dtype=dtype,
) )
) )
...@@ -910,11 +918,18 @@ class Classifier2p5D(ParallelLayer): ...@@ -910,11 +918,18 @@ class Classifier2p5D(ParallelLayer):
self.has_weight = False self.has_weight = False
else: else:
self.weight = Parameter( self.weight = Parameter(
torch.empty(self.num_classes, self.input_size_per_partition, device=get_current_device(), dtype=dtype) torch.empty(
self.num_classes,
self.input_size_per_partition,
device=get_accelerator().get_current_device(),
dtype=dtype,
)
) )
self.has_weight = True self.has_weight = True
if bias: if bias:
self.bias = Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype)) self.bias = Parameter(
torch.zeros(self.num_classes, device=get_accelerator().get_current_device(), dtype=dtype)
)
else: else:
self.bias = None self.bias = None
...@@ -1068,7 +1083,7 @@ class VocabParallelClassifier2p5D(ParallelLayer): ...@@ -1068,7 +1083,7 @@ class VocabParallelClassifier2p5D(ParallelLayer):
self.hidden_size_per_partition = divide(num_classes, self.tesseract_dim) self.hidden_size_per_partition = divide(num_classes, self.tesseract_dim)
# create weight, shape: [k/q, h/q] # create weight, shape: [k/q, h/q]
factory_kwargs = {"device": get_current_device(), "dtype": dtype} factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
if weight is not None: if weight is not None:
self.weight = weight self.weight = weight
self.has_weight = False self.has_weight = False
......
...@@ -8,6 +8,7 @@ import torch.nn.functional as F ...@@ -8,6 +8,7 @@ import torch.nn.functional as F
from torch import Tensor from torch import Tensor
from torch.nn import Parameter from torch.nn import Parameter
from colossalai.accelerator import get_accelerator
from colossalai.legacy.communication import all_reduce, broadcast from colossalai.legacy.communication import all_reduce, broadcast
from colossalai.legacy.constants import ( from colossalai.legacy.constants import (
INPUT_GROUP_3D, INPUT_GROUP_3D,
...@@ -27,7 +28,6 @@ from colossalai.legacy.utils.checkpointing import ( ...@@ -27,7 +28,6 @@ from colossalai.legacy.utils.checkpointing import (
partition_tensor_parallel_state_dict, partition_tensor_parallel_state_dict,
) )
from colossalai.nn import init as init from colossalai.nn import init as init
from colossalai.utils.device import get_current_device
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
from ._operation import ( from ._operation import (
...@@ -69,11 +69,13 @@ class LayerNorm3D(ParallelLayer): ...@@ -69,11 +69,13 @@ class LayerNorm3D(ParallelLayer):
self.normalized_shape_per_partition = divide(normalized_shape, self.depth) self.normalized_shape_per_partition = divide(normalized_shape, self.depth)
self.weight = Parameter( self.weight = Parameter(
torch.ones(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype) torch.ones(self.normalized_shape_per_partition, device=get_accelerator().get_current_device(), dtype=dtype)
) )
if bias: if bias:
self.bias = Parameter( self.bias = Parameter(
torch.zeros(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype) torch.zeros(
self.normalized_shape_per_partition, device=get_accelerator().get_current_device(), dtype=dtype
)
) )
else: else:
self.bias = None self.bias = None
...@@ -202,13 +204,15 @@ class Linear3D(ParallelLayer): ...@@ -202,13 +204,15 @@ class Linear3D(ParallelLayer):
torch.empty( torch.empty(
self.in_features_per_partition, self.in_features_per_partition,
self.out_features_per_partition, self.out_features_per_partition,
device=get_current_device(), device=get_accelerator().get_current_device(),
dtype=dtype, dtype=dtype,
) )
) )
if bias: if bias:
self.bias = Parameter( self.bias = Parameter(
torch.zeros(self.bias_features_per_partition, device=get_current_device(), dtype=dtype) torch.zeros(
self.bias_features_per_partition, device=get_accelerator().get_current_device(), dtype=dtype
)
) )
else: else:
self.bias = None self.bias = None
...@@ -380,11 +384,18 @@ class Classifier3D(ParallelLayer): ...@@ -380,11 +384,18 @@ class Classifier3D(ParallelLayer):
self.has_weight = False self.has_weight = False
else: else:
self.weight = Parameter( self.weight = Parameter(
torch.empty(self.num_classes, self.in_features_per_partition, device=get_current_device(), dtype=dtype) torch.empty(
self.num_classes,
self.in_features_per_partition,
device=get_accelerator().get_current_device(),
dtype=dtype,
)
) )
self.has_weight = True self.has_weight = True
if bias: if bias:
self.bias = Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype)) self.bias = Parameter(
torch.zeros(self.num_classes, device=get_accelerator().get_current_device(), dtype=dtype)
)
else: else:
self.bias = None self.bias = None
...@@ -523,14 +534,16 @@ class VocabParallelClassifier3D(ParallelLayer): ...@@ -523,14 +534,16 @@ class VocabParallelClassifier3D(ParallelLayer):
torch.empty( torch.empty(
self.out_features_per_partition, self.out_features_per_partition,
self.in_features_per_partition, self.in_features_per_partition,
device=get_current_device(), device=get_accelerator().get_current_device(),
dtype=dtype, dtype=dtype,
) )
) )
self.has_weight = True self.has_weight = True
if bias: if bias:
self.bias = Parameter( self.bias = Parameter(
torch.zeros(self.bias_features_per_partition, device=get_current_device(), dtype=dtype) torch.zeros(
self.bias_features_per_partition, device=get_accelerator().get_current_device(), dtype=dtype
)
) )
else: else:
self.bias = None self.bias = None
...@@ -705,16 +718,24 @@ class PatchEmbedding3D(ParallelLayer): ...@@ -705,16 +718,24 @@ class PatchEmbedding3D(ParallelLayer):
self.weight = nn.Parameter( self.weight = nn.Parameter(
torch.empty( torch.empty(
(embed_size_per_partition, in_chans, *self.patch_size), device=get_current_device(), dtype=dtype (embed_size_per_partition, in_chans, *self.patch_size),
device=get_accelerator().get_current_device(),
dtype=dtype,
) )
) )
self.bias = nn.Parameter(torch.empty(embed_size_per_partition, device=get_current_device(), dtype=dtype)) self.bias = nn.Parameter(
torch.empty(embed_size_per_partition, device=get_accelerator().get_current_device(), dtype=dtype)
)
self.cls_token = nn.Parameter( self.cls_token = nn.Parameter(
torch.zeros((1, 1, embed_size_per_partition), device=get_current_device(), dtype=dtype) torch.zeros((1, 1, embed_size_per_partition), device=get_accelerator().get_current_device(), dtype=dtype)
) )
self.pos_embed = nn.Parameter( self.pos_embed = nn.Parameter(
torch.zeros((1, self.num_patches + 1, embed_size_per_partition), device=get_current_device(), dtype=dtype) torch.zeros(
(1, self.num_patches + 1, embed_size_per_partition),
device=get_accelerator().get_current_device(),
dtype=dtype,
)
) )
self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer) self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer)
...@@ -880,7 +901,9 @@ class Embedding3D(ParallelLayer): ...@@ -880,7 +901,9 @@ class Embedding3D(ParallelLayer):
self.embed_kwargs = kwargs self.embed_kwargs = kwargs
self.weight = nn.Parameter( self.weight = nn.Parameter(
torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype) torch.empty(
(num_embeddings, embed_dim_per_partition), device=get_accelerator().get_current_device(), dtype=dtype
)
) )
self.reset_parameters(weight_initializer) self.reset_parameters(weight_initializer)
...@@ -1019,7 +1042,7 @@ class VocabParallelEmbedding3D(ParallelLayer): ...@@ -1019,7 +1042,7 @@ class VocabParallelEmbedding3D(ParallelLayer):
self.weight = Parameter( self.weight = Parameter(
torch.empty( torch.empty(
(self.num_embeddings_per_partition, self.embed_dim_per_partition), (self.num_embeddings_per_partition, self.embed_dim_per_partition),
device=get_current_device(), device=get_accelerator().get_current_device(),
dtype=dtype, dtype=dtype,
) )
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment