Unverified Commit d202cc28 authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[npu] change device to accelerator api (#5239)



* update accelerator

* fix timer

* fix amp

* update

* fix

* update bug

* add error raise

* fix autocast

* fix set device

* remove doc accelerator

* update doc

* update doc

* update doc

* use nullcontext

* update cpu

* update null context

* change time limit for example

* udpate

* update

* update

* update

* [npu] polish accelerator code

---------
Co-authored-by: default avatarXuanlei Zhao <xuanlei.zhao@gmail.com>
Co-authored-by: default avatarzxl <43881818+oahzxl@users.noreply.github.com>
parent dd2c28a3
......@@ -5,11 +5,11 @@ import torch
from torch import distributed as dist
from torch.cuda.amp import custom_bwd, custom_fwd
from colossalai.accelerator import get_accelerator
from colossalai.legacy.communication import ring_forward
from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.nn.layer.parallel_sequence._utils import _calc_current_device_range, _calc_incoming_device_range
from colossalai.utils import get_current_device
class RingQK(torch.autograd.Function):
......@@ -30,7 +30,7 @@ class RingQK(torch.autograd.Function):
sub_seq_length,
sub_seq_length * gpc.get_world_size(ParallelMode.SEQUENCE),
dtype=sub_q.dtype,
device=get_current_device(),
device=get_accelerator().get_current_device(),
)
# compute local QK^T
......@@ -71,7 +71,7 @@ class RingQK(torch.autograd.Function):
grad_q = torch.zeros_like(
sub_q,
dtype=sub_q.dtype,
device=get_current_device(),
device=get_accelerator().get_current_device(),
)
# compute with local sub_k
......@@ -105,7 +105,7 @@ class RingAV(torch.autograd.Function):
batch_size * num_attention_heads,
sub_seq_length,
attention_head_size,
device=get_current_device(),
device=get_accelerator().get_current_device(),
dtype=attention_score.dtype,
)
......@@ -142,7 +142,9 @@ class RingAV(torch.autograd.Function):
grad_v /= local_world_size
# calculate gradient for attention score
grad_attention_score = torch.zeros_like(attention_scores, dtype=grad_output.dtype, device=get_current_device())
grad_attention_score = torch.zeros_like(
attention_scores, dtype=grad_output.dtype, device=get_accelerator().get_current_device()
)
# compute with local sub_k
grad_attention_score[:, :, local_start_idx:local_end_idx] += torch.matmul(grad_output, sub_v.transpose(2, 1))
......
......@@ -7,10 +7,10 @@ from torch import Tensor
from torch import nn as nn
from torch.nn.parameter import Parameter
from colossalai.accelerator import get_accelerator
from colossalai.legacy.context import seed
from colossalai.legacy.registry import LAYERS
from colossalai.nn import init as init
from colossalai.utils.device import get_current_device
from ..utils import to_2tuple
......@@ -173,12 +173,18 @@ class VanillaPatchEmbedding(nn.Module):
self.flatten = flatten
self.weight = nn.Parameter(
torch.empty((embed_size, in_chans, *self.patch_size), device=get_current_device(), dtype=dtype)
torch.empty(
(embed_size, in_chans, *self.patch_size), device=get_accelerator().get_current_device(), dtype=dtype
)
)
self.bias = nn.Parameter(torch.empty(embed_size, device=get_accelerator().get_current_device(), dtype=dtype))
self.cls_token = nn.Parameter(
torch.zeros((1, 1, embed_size), device=get_accelerator().get_current_device(), dtype=dtype)
)
self.bias = nn.Parameter(torch.empty(embed_size, device=get_current_device(), dtype=dtype))
self.cls_token = nn.Parameter(torch.zeros((1, 1, embed_size), device=get_current_device(), dtype=dtype))
self.pos_embed = nn.Parameter(
torch.zeros((1, self.num_patches + 1, embed_size), device=get_current_device(), dtype=dtype)
torch.zeros(
(1, self.num_patches + 1, embed_size), device=get_accelerator().get_current_device(), dtype=dtype
)
)
self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer)
......@@ -242,11 +248,15 @@ class VanillaClassifier(nn.Module):
self.has_weight = False
else:
self.weight = nn.Parameter(
torch.empty(self.num_classes, self.in_features, device=get_current_device(), dtype=dtype)
torch.empty(
self.num_classes, self.in_features, device=get_accelerator().get_current_device(), dtype=dtype
)
)
self.has_weight = True
if bias:
self.bias = nn.Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype))
self.bias = nn.Parameter(
torch.zeros(self.num_classes, device=get_accelerator().get_current_device(), dtype=dtype)
)
else:
self.bias = None
......@@ -287,7 +297,7 @@ class VanillaLayerNorm(nn.Module):
self.normalized_shape = (normalized_shape,)
self.variance_epsilon = eps
factory_kwargs = {"device": get_current_device(), "dtype": dtype}
factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
self.weight = nn.Parameter(torch.ones(normalized_shape, **factory_kwargs))
if bias:
......@@ -333,7 +343,7 @@ class VanillaLinear(nn.Module):
self.in_features = in_features
self.out_features = out_features
self.skip_bias_add = skip_bias_add
factory_kwargs = {"device": get_current_device(), "dtype": dtype}
factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs))
if bias:
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
......
......@@ -4,12 +4,12 @@ from torch.cuda.amp import custom_bwd, custom_fwd
from torch.nn.functional import cross_entropy
from torch.nn.modules.loss import _Loss
from colossalai.accelerator import get_accelerator
from colossalai.legacy.context import ParallelMode
from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.nn.layer.parallel_2d import reduce_by_batch_2d, split_batch_2d
from colossalai.legacy.nn.layer.parallel_2d._utils import assert_summa_initialization
from colossalai.legacy.registry import LOSSES
from colossalai.utils import get_current_device
@LOSSES.register_module
......@@ -118,7 +118,7 @@ class _VocabParallelCrossEntropy2D(torch.autograd.Function):
grad_2d = grad_input.view(-1, partition_vocab_size)
# Add the gradient from matching classes.
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_current_device())
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_accelerator().get_current_device())
grad_2d[arange_1d, masked_target] -= 1.0 - target_mask.view(-1).float()
# Finally elementwise multiplication with the output gradients.
......
......@@ -4,12 +4,12 @@ from torch.cuda.amp import custom_bwd, custom_fwd
from torch.nn.functional import cross_entropy
from torch.nn.modules.loss import _Loss
from colossalai.accelerator import get_accelerator
from colossalai.legacy.context import ParallelMode
from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_batch_2p5d
from colossalai.legacy.nn.layer.parallel_2p5d._utils import assert_tesseract_initialization
from colossalai.legacy.registry import LOSSES
from colossalai.utils import get_current_device
@LOSSES.register_module
......@@ -112,7 +112,7 @@ class _VocabParallelCrossEntropy2p5D(torch.autograd.Function):
grad_2d = grad_input.view(-1, partition_vocab_size)
# Add the gradient from matching classes.
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_current_device())
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_accelerator().get_current_device())
grad_2d[arange_1d, masked_target] -= 1.0 - target_mask.view(-1).float()
# Finally elementwise multiplication with the output gradients.
......
......@@ -4,12 +4,12 @@ from torch.cuda.amp import custom_bwd, custom_fwd
from torch.nn.functional import cross_entropy
from torch.nn.modules.loss import _Loss
from colossalai.accelerator import get_accelerator
from colossalai.legacy.constants import INPUT_GROUP_3D, OUTPUT_GROUP_3D, WEIGHT_GROUP_3D
from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.nn.layer.parallel_3d import reduce_by_batch_3d, split_tensor_3d
from colossalai.legacy.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
from colossalai.legacy.registry import LOSSES
from colossalai.utils import get_current_device
@LOSSES.register_module
......@@ -80,7 +80,7 @@ class _VocabParallelCrossEntropy3D(torch.autograd.Function):
target_mask = (targets < vocab_start) | (targets > vocab_end)
masked_target = targets.clone() - vocab_start
masked_target[target_mask] = 0
arange_1d = torch.arange(start=0, end=logits.size()[0], device=get_current_device())
arange_1d = torch.arange(start=0, end=logits.size()[0], device=get_accelerator().get_current_device())
predicted_logits = logits[arange_1d, masked_target]
predicted_logits = predicted_logits.clone().contiguous().view_as(targets)
predicted_logits[target_mask] = 0.0
......@@ -110,7 +110,7 @@ class _VocabParallelCrossEntropy3D(torch.autograd.Function):
grad_2d = input_grad.view(-1, partition_vocab_size)
# Add the gradient from matching classes.
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_current_device())
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_accelerator().get_current_device())
grad_2d[arange_1d, masked_target] -= 1.0 - target_mask.view(-1).float()
input_grad.mul_(output_grad.unsqueeze(dim=-1))
......
......@@ -7,12 +7,12 @@ from typing import Callable
import torch
import torch.distributed as dist
from colossalai.accelerator import get_accelerator
from colossalai.legacy.communication import all_reduce
from colossalai.legacy.context import ParallelMode
from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.registry import HOOKS
from colossalai.legacy.utils import is_no_pp_or_last_stage
from colossalai.utils import get_current_device
from ._base_hook import BaseHook
from ._commons_ import _format_number
......@@ -82,8 +82,8 @@ class LossMetric(Metric):
def __init__(self, epoch_only):
super().__init__(epoch_only=epoch_only)
self.last_step_loss = torch.zeros(1, device=get_current_device())
self.accum_loss = torch.zeros(1, device=get_current_device())
self.last_step_loss = torch.zeros(1, device=get_accelerator().get_current_device())
self.accum_loss = torch.zeros(1, device=get_accelerator().get_current_device())
self.count = 0
def reset(self) -> None:
......@@ -164,10 +164,10 @@ class AccuracyMetric(Metric):
def __init__(self, epoch_only: bool, accuracy_func: Callable):
super().__init__(epoch_only=epoch_only)
self.acc = accuracy_func
self.last_step_sum = torch.zeros(1, device=get_current_device())
self.last_step_correct = torch.zeros(1, device=get_current_device())
self.accumulated_sum = torch.zeros(1, device=get_current_device())
self.accumulated_correct = torch.zeros(1, device=get_current_device())
self.last_step_sum = torch.zeros(1, device=get_accelerator().get_current_device())
self.last_step_correct = torch.zeros(1, device=get_accelerator().get_current_device())
self.accumulated_sum = torch.zeros(1, device=get_accelerator().get_current_device())
self.accumulated_correct = torch.zeros(1, device=get_accelerator().get_current_device())
def reset(self) -> None:
self.last_step_sum.zero_()
......@@ -320,10 +320,10 @@ class ThroughputMetric(Metric):
super().__init__(epoch_only=epoch_only)
self.ignored_steps = ignored_steps
self.cur_steps = 0
self.accumulated_num_samples = torch.zeros(1, device=get_current_device())
self.accumulated_used_time = torch.zeros(1, device=get_current_device())
self.last_step_num_samples = torch.zeros(1, device=get_current_device())
self.last_step_used_time = torch.zeros(1, device=get_current_device())
self.accumulated_num_samples = torch.zeros(1, device=get_accelerator().get_current_device())
self.accumulated_used_time = torch.zeros(1, device=get_accelerator().get_current_device())
self.last_step_num_samples = torch.zeros(1, device=get_accelerator().get_current_device())
self.last_step_used_time = torch.zeros(1, device=get_accelerator().get_current_device())
self._tflop_per_step = tflop_per_step
self._use_local = use_local
......
......@@ -6,8 +6,8 @@ import weakref
import torch
from torch.utils.checkpoint import check_backward_validity, detach_variable
from colossalai.accelerator import get_accelerator
from colossalai.legacy.context.random import get_current_mode, get_states, set_mode, set_seed_states, sync_states
from colossalai.utils.device import autocast, get_current_device
def copy_to_device(obj, device):
......@@ -33,7 +33,7 @@ class CheckpointFunction(torch.autograd.Function):
check_backward_validity(args)
ctx.run_function = run_function
ctx.activation_offload = activation_offload
ctx.device = get_current_device()
ctx.device = get_accelerator().get_current_device()
# preserve rng states
ctx.fwd_cpu_rng_state = torch.get_rng_state()
......@@ -110,7 +110,7 @@ class CheckpointFunction(torch.autograd.Function):
inputs[idx] = tensors[i]
detached_inputs = detach_variable(tuple(inputs))
if ctx.had_autocast_in_fwd:
with torch.enable_grad(), autocast():
with torch.enable_grad(), get_accelerator().autocast()():
outputs = ctx.run_function(*detached_inputs)
else:
with torch.enable_grad():
......@@ -226,7 +226,7 @@ def _checkpoint_without_reentrant(function, activation_offload=False, *args):
# rerun forward, the inner_pack will store all the activations in storage
if has_autocast_in_fwd:
with torch.enable_grad(), autocast(), torch.autograd.graph.saved_tensors_hooks(
with torch.enable_grad(), get_accelerator().autocast()(), torch.autograd.graph.saved_tensors_hooks(
inner_pack, inner_unpack
):
_unused = function(*args)
......@@ -245,7 +245,7 @@ def _checkpoint_without_reentrant(function, activation_offload=False, *args):
# get device if we need to offload the activation
if activation_offload:
device = get_current_device()
device = get_accelerator().get_current_device()
# run function with pack and unpack as saved_tensors_hooks
with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
......
......@@ -6,9 +6,9 @@ import torch
import torch.distributed as dist
from packaging import version
from colossalai.accelerator import get_accelerator
from colossalai.legacy.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device
_GLOBAL_CUDA_MEM_FRACTION = 1.0
_GLOBAL_CPU_MEM_CAPACITY = -1
......@@ -112,7 +112,10 @@ def colo_device_memory_capacity(device: torch.device) -> int:
# In the context of 1-CPU-N-GPU, the memory capacity of the current process is 1/N overall CPU memory.
return colo_get_cpu_memory_capacity() / gpc.num_processes_on_current_node
if device.type == "cuda":
return torch.cuda.get_device_properties(get_current_device()).total_memory * _GLOBAL_CUDA_MEM_FRACTION
return (
torch.cuda.get_device_properties(get_accelerator().get_current_device()).total_memory
* _GLOBAL_CUDA_MEM_FRACTION
)
def colo_device_memory_used(device: torch.device) -> int:
......@@ -153,7 +156,7 @@ def colo_set_process_memory_fraction(ratio: float) -> None:
return
global _GLOBAL_CUDA_MEM_FRACTION
_GLOBAL_CUDA_MEM_FRACTION = ratio
torch.cuda.set_per_process_memory_fraction(_GLOBAL_CUDA_MEM_FRACTION, get_current_device())
torch.cuda.set_per_process_memory_fraction(_GLOBAL_CUDA_MEM_FRACTION, get_accelerator().get_current_device())
def colo_set_cpu_memory_capacity(size: int) -> None:
......
......@@ -8,7 +8,7 @@ import torch.distributed as dist
from torch.autograd.profiler import profile
from torch.distributed import ReduceOp
from colossalai.utils import get_current_device
from colossalai.accelerator import get_accelerator
from .prof_utils import BaseProfiler, _format_bandwidth, _format_memory, _format_time
......@@ -177,7 +177,7 @@ class CommProfiler(BaseProfiler):
assert current_comm_event is not None, "dist op has not been found"
buffer = torch.tensor([current_comm_event.self_cuda_time], device=get_current_device())
buffer = torch.tensor([current_comm_event.self_cuda_time], device=get_accelerator().get_current_device())
torch_all_reduce(buffer, op=ReduceOp.MIN, group=group)
current_comm_event.self_cuda_time = buffer.item()
......
......@@ -3,7 +3,7 @@ import types
from time import time
from typing import List
from colossalai.utils.device import get_current_device
from colossalai.accelerator import get_accelerator
from .stateful_tensor import StatefulTensor, TensorState
from .tensor_placement_policy import TensorPlacementPolicy
......@@ -69,7 +69,7 @@ class StatefulTensorMgr(object):
# move COMPUTE tensors to CUDA
self._cpu_gpu_move_volume += cuda_demand
for t in move_to_cuda_tensor_list:
colo_model_data_tensor_move_inline(t, get_current_device())
colo_model_data_tensor_move_inline(t, get_accelerator().get_current_device())
@property
def cpu_gpu_move_volume(self):
......
......@@ -5,8 +5,8 @@ from typing import List, Optional, Type
import torch
from colossalai.accelerator import get_accelerator
from colossalai.legacy.utils.memory import colo_device_memory_capacity
from colossalai.utils import get_current_device
from colossalai.zero.gemini.memory_tracer import MemStatsCollector
from .stateful_tensor import StatefulTensor
......@@ -38,7 +38,7 @@ class CPUTensorPlacementPolicy(TensorPlacementPolicy):
class CUDATensorPlacementPolicy(TensorPlacementPolicy):
def __init__(self, mem_stats_collector: Optional[MemStatsCollector] = None) -> None:
assert torch.cuda.is_available(), "Cannot use CUDATensorPlacementPolicy when CUDA is not available"
super().__init__(get_current_device(), mem_stats_collector=mem_stats_collector)
super().__init__(get_accelerator().get_current_device(), mem_stats_collector=mem_stats_collector)
def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) -> int:
return 0, 0
......@@ -78,7 +78,7 @@ class AutoTensorPlacementPolicy(TensorPlacementPolicy):
int: the volume of memory that is evicted
"""
start = time()
cuda_capacity = colo_device_memory_capacity(get_current_device())
cuda_capacity = colo_device_memory_capacity(get_accelerator().get_current_device())
used_cuda_model_data = StatefulTensor.GST_MGR.total_mem["cuda"]
if warmup:
# We designate a part of CUDA memory for model data in warmup iterations.
......
......@@ -4,8 +4,8 @@ import torch
import torch.distributed as dist
from torch._utils import _flatten_dense_tensors as flatten
from colossalai.accelerator import get_accelerator
from colossalai.legacy.zero.sharded_param.sharded_tensor import ShardedTensor
from colossalai.utils import get_current_device
from .tensor_shard_strategy import TensorShardStrategy
......@@ -30,9 +30,11 @@ class BucketTensorShardStrategy(TensorShardStrategy):
rank = dist.get_rank(process_group)
for i in range(world_size):
if i == rank:
buffer_list.append(flatten([t.payload for t in tensor_list]).cuda(get_current_device()))
buffer_list.append(
flatten([t.payload for t in tensor_list]).cuda(get_accelerator().get_current_device())
)
else:
buffer_list.append(torch.zeros(buffer_size, dtype=dtype, device=get_current_device()))
buffer_list.append(torch.zeros(buffer_size, dtype=dtype, device=get_accelerator().get_current_device()))
dist.all_gather(buffer_list, buffer_list[rank], group=process_group)
# Move to target device before splitting buffer
# Ensure we utilize maximum PCIE bandwidth
......
......@@ -3,11 +3,11 @@ from typing import List, Optional
import torch
import torch.distributed as dist
from colossalai.accelerator import get_accelerator
from colossalai.legacy.zero.gemini.tensor_utils import colo_model_data_tensor_move_inline
from colossalai.legacy.zero.shard_utils import BaseShardStrategy
from colossalai.legacy.zero.shard_utils.commons import get_shard
from colossalai.legacy.zero.sharded_param.sharded_tensor import ShardedTensor
from colossalai.utils import get_current_device
class TensorShardStrategy(BaseShardStrategy):
......@@ -34,9 +34,9 @@ class TensorShardStrategy(BaseShardStrategy):
if t.is_sharded:
return
if t.payload.device.type == "cuda":
assert t.payload.device == get_current_device(), (
assert t.payload.device == get_accelerator().get_current_device(), (
f"shard tensor on cuda device index {t.payload.device.index},"
f" but current cuda device is {get_current_device()}"
f" but current cuda device is {get_accelerator().get_current_device()}"
)
sharded_payload, _ = get_shard(t.payload, dist.get_rank(process_group), dist.get_world_size(process_group))
t.payload_reset(sharded_payload)
......@@ -50,7 +50,9 @@ class TensorShardStrategy(BaseShardStrategy):
world_size = dist.get_world_size(process_group)
rank = dist.get_rank(process_group)
buffer = torch.empty(payload_numel * world_size, dtype=t.payload.dtype, device=get_current_device())
buffer = torch.empty(
payload_numel * world_size, dtype=t.payload.dtype, device=get_accelerator().get_current_device()
)
buffer_list = list(torch.chunk(buffer, chunks=world_size, dim=0))
buffer_list[rank].copy_(t.payload)
......
......@@ -10,6 +10,7 @@ import torch.nn as nn
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from colossalai.accelerator import get_accelerator
from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.utils.memory import colo_device_memory_capacity
......@@ -22,7 +23,7 @@ from colossalai.legacy.zero.gemini.tensor_utils import colo_model_data_move_to_c
from colossalai.legacy.zero.shard_utils import BaseShardStrategy
from colossalai.legacy.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
from colossalai.logging import get_dist_logger
from colossalai.utils import disposable, get_current_device
from colossalai.utils import disposable
from colossalai.zero.gemini.memory_tracer import MemStatsCollector
from ._utils import (
......@@ -212,8 +213,12 @@ class ShardedModelV2(nn.Module):
self.logger.error(f"dump memory tracer collected information to a {filename}", ranks=[0])
if gpc.get_global_rank() == 0:
with open(filename, "w+") as f:
f.write(f"cuda reserved {torch.cuda.memory_reserved(get_current_device()) / 1e9} GB\n")
f.write(f"cuda max allocated {torch.cuda.max_memory_allocated(get_current_device()) / 1e9} GB\n")
f.write(
f"cuda reserved {torch.cuda.memory_reserved(get_accelerator().get_current_device()) / 1e9} GB\n"
)
f.write(
f"cuda max allocated {torch.cuda.max_memory_allocated(get_accelerator().get_current_device()) / 1e9} GB\n"
)
f.write("CUDA model data (GB)\n")
f.write("\n")
f.write("CUDA non model data (GB)\n")
......@@ -266,7 +271,8 @@ class ShardedModelV2(nn.Module):
# model data is fixed in cuda during training.
# cuda margin space can be used to store OS.
self._cuda_margin_space = (
colo_device_memory_capacity(get_current_device()) - self._memstats_collector._memstats.max_overall_cuda
colo_device_memory_capacity(get_accelerator().get_current_device())
- self._memstats_collector._memstats.max_overall_cuda
)
@torch.no_grad()
......
......@@ -3,13 +3,13 @@ from typing import Optional
import torch
import torch.distributed as dist
from colossalai.accelerator import get_accelerator
from colossalai.legacy.registry import OPHOOKS
from colossalai.legacy.zero.gemini.ophooks import BaseOpHook
from colossalai.legacy.zero.gemini.stateful_tensor import TensorState
from colossalai.legacy.zero.gemini.stateful_tensor_mgr import StatefulTensorMgr
from colossalai.legacy.zero.shard_utils import BaseShardStrategy
from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device
from colossalai.zero.gemini.memory_tracer import MemStatsCollector
......@@ -33,7 +33,7 @@ class ZeroHook(BaseOpHook):
self.process_group = process_group
# NOTE(jiaruifang) Now the computing device of FWD and BWD is always on GPU
self.computing_device = get_current_device()
self.computing_device = get_accelerator().get_current_device()
self._memstarts_collector = memstarts_collector
self._stateful_tensor_mgr = stateful_tensor_mgr
......
......@@ -8,9 +8,9 @@ import torch.nn as nn
import torch.nn.functional as F
from torch.distributed import ProcessGroup
from colossalai.accelerator import get_accelerator
from colossalai.moe._operation import moe_cumsum
from colossalai.moe.manager import MOE_MANAGER
from colossalai.utils import get_current_device
class MoeRouter(nn.Module, ABC):
......@@ -24,14 +24,16 @@ class MoeRouter(nn.Module, ABC):
drop_tks (bool, optional): Whether drops tokens in evaluation
"""
def __init__(self,
k_value: int,
capacity_factor_train: float,
capacity_factor_eval: float,
min_capacity: int,
noisy_func: Optional[Callable] = None,
drop_tks: bool = True,
use_kernel: bool = False):
def __init__(
self,
k_value: int,
capacity_factor_train: float,
capacity_factor_eval: float,
min_capacity: int,
noisy_func: Optional[Callable] = None,
drop_tks: bool = True,
use_kernel: bool = False,
):
super().__init__()
self.k_value = k_value
self.capacity_factor_train = capacity_factor_train
......@@ -68,8 +70,9 @@ class MoeRouter(nn.Module, ABC):
if router_probs.dim() == expert_indices.dim() == 2:
router_probs = router_probs.unsqueeze(0)
expert_indices = expert_indices.unsqueeze(0)
assert router_probs.dim() == expert_indices.dim() == 3, \
"router_probs must be 3D tensor and expert_indices must be 4D tensor"
assert (
router_probs.dim() == expert_indices.dim() == 3
), "router_probs must be 3D tensor and expert_indices must be 4D tensor"
# Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts].
expert_mask = F.one_hot(expert_indices, num_experts)
......@@ -122,25 +125,29 @@ class Top1Router(MoeRouter):
drop_tks (bool, optional): Whether drops tokens in evaluation
"""
def __init__(self,
capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0,
min_capacity: int = 4,
select_policy: str = "first",
noisy_func: Optional[Callable] = None,
drop_tks: bool = True):
super().__init__(k_value=1,
capacity_factor_train=capacity_factor_train,
capacity_factor_eval=capacity_factor_eval,
min_capacity=min_capacity,
noisy_func=noisy_func,
drop_tks=drop_tks)
def __init__(
self,
capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0,
min_capacity: int = 4,
select_policy: str = "first",
noisy_func: Optional[Callable] = None,
drop_tks: bool = True,
):
super().__init__(
k_value=1,
capacity_factor_train=capacity_factor_train,
capacity_factor_eval=capacity_factor_eval,
min_capacity=min_capacity,
noisy_func=noisy_func,
drop_tks=drop_tks,
)
self.select_policy = select_policy
assert select_policy in {"first", "random"}
if select_policy == "random":
self.uniform = torch.distributions.uniform.Uniform(
low=torch.tensor(0.0, device=get_current_device()),
high=torch.tensor(1.0, device=get_current_device())
low=torch.tensor(0.0, device=get_accelerator().get_current_device()),
high=torch.tensor(1.0, device=get_accelerator().get_current_device()),
).rsample
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple:
......@@ -216,18 +223,22 @@ class Top2Router(MoeRouter):
drop_tks (bool, optional): Whether drops tokens in evaluation.
"""
def __init__(self,
capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0,
min_capacity: int = 4,
noisy_func: Optional[Callable] = None,
drop_tks: bool = True):
super().__init__(k_value=2,
capacity_factor_train=capacity_factor_train,
capacity_factor_eval=capacity_factor_eval,
min_capacity=min_capacity,
noisy_func=noisy_func,
drop_tks=drop_tks)
def __init__(
self,
capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0,
min_capacity: int = 4,
noisy_func: Optional[Callable] = None,
drop_tks: bool = True,
):
super().__init__(
k_value=2,
capacity_factor_train=capacity_factor_train,
capacity_factor_eval=capacity_factor_eval,
min_capacity=min_capacity,
noisy_func=noisy_func,
drop_tks=drop_tks,
)
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple:
"""
......@@ -255,8 +266,8 @@ class Top2Router(MoeRouter):
top2_idx = torch.argmax(logits_except1, dim=-1)
mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32)
cmask = (mask1 + mask2) # loss: [s, e]
cmask = cmask.float() / 2.0 # div 2 to normalize it to 1
cmask = mask1 + mask2 # loss: [s, e]
cmask = cmask.float() / 2.0 # div 2 to normalize it to 1
# calculate loss
expert_indices = torch.stack([top1_idx, top2_idx], dim=-1)
......@@ -269,7 +280,7 @@ class Top2Router(MoeRouter):
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group)
capacity = max_num.item()
rank1 = moe_cumsum(mask1, use_kernel=self.use_kernel) # rank1: [s, e]
rank1 = moe_cumsum(mask1, use_kernel=self.use_kernel) # rank1: [s, e]
rank2 = moe_cumsum(mask2, use_kernel=self.use_kernel)
rank2 += torch.sum(mask1, dim=-2, keepdim=True)
......@@ -336,15 +347,18 @@ class TopKRouter(MoeRouter):
oversubscribed / reach capacity.
"""
def __init__(self,
num_selected_experts: int,
capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0,
min_capacity: int = 4,
noisy_func: Optional[Callable] = None,
drop_tks: bool = True):
super().__init__(num_selected_experts, capacity_factor_train, capacity_factor_eval, min_capacity, noisy_func,
drop_tks)
def __init__(
self,
num_selected_experts: int,
capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0,
min_capacity: int = 4,
noisy_func: Optional[Callable] = None,
drop_tks: bool = True,
):
super().__init__(
num_selected_experts, capacity_factor_train, capacity_factor_eval, min_capacity, noisy_func, drop_tks
)
def forward(
self,
......@@ -410,7 +424,7 @@ class TopKRouter(MoeRouter):
# The combine array will be used for combining expert outputs, scaled by the
# router probabilities. Shape: [num_groups, tokens_per_group, num_experts,
# expert_capacity].
combine_array = torch.einsum('...te,...tec->...tec', router_probs, dispatch_mask)
combine_array = torch.einsum("...te,...tec->...tec", router_probs, dispatch_mask)
return combine_array, dispatch_mask
......
......@@ -7,13 +7,12 @@ import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from colossalai.accelerator import get_accelerator
from colossalai.moe.manager import MOE_MANAGER
from colossalai.tensor.moe_tensor.api import get_dp_group, get_dp_group_ranks, get_ep_size, is_moe_tensor
from colossalai.utils import get_current_device
class ForceFP32Parameter(torch.nn.Parameter):
def half(self, memory_format=None):
return self.data.clone()
......@@ -30,8 +29,8 @@ class NormalNoiseGenerator:
def __init__(self, num_experts: int):
self.normal = torch.distributions.normal.Normal(
loc=torch.tensor(0.0, device=get_current_device()),
scale=torch.tensor(1.0 / num_experts**2, device=get_current_device()),
loc=torch.tensor(0.0, device=get_accelerator().get_current_device()),
scale=torch.tensor(1.0 / num_experts**2, device=get_accelerator().get_current_device()),
).rsample
def __call__(self, inputs: torch.Tensor):
......@@ -52,8 +51,8 @@ class UniformNoiseGenerator:
def __init__(self, eps: float = 1e-2):
self.uniform = torch.distributions.uniform.Uniform(
low=torch.tensor(1.0 - eps, device=get_current_device()),
high=torch.tensor(1.0 + eps, device=get_current_device()),
low=torch.tensor(1.0 - eps, device=get_accelerator().get_current_device()),
high=torch.tensor(1.0 + eps, device=get_accelerator().get_current_device()),
).rsample
def __call__(self, inputs: torch.Tensor):
......@@ -142,7 +141,7 @@ def get_moe_epsize_param_dict(model: nn.Module) -> Dict[int, List[nn.Parameter]]
epsize_param_dict = dict()
for param in model.parameters():
if not is_moe_tensor(param):
ep_size = 1 # set ep_size to 1 for dp parameters
ep_size = 1 # set ep_size to 1 for dp parameters
else:
ep_size = get_ep_size(param)
if ep_size not in epsize_param_dict:
......@@ -193,18 +192,13 @@ def create_ep_hierarchical_group(
assert nproc_per_node is not None, "Please use torchrun to launch the job, or specify nproc_per_node manually."
nproc_per_node = int(nproc_per_node)
else:
assert dist.get_world_size() % nproc_per_node == 0, \
"nproc_per_node should be a divisor of world_size."
assert dist.get_world_size() % nproc_per_node == 0, "nproc_per_node should be a divisor of world_size."
num_node = dist.get_world_size() // nproc_per_node
intra_src_rank = None
ep_intra_node_group = None
for i in range(num_node):
ep_intra_ranks = [
i * nproc_per_node + j
for j in range(nproc_per_node)
if j in ep_group_ranks
]
ep_intra_ranks = [i * nproc_per_node + j for j in range(nproc_per_node) if j in ep_group_ranks]
group = dist.new_group(ep_intra_ranks)
if rank in ep_intra_ranks:
assert ep_intra_node_group is None
......@@ -212,10 +206,7 @@ def create_ep_hierarchical_group(
intra_src_rank = ep_intra_ranks[0]
ep_inter_node_group = None
ep_inter_ranks = [
ep_group_ranks[0] + i * nproc_per_node
for i in range(num_node)
]
ep_inter_ranks = [ep_group_ranks[0] + i * nproc_per_node for i in range(num_node)]
if len(ep_inter_ranks) > 1:
group = dist.new_group(ep_inter_ranks)
if rank in ep_inter_ranks:
......
......@@ -7,10 +7,10 @@ import torch.cuda
from torch.nn import Module
from torch.utils._pytree import tree_map
from colossalai.accelerator import get_accelerator
from colossalai.inference.engine.microbatch_manager import MicroBatchManager, Status
from colossalai.pipeline.p2p import PipelineP2PCommunication
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils.device import get_current_device
from ._utils import get_batch_size, get_micro_batch, model_forward, to_device
from .base import PipelineSchedule
......@@ -86,7 +86,7 @@ class GenerateSchedule(PipelineSchedule):
"""
micro_batch = get_micro_batch(self.batch, self.microbatch_offset, self.microbatch_size)
self.microbatch_offset += self.microbatch_size
return tree_map(partial(to_device, device=get_current_device()), micro_batch)
return tree_map(partial(to_device, device=get_accelerator().get_current_device()), micro_batch)
def _prepare_inputs_for_interval_stage(self):
"""
......
......@@ -6,10 +6,10 @@ import torch.cuda
from torch.nn import Module
from torch.utils._pytree import tree_map
from colossalai.accelerator import get_accelerator
from colossalai.interface import OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils.device import get_current_device
from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device
from .base import PipelineSchedule
......@@ -56,7 +56,7 @@ class InterleavedSchedule(PipelineSchedule):
"""
micro_batch = get_micro_batch(self.batch, self.microbatch_offset[model_chunk_id], self.microbatch_size)
self.microbatch_offset[model_chunk_id] += self.microbatch_size
return tree_map(partial(to_device, device=get_current_device()), micro_batch)
return tree_map(partial(to_device, device=get_accelerator().get_current_device()), micro_batch)
def get_model_chunk_id(self, microbatch_id: int, forward: bool) -> int:
"""Helper method to get the model chunk ID given the iteration number.
......@@ -292,7 +292,7 @@ class InterleavedSchedule(PipelineSchedule):
outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None
if return_loss and self.stage_manager.is_last_stage():
accum_loss = torch.zeros(1, device=get_current_device())
accum_loss = torch.zeros(1, device=get_accelerator().get_current_device())
else:
accum_loss = None
......
......@@ -6,10 +6,10 @@ import torch.cuda
from torch.nn import Module
from torch.utils._pytree import tree_map
from colossalai.accelerator import get_accelerator
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils.device import get_current_device
from ._utils import (
detach,
......@@ -80,7 +80,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
"""
micro_batch = get_micro_batch(self.batch, self.microbatch_offset, self.microbatch_size)
self.microbatch_offset += self.microbatch_size
return tree_map(partial(to_device, device=get_current_device()), micro_batch)
return tree_map(partial(to_device, device=get_accelerator().get_current_device()), micro_batch)
def recv_forward(self, prev_rank: int = None) -> Any:
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
......@@ -297,7 +297,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None
if return_loss and self.stage_manager.is_last_stage():
accum_loss = torch.zeros(1, device=get_current_device())
accum_loss = torch.zeros(1, device=get_accelerator().get_current_device())
else:
accum_loss = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment