Unverified Commit d202cc28 authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[npu] change device to accelerator api (#5239)



* update accelerator

* fix timer

* fix amp

* update

* fix

* update bug

* add error raise

* fix autocast

* fix set device

* remove doc accelerator

* update doc

* update doc

* update doc

* use nullcontext

* update cpu

* update null context

* change time limit for example

* udpate

* update

* update

* update

* [npu] polish accelerator code

---------
Co-authored-by: default avatarXuanlei Zhao <xuanlei.zhao@gmail.com>
Co-authored-by: default avatarzxl <43881818+oahzxl@users.noreply.github.com>
parent dd2c28a3
...@@ -5,11 +5,11 @@ import torch ...@@ -5,11 +5,11 @@ import torch
from torch import distributed as dist from torch import distributed as dist
from torch.cuda.amp import custom_bwd, custom_fwd from torch.cuda.amp import custom_bwd, custom_fwd
from colossalai.accelerator import get_accelerator
from colossalai.legacy.communication import ring_forward from colossalai.legacy.communication import ring_forward
from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.legacy.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.nn.layer.parallel_sequence._utils import _calc_current_device_range, _calc_incoming_device_range from colossalai.legacy.nn.layer.parallel_sequence._utils import _calc_current_device_range, _calc_incoming_device_range
from colossalai.utils import get_current_device
class RingQK(torch.autograd.Function): class RingQK(torch.autograd.Function):
...@@ -30,7 +30,7 @@ class RingQK(torch.autograd.Function): ...@@ -30,7 +30,7 @@ class RingQK(torch.autograd.Function):
sub_seq_length, sub_seq_length,
sub_seq_length * gpc.get_world_size(ParallelMode.SEQUENCE), sub_seq_length * gpc.get_world_size(ParallelMode.SEQUENCE),
dtype=sub_q.dtype, dtype=sub_q.dtype,
device=get_current_device(), device=get_accelerator().get_current_device(),
) )
# compute local QK^T # compute local QK^T
...@@ -71,7 +71,7 @@ class RingQK(torch.autograd.Function): ...@@ -71,7 +71,7 @@ class RingQK(torch.autograd.Function):
grad_q = torch.zeros_like( grad_q = torch.zeros_like(
sub_q, sub_q,
dtype=sub_q.dtype, dtype=sub_q.dtype,
device=get_current_device(), device=get_accelerator().get_current_device(),
) )
# compute with local sub_k # compute with local sub_k
...@@ -105,7 +105,7 @@ class RingAV(torch.autograd.Function): ...@@ -105,7 +105,7 @@ class RingAV(torch.autograd.Function):
batch_size * num_attention_heads, batch_size * num_attention_heads,
sub_seq_length, sub_seq_length,
attention_head_size, attention_head_size,
device=get_current_device(), device=get_accelerator().get_current_device(),
dtype=attention_score.dtype, dtype=attention_score.dtype,
) )
...@@ -142,7 +142,9 @@ class RingAV(torch.autograd.Function): ...@@ -142,7 +142,9 @@ class RingAV(torch.autograd.Function):
grad_v /= local_world_size grad_v /= local_world_size
# calculate gradient for attention score # calculate gradient for attention score
grad_attention_score = torch.zeros_like(attention_scores, dtype=grad_output.dtype, device=get_current_device()) grad_attention_score = torch.zeros_like(
attention_scores, dtype=grad_output.dtype, device=get_accelerator().get_current_device()
)
# compute with local sub_k # compute with local sub_k
grad_attention_score[:, :, local_start_idx:local_end_idx] += torch.matmul(grad_output, sub_v.transpose(2, 1)) grad_attention_score[:, :, local_start_idx:local_end_idx] += torch.matmul(grad_output, sub_v.transpose(2, 1))
......
...@@ -7,10 +7,10 @@ from torch import Tensor ...@@ -7,10 +7,10 @@ from torch import Tensor
from torch import nn as nn from torch import nn as nn
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from colossalai.accelerator import get_accelerator
from colossalai.legacy.context import seed from colossalai.legacy.context import seed
from colossalai.legacy.registry import LAYERS from colossalai.legacy.registry import LAYERS
from colossalai.nn import init as init from colossalai.nn import init as init
from colossalai.utils.device import get_current_device
from ..utils import to_2tuple from ..utils import to_2tuple
...@@ -173,12 +173,18 @@ class VanillaPatchEmbedding(nn.Module): ...@@ -173,12 +173,18 @@ class VanillaPatchEmbedding(nn.Module):
self.flatten = flatten self.flatten = flatten
self.weight = nn.Parameter( self.weight = nn.Parameter(
torch.empty((embed_size, in_chans, *self.patch_size), device=get_current_device(), dtype=dtype) torch.empty(
(embed_size, in_chans, *self.patch_size), device=get_accelerator().get_current_device(), dtype=dtype
)
)
self.bias = nn.Parameter(torch.empty(embed_size, device=get_accelerator().get_current_device(), dtype=dtype))
self.cls_token = nn.Parameter(
torch.zeros((1, 1, embed_size), device=get_accelerator().get_current_device(), dtype=dtype)
) )
self.bias = nn.Parameter(torch.empty(embed_size, device=get_current_device(), dtype=dtype))
self.cls_token = nn.Parameter(torch.zeros((1, 1, embed_size), device=get_current_device(), dtype=dtype))
self.pos_embed = nn.Parameter( self.pos_embed = nn.Parameter(
torch.zeros((1, self.num_patches + 1, embed_size), device=get_current_device(), dtype=dtype) torch.zeros(
(1, self.num_patches + 1, embed_size), device=get_accelerator().get_current_device(), dtype=dtype
)
) )
self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer) self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer)
...@@ -242,11 +248,15 @@ class VanillaClassifier(nn.Module): ...@@ -242,11 +248,15 @@ class VanillaClassifier(nn.Module):
self.has_weight = False self.has_weight = False
else: else:
self.weight = nn.Parameter( self.weight = nn.Parameter(
torch.empty(self.num_classes, self.in_features, device=get_current_device(), dtype=dtype) torch.empty(
self.num_classes, self.in_features, device=get_accelerator().get_current_device(), dtype=dtype
)
) )
self.has_weight = True self.has_weight = True
if bias: if bias:
self.bias = nn.Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype)) self.bias = nn.Parameter(
torch.zeros(self.num_classes, device=get_accelerator().get_current_device(), dtype=dtype)
)
else: else:
self.bias = None self.bias = None
...@@ -287,7 +297,7 @@ class VanillaLayerNorm(nn.Module): ...@@ -287,7 +297,7 @@ class VanillaLayerNorm(nn.Module):
self.normalized_shape = (normalized_shape,) self.normalized_shape = (normalized_shape,)
self.variance_epsilon = eps self.variance_epsilon = eps
factory_kwargs = {"device": get_current_device(), "dtype": dtype} factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
self.weight = nn.Parameter(torch.ones(normalized_shape, **factory_kwargs)) self.weight = nn.Parameter(torch.ones(normalized_shape, **factory_kwargs))
if bias: if bias:
...@@ -333,7 +343,7 @@ class VanillaLinear(nn.Module): ...@@ -333,7 +343,7 @@ class VanillaLinear(nn.Module):
self.in_features = in_features self.in_features = in_features
self.out_features = out_features self.out_features = out_features
self.skip_bias_add = skip_bias_add self.skip_bias_add = skip_bias_add
factory_kwargs = {"device": get_current_device(), "dtype": dtype} factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs)) self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs))
if bias: if bias:
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
......
...@@ -4,12 +4,12 @@ from torch.cuda.amp import custom_bwd, custom_fwd ...@@ -4,12 +4,12 @@ from torch.cuda.amp import custom_bwd, custom_fwd
from torch.nn.functional import cross_entropy from torch.nn.functional import cross_entropy
from torch.nn.modules.loss import _Loss from torch.nn.modules.loss import _Loss
from colossalai.accelerator import get_accelerator
from colossalai.legacy.context import ParallelMode from colossalai.legacy.context import ParallelMode
from colossalai.legacy.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.nn.layer.parallel_2d import reduce_by_batch_2d, split_batch_2d from colossalai.legacy.nn.layer.parallel_2d import reduce_by_batch_2d, split_batch_2d
from colossalai.legacy.nn.layer.parallel_2d._utils import assert_summa_initialization from colossalai.legacy.nn.layer.parallel_2d._utils import assert_summa_initialization
from colossalai.legacy.registry import LOSSES from colossalai.legacy.registry import LOSSES
from colossalai.utils import get_current_device
@LOSSES.register_module @LOSSES.register_module
...@@ -118,7 +118,7 @@ class _VocabParallelCrossEntropy2D(torch.autograd.Function): ...@@ -118,7 +118,7 @@ class _VocabParallelCrossEntropy2D(torch.autograd.Function):
grad_2d = grad_input.view(-1, partition_vocab_size) grad_2d = grad_input.view(-1, partition_vocab_size)
# Add the gradient from matching classes. # Add the gradient from matching classes.
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_current_device()) arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_accelerator().get_current_device())
grad_2d[arange_1d, masked_target] -= 1.0 - target_mask.view(-1).float() grad_2d[arange_1d, masked_target] -= 1.0 - target_mask.view(-1).float()
# Finally elementwise multiplication with the output gradients. # Finally elementwise multiplication with the output gradients.
......
...@@ -4,12 +4,12 @@ from torch.cuda.amp import custom_bwd, custom_fwd ...@@ -4,12 +4,12 @@ from torch.cuda.amp import custom_bwd, custom_fwd
from torch.nn.functional import cross_entropy from torch.nn.functional import cross_entropy
from torch.nn.modules.loss import _Loss from torch.nn.modules.loss import _Loss
from colossalai.accelerator import get_accelerator
from colossalai.legacy.context import ParallelMode from colossalai.legacy.context import ParallelMode
from colossalai.legacy.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_batch_2p5d from colossalai.legacy.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_batch_2p5d
from colossalai.legacy.nn.layer.parallel_2p5d._utils import assert_tesseract_initialization from colossalai.legacy.nn.layer.parallel_2p5d._utils import assert_tesseract_initialization
from colossalai.legacy.registry import LOSSES from colossalai.legacy.registry import LOSSES
from colossalai.utils import get_current_device
@LOSSES.register_module @LOSSES.register_module
...@@ -112,7 +112,7 @@ class _VocabParallelCrossEntropy2p5D(torch.autograd.Function): ...@@ -112,7 +112,7 @@ class _VocabParallelCrossEntropy2p5D(torch.autograd.Function):
grad_2d = grad_input.view(-1, partition_vocab_size) grad_2d = grad_input.view(-1, partition_vocab_size)
# Add the gradient from matching classes. # Add the gradient from matching classes.
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_current_device()) arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_accelerator().get_current_device())
grad_2d[arange_1d, masked_target] -= 1.0 - target_mask.view(-1).float() grad_2d[arange_1d, masked_target] -= 1.0 - target_mask.view(-1).float()
# Finally elementwise multiplication with the output gradients. # Finally elementwise multiplication with the output gradients.
......
...@@ -4,12 +4,12 @@ from torch.cuda.amp import custom_bwd, custom_fwd ...@@ -4,12 +4,12 @@ from torch.cuda.amp import custom_bwd, custom_fwd
from torch.nn.functional import cross_entropy from torch.nn.functional import cross_entropy
from torch.nn.modules.loss import _Loss from torch.nn.modules.loss import _Loss
from colossalai.accelerator import get_accelerator
from colossalai.legacy.constants import INPUT_GROUP_3D, OUTPUT_GROUP_3D, WEIGHT_GROUP_3D from colossalai.legacy.constants import INPUT_GROUP_3D, OUTPUT_GROUP_3D, WEIGHT_GROUP_3D
from colossalai.legacy.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.nn.layer.parallel_3d import reduce_by_batch_3d, split_tensor_3d from colossalai.legacy.nn.layer.parallel_3d import reduce_by_batch_3d, split_tensor_3d
from colossalai.legacy.nn.layer.parallel_3d._utils import get_parallel_mode_from_env from colossalai.legacy.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
from colossalai.legacy.registry import LOSSES from colossalai.legacy.registry import LOSSES
from colossalai.utils import get_current_device
@LOSSES.register_module @LOSSES.register_module
...@@ -80,7 +80,7 @@ class _VocabParallelCrossEntropy3D(torch.autograd.Function): ...@@ -80,7 +80,7 @@ class _VocabParallelCrossEntropy3D(torch.autograd.Function):
target_mask = (targets < vocab_start) | (targets > vocab_end) target_mask = (targets < vocab_start) | (targets > vocab_end)
masked_target = targets.clone() - vocab_start masked_target = targets.clone() - vocab_start
masked_target[target_mask] = 0 masked_target[target_mask] = 0
arange_1d = torch.arange(start=0, end=logits.size()[0], device=get_current_device()) arange_1d = torch.arange(start=0, end=logits.size()[0], device=get_accelerator().get_current_device())
predicted_logits = logits[arange_1d, masked_target] predicted_logits = logits[arange_1d, masked_target]
predicted_logits = predicted_logits.clone().contiguous().view_as(targets) predicted_logits = predicted_logits.clone().contiguous().view_as(targets)
predicted_logits[target_mask] = 0.0 predicted_logits[target_mask] = 0.0
...@@ -110,7 +110,7 @@ class _VocabParallelCrossEntropy3D(torch.autograd.Function): ...@@ -110,7 +110,7 @@ class _VocabParallelCrossEntropy3D(torch.autograd.Function):
grad_2d = input_grad.view(-1, partition_vocab_size) grad_2d = input_grad.view(-1, partition_vocab_size)
# Add the gradient from matching classes. # Add the gradient from matching classes.
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_current_device()) arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_accelerator().get_current_device())
grad_2d[arange_1d, masked_target] -= 1.0 - target_mask.view(-1).float() grad_2d[arange_1d, masked_target] -= 1.0 - target_mask.view(-1).float()
input_grad.mul_(output_grad.unsqueeze(dim=-1)) input_grad.mul_(output_grad.unsqueeze(dim=-1))
......
...@@ -7,12 +7,12 @@ from typing import Callable ...@@ -7,12 +7,12 @@ from typing import Callable
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.accelerator import get_accelerator
from colossalai.legacy.communication import all_reduce from colossalai.legacy.communication import all_reduce
from colossalai.legacy.context import ParallelMode from colossalai.legacy.context import ParallelMode
from colossalai.legacy.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.registry import HOOKS from colossalai.legacy.registry import HOOKS
from colossalai.legacy.utils import is_no_pp_or_last_stage from colossalai.legacy.utils import is_no_pp_or_last_stage
from colossalai.utils import get_current_device
from ._base_hook import BaseHook from ._base_hook import BaseHook
from ._commons_ import _format_number from ._commons_ import _format_number
...@@ -82,8 +82,8 @@ class LossMetric(Metric): ...@@ -82,8 +82,8 @@ class LossMetric(Metric):
def __init__(self, epoch_only): def __init__(self, epoch_only):
super().__init__(epoch_only=epoch_only) super().__init__(epoch_only=epoch_only)
self.last_step_loss = torch.zeros(1, device=get_current_device()) self.last_step_loss = torch.zeros(1, device=get_accelerator().get_current_device())
self.accum_loss = torch.zeros(1, device=get_current_device()) self.accum_loss = torch.zeros(1, device=get_accelerator().get_current_device())
self.count = 0 self.count = 0
def reset(self) -> None: def reset(self) -> None:
...@@ -164,10 +164,10 @@ class AccuracyMetric(Metric): ...@@ -164,10 +164,10 @@ class AccuracyMetric(Metric):
def __init__(self, epoch_only: bool, accuracy_func: Callable): def __init__(self, epoch_only: bool, accuracy_func: Callable):
super().__init__(epoch_only=epoch_only) super().__init__(epoch_only=epoch_only)
self.acc = accuracy_func self.acc = accuracy_func
self.last_step_sum = torch.zeros(1, device=get_current_device()) self.last_step_sum = torch.zeros(1, device=get_accelerator().get_current_device())
self.last_step_correct = torch.zeros(1, device=get_current_device()) self.last_step_correct = torch.zeros(1, device=get_accelerator().get_current_device())
self.accumulated_sum = torch.zeros(1, device=get_current_device()) self.accumulated_sum = torch.zeros(1, device=get_accelerator().get_current_device())
self.accumulated_correct = torch.zeros(1, device=get_current_device()) self.accumulated_correct = torch.zeros(1, device=get_accelerator().get_current_device())
def reset(self) -> None: def reset(self) -> None:
self.last_step_sum.zero_() self.last_step_sum.zero_()
...@@ -320,10 +320,10 @@ class ThroughputMetric(Metric): ...@@ -320,10 +320,10 @@ class ThroughputMetric(Metric):
super().__init__(epoch_only=epoch_only) super().__init__(epoch_only=epoch_only)
self.ignored_steps = ignored_steps self.ignored_steps = ignored_steps
self.cur_steps = 0 self.cur_steps = 0
self.accumulated_num_samples = torch.zeros(1, device=get_current_device()) self.accumulated_num_samples = torch.zeros(1, device=get_accelerator().get_current_device())
self.accumulated_used_time = torch.zeros(1, device=get_current_device()) self.accumulated_used_time = torch.zeros(1, device=get_accelerator().get_current_device())
self.last_step_num_samples = torch.zeros(1, device=get_current_device()) self.last_step_num_samples = torch.zeros(1, device=get_accelerator().get_current_device())
self.last_step_used_time = torch.zeros(1, device=get_current_device()) self.last_step_used_time = torch.zeros(1, device=get_accelerator().get_current_device())
self._tflop_per_step = tflop_per_step self._tflop_per_step = tflop_per_step
self._use_local = use_local self._use_local = use_local
......
...@@ -6,8 +6,8 @@ import weakref ...@@ -6,8 +6,8 @@ import weakref
import torch import torch
from torch.utils.checkpoint import check_backward_validity, detach_variable from torch.utils.checkpoint import check_backward_validity, detach_variable
from colossalai.accelerator import get_accelerator
from colossalai.legacy.context.random import get_current_mode, get_states, set_mode, set_seed_states, sync_states from colossalai.legacy.context.random import get_current_mode, get_states, set_mode, set_seed_states, sync_states
from colossalai.utils.device import autocast, get_current_device
def copy_to_device(obj, device): def copy_to_device(obj, device):
...@@ -33,7 +33,7 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -33,7 +33,7 @@ class CheckpointFunction(torch.autograd.Function):
check_backward_validity(args) check_backward_validity(args)
ctx.run_function = run_function ctx.run_function = run_function
ctx.activation_offload = activation_offload ctx.activation_offload = activation_offload
ctx.device = get_current_device() ctx.device = get_accelerator().get_current_device()
# preserve rng states # preserve rng states
ctx.fwd_cpu_rng_state = torch.get_rng_state() ctx.fwd_cpu_rng_state = torch.get_rng_state()
...@@ -110,7 +110,7 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -110,7 +110,7 @@ class CheckpointFunction(torch.autograd.Function):
inputs[idx] = tensors[i] inputs[idx] = tensors[i]
detached_inputs = detach_variable(tuple(inputs)) detached_inputs = detach_variable(tuple(inputs))
if ctx.had_autocast_in_fwd: if ctx.had_autocast_in_fwd:
with torch.enable_grad(), autocast(): with torch.enable_grad(), get_accelerator().autocast()():
outputs = ctx.run_function(*detached_inputs) outputs = ctx.run_function(*detached_inputs)
else: else:
with torch.enable_grad(): with torch.enable_grad():
...@@ -226,7 +226,7 @@ def _checkpoint_without_reentrant(function, activation_offload=False, *args): ...@@ -226,7 +226,7 @@ def _checkpoint_without_reentrant(function, activation_offload=False, *args):
# rerun forward, the inner_pack will store all the activations in storage # rerun forward, the inner_pack will store all the activations in storage
if has_autocast_in_fwd: if has_autocast_in_fwd:
with torch.enable_grad(), autocast(), torch.autograd.graph.saved_tensors_hooks( with torch.enable_grad(), get_accelerator().autocast()(), torch.autograd.graph.saved_tensors_hooks(
inner_pack, inner_unpack inner_pack, inner_unpack
): ):
_unused = function(*args) _unused = function(*args)
...@@ -245,7 +245,7 @@ def _checkpoint_without_reentrant(function, activation_offload=False, *args): ...@@ -245,7 +245,7 @@ def _checkpoint_without_reentrant(function, activation_offload=False, *args):
# get device if we need to offload the activation # get device if we need to offload the activation
if activation_offload: if activation_offload:
device = get_current_device() device = get_accelerator().get_current_device()
# run function with pack and unpack as saved_tensors_hooks # run function with pack and unpack as saved_tensors_hooks
with torch.autograd.graph.saved_tensors_hooks(pack, unpack): with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
......
...@@ -6,9 +6,9 @@ import torch ...@@ -6,9 +6,9 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from packaging import version from packaging import version
from colossalai.accelerator import get_accelerator
from colossalai.legacy.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device
_GLOBAL_CUDA_MEM_FRACTION = 1.0 _GLOBAL_CUDA_MEM_FRACTION = 1.0
_GLOBAL_CPU_MEM_CAPACITY = -1 _GLOBAL_CPU_MEM_CAPACITY = -1
...@@ -112,7 +112,10 @@ def colo_device_memory_capacity(device: torch.device) -> int: ...@@ -112,7 +112,10 @@ def colo_device_memory_capacity(device: torch.device) -> int:
# In the context of 1-CPU-N-GPU, the memory capacity of the current process is 1/N overall CPU memory. # In the context of 1-CPU-N-GPU, the memory capacity of the current process is 1/N overall CPU memory.
return colo_get_cpu_memory_capacity() / gpc.num_processes_on_current_node return colo_get_cpu_memory_capacity() / gpc.num_processes_on_current_node
if device.type == "cuda": if device.type == "cuda":
return torch.cuda.get_device_properties(get_current_device()).total_memory * _GLOBAL_CUDA_MEM_FRACTION return (
torch.cuda.get_device_properties(get_accelerator().get_current_device()).total_memory
* _GLOBAL_CUDA_MEM_FRACTION
)
def colo_device_memory_used(device: torch.device) -> int: def colo_device_memory_used(device: torch.device) -> int:
...@@ -153,7 +156,7 @@ def colo_set_process_memory_fraction(ratio: float) -> None: ...@@ -153,7 +156,7 @@ def colo_set_process_memory_fraction(ratio: float) -> None:
return return
global _GLOBAL_CUDA_MEM_FRACTION global _GLOBAL_CUDA_MEM_FRACTION
_GLOBAL_CUDA_MEM_FRACTION = ratio _GLOBAL_CUDA_MEM_FRACTION = ratio
torch.cuda.set_per_process_memory_fraction(_GLOBAL_CUDA_MEM_FRACTION, get_current_device()) torch.cuda.set_per_process_memory_fraction(_GLOBAL_CUDA_MEM_FRACTION, get_accelerator().get_current_device())
def colo_set_cpu_memory_capacity(size: int) -> None: def colo_set_cpu_memory_capacity(size: int) -> None:
......
...@@ -8,7 +8,7 @@ import torch.distributed as dist ...@@ -8,7 +8,7 @@ import torch.distributed as dist
from torch.autograd.profiler import profile from torch.autograd.profiler import profile
from torch.distributed import ReduceOp from torch.distributed import ReduceOp
from colossalai.utils import get_current_device from colossalai.accelerator import get_accelerator
from .prof_utils import BaseProfiler, _format_bandwidth, _format_memory, _format_time from .prof_utils import BaseProfiler, _format_bandwidth, _format_memory, _format_time
...@@ -177,7 +177,7 @@ class CommProfiler(BaseProfiler): ...@@ -177,7 +177,7 @@ class CommProfiler(BaseProfiler):
assert current_comm_event is not None, "dist op has not been found" assert current_comm_event is not None, "dist op has not been found"
buffer = torch.tensor([current_comm_event.self_cuda_time], device=get_current_device()) buffer = torch.tensor([current_comm_event.self_cuda_time], device=get_accelerator().get_current_device())
torch_all_reduce(buffer, op=ReduceOp.MIN, group=group) torch_all_reduce(buffer, op=ReduceOp.MIN, group=group)
current_comm_event.self_cuda_time = buffer.item() current_comm_event.self_cuda_time = buffer.item()
......
...@@ -3,7 +3,7 @@ import types ...@@ -3,7 +3,7 @@ import types
from time import time from time import time
from typing import List from typing import List
from colossalai.utils.device import get_current_device from colossalai.accelerator import get_accelerator
from .stateful_tensor import StatefulTensor, TensorState from .stateful_tensor import StatefulTensor, TensorState
from .tensor_placement_policy import TensorPlacementPolicy from .tensor_placement_policy import TensorPlacementPolicy
...@@ -69,7 +69,7 @@ class StatefulTensorMgr(object): ...@@ -69,7 +69,7 @@ class StatefulTensorMgr(object):
# move COMPUTE tensors to CUDA # move COMPUTE tensors to CUDA
self._cpu_gpu_move_volume += cuda_demand self._cpu_gpu_move_volume += cuda_demand
for t in move_to_cuda_tensor_list: for t in move_to_cuda_tensor_list:
colo_model_data_tensor_move_inline(t, get_current_device()) colo_model_data_tensor_move_inline(t, get_accelerator().get_current_device())
@property @property
def cpu_gpu_move_volume(self): def cpu_gpu_move_volume(self):
......
...@@ -5,8 +5,8 @@ from typing import List, Optional, Type ...@@ -5,8 +5,8 @@ from typing import List, Optional, Type
import torch import torch
from colossalai.accelerator import get_accelerator
from colossalai.legacy.utils.memory import colo_device_memory_capacity from colossalai.legacy.utils.memory import colo_device_memory_capacity
from colossalai.utils import get_current_device
from colossalai.zero.gemini.memory_tracer import MemStatsCollector from colossalai.zero.gemini.memory_tracer import MemStatsCollector
from .stateful_tensor import StatefulTensor from .stateful_tensor import StatefulTensor
...@@ -38,7 +38,7 @@ class CPUTensorPlacementPolicy(TensorPlacementPolicy): ...@@ -38,7 +38,7 @@ class CPUTensorPlacementPolicy(TensorPlacementPolicy):
class CUDATensorPlacementPolicy(TensorPlacementPolicy): class CUDATensorPlacementPolicy(TensorPlacementPolicy):
def __init__(self, mem_stats_collector: Optional[MemStatsCollector] = None) -> None: def __init__(self, mem_stats_collector: Optional[MemStatsCollector] = None) -> None:
assert torch.cuda.is_available(), "Cannot use CUDATensorPlacementPolicy when CUDA is not available" assert torch.cuda.is_available(), "Cannot use CUDATensorPlacementPolicy when CUDA is not available"
super().__init__(get_current_device(), mem_stats_collector=mem_stats_collector) super().__init__(get_accelerator().get_current_device(), mem_stats_collector=mem_stats_collector)
def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) -> int: def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) -> int:
return 0, 0 return 0, 0
...@@ -78,7 +78,7 @@ class AutoTensorPlacementPolicy(TensorPlacementPolicy): ...@@ -78,7 +78,7 @@ class AutoTensorPlacementPolicy(TensorPlacementPolicy):
int: the volume of memory that is evicted int: the volume of memory that is evicted
""" """
start = time() start = time()
cuda_capacity = colo_device_memory_capacity(get_current_device()) cuda_capacity = colo_device_memory_capacity(get_accelerator().get_current_device())
used_cuda_model_data = StatefulTensor.GST_MGR.total_mem["cuda"] used_cuda_model_data = StatefulTensor.GST_MGR.total_mem["cuda"]
if warmup: if warmup:
# We designate a part of CUDA memory for model data in warmup iterations. # We designate a part of CUDA memory for model data in warmup iterations.
......
...@@ -4,8 +4,8 @@ import torch ...@@ -4,8 +4,8 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from torch._utils import _flatten_dense_tensors as flatten from torch._utils import _flatten_dense_tensors as flatten
from colossalai.accelerator import get_accelerator
from colossalai.legacy.zero.sharded_param.sharded_tensor import ShardedTensor from colossalai.legacy.zero.sharded_param.sharded_tensor import ShardedTensor
from colossalai.utils import get_current_device
from .tensor_shard_strategy import TensorShardStrategy from .tensor_shard_strategy import TensorShardStrategy
...@@ -30,9 +30,11 @@ class BucketTensorShardStrategy(TensorShardStrategy): ...@@ -30,9 +30,11 @@ class BucketTensorShardStrategy(TensorShardStrategy):
rank = dist.get_rank(process_group) rank = dist.get_rank(process_group)
for i in range(world_size): for i in range(world_size):
if i == rank: if i == rank:
buffer_list.append(flatten([t.payload for t in tensor_list]).cuda(get_current_device())) buffer_list.append(
flatten([t.payload for t in tensor_list]).cuda(get_accelerator().get_current_device())
)
else: else:
buffer_list.append(torch.zeros(buffer_size, dtype=dtype, device=get_current_device())) buffer_list.append(torch.zeros(buffer_size, dtype=dtype, device=get_accelerator().get_current_device()))
dist.all_gather(buffer_list, buffer_list[rank], group=process_group) dist.all_gather(buffer_list, buffer_list[rank], group=process_group)
# Move to target device before splitting buffer # Move to target device before splitting buffer
# Ensure we utilize maximum PCIE bandwidth # Ensure we utilize maximum PCIE bandwidth
......
...@@ -3,11 +3,11 @@ from typing import List, Optional ...@@ -3,11 +3,11 @@ from typing import List, Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.accelerator import get_accelerator
from colossalai.legacy.zero.gemini.tensor_utils import colo_model_data_tensor_move_inline from colossalai.legacy.zero.gemini.tensor_utils import colo_model_data_tensor_move_inline
from colossalai.legacy.zero.shard_utils import BaseShardStrategy from colossalai.legacy.zero.shard_utils import BaseShardStrategy
from colossalai.legacy.zero.shard_utils.commons import get_shard from colossalai.legacy.zero.shard_utils.commons import get_shard
from colossalai.legacy.zero.sharded_param.sharded_tensor import ShardedTensor from colossalai.legacy.zero.sharded_param.sharded_tensor import ShardedTensor
from colossalai.utils import get_current_device
class TensorShardStrategy(BaseShardStrategy): class TensorShardStrategy(BaseShardStrategy):
...@@ -34,9 +34,9 @@ class TensorShardStrategy(BaseShardStrategy): ...@@ -34,9 +34,9 @@ class TensorShardStrategy(BaseShardStrategy):
if t.is_sharded: if t.is_sharded:
return return
if t.payload.device.type == "cuda": if t.payload.device.type == "cuda":
assert t.payload.device == get_current_device(), ( assert t.payload.device == get_accelerator().get_current_device(), (
f"shard tensor on cuda device index {t.payload.device.index}," f"shard tensor on cuda device index {t.payload.device.index},"
f" but current cuda device is {get_current_device()}" f" but current cuda device is {get_accelerator().get_current_device()}"
) )
sharded_payload, _ = get_shard(t.payload, dist.get_rank(process_group), dist.get_world_size(process_group)) sharded_payload, _ = get_shard(t.payload, dist.get_rank(process_group), dist.get_world_size(process_group))
t.payload_reset(sharded_payload) t.payload_reset(sharded_payload)
...@@ -50,7 +50,9 @@ class TensorShardStrategy(BaseShardStrategy): ...@@ -50,7 +50,9 @@ class TensorShardStrategy(BaseShardStrategy):
world_size = dist.get_world_size(process_group) world_size = dist.get_world_size(process_group)
rank = dist.get_rank(process_group) rank = dist.get_rank(process_group)
buffer = torch.empty(payload_numel * world_size, dtype=t.payload.dtype, device=get_current_device()) buffer = torch.empty(
payload_numel * world_size, dtype=t.payload.dtype, device=get_accelerator().get_current_device()
)
buffer_list = list(torch.chunk(buffer, chunks=world_size, dim=0)) buffer_list = list(torch.chunk(buffer, chunks=world_size, dim=0))
buffer_list[rank].copy_(t.payload) buffer_list[rank].copy_(t.payload)
......
...@@ -10,6 +10,7 @@ import torch.nn as nn ...@@ -10,6 +10,7 @@ import torch.nn as nn
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from colossalai.accelerator import get_accelerator
from colossalai.legacy.context.parallel_mode import ParallelMode from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.legacy.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.utils.memory import colo_device_memory_capacity from colossalai.legacy.utils.memory import colo_device_memory_capacity
...@@ -22,7 +23,7 @@ from colossalai.legacy.zero.gemini.tensor_utils import colo_model_data_move_to_c ...@@ -22,7 +23,7 @@ from colossalai.legacy.zero.gemini.tensor_utils import colo_model_data_move_to_c
from colossalai.legacy.zero.shard_utils import BaseShardStrategy from colossalai.legacy.zero.shard_utils import BaseShardStrategy
from colossalai.legacy.zero.sharded_model.reduce_scatter import ReduceScatterBucketer from colossalai.legacy.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import disposable, get_current_device from colossalai.utils import disposable
from colossalai.zero.gemini.memory_tracer import MemStatsCollector from colossalai.zero.gemini.memory_tracer import MemStatsCollector
from ._utils import ( from ._utils import (
...@@ -212,8 +213,12 @@ class ShardedModelV2(nn.Module): ...@@ -212,8 +213,12 @@ class ShardedModelV2(nn.Module):
self.logger.error(f"dump memory tracer collected information to a {filename}", ranks=[0]) self.logger.error(f"dump memory tracer collected information to a {filename}", ranks=[0])
if gpc.get_global_rank() == 0: if gpc.get_global_rank() == 0:
with open(filename, "w+") as f: with open(filename, "w+") as f:
f.write(f"cuda reserved {torch.cuda.memory_reserved(get_current_device()) / 1e9} GB\n") f.write(
f.write(f"cuda max allocated {torch.cuda.max_memory_allocated(get_current_device()) / 1e9} GB\n") f"cuda reserved {torch.cuda.memory_reserved(get_accelerator().get_current_device()) / 1e9} GB\n"
)
f.write(
f"cuda max allocated {torch.cuda.max_memory_allocated(get_accelerator().get_current_device()) / 1e9} GB\n"
)
f.write("CUDA model data (GB)\n") f.write("CUDA model data (GB)\n")
f.write("\n") f.write("\n")
f.write("CUDA non model data (GB)\n") f.write("CUDA non model data (GB)\n")
...@@ -266,7 +271,8 @@ class ShardedModelV2(nn.Module): ...@@ -266,7 +271,8 @@ class ShardedModelV2(nn.Module):
# model data is fixed in cuda during training. # model data is fixed in cuda during training.
# cuda margin space can be used to store OS. # cuda margin space can be used to store OS.
self._cuda_margin_space = ( self._cuda_margin_space = (
colo_device_memory_capacity(get_current_device()) - self._memstats_collector._memstats.max_overall_cuda colo_device_memory_capacity(get_accelerator().get_current_device())
- self._memstats_collector._memstats.max_overall_cuda
) )
@torch.no_grad() @torch.no_grad()
......
...@@ -3,13 +3,13 @@ from typing import Optional ...@@ -3,13 +3,13 @@ from typing import Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.accelerator import get_accelerator
from colossalai.legacy.registry import OPHOOKS from colossalai.legacy.registry import OPHOOKS
from colossalai.legacy.zero.gemini.ophooks import BaseOpHook from colossalai.legacy.zero.gemini.ophooks import BaseOpHook
from colossalai.legacy.zero.gemini.stateful_tensor import TensorState from colossalai.legacy.zero.gemini.stateful_tensor import TensorState
from colossalai.legacy.zero.gemini.stateful_tensor_mgr import StatefulTensorMgr from colossalai.legacy.zero.gemini.stateful_tensor_mgr import StatefulTensorMgr
from colossalai.legacy.zero.shard_utils import BaseShardStrategy from colossalai.legacy.zero.shard_utils import BaseShardStrategy
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device
from colossalai.zero.gemini.memory_tracer import MemStatsCollector from colossalai.zero.gemini.memory_tracer import MemStatsCollector
...@@ -33,7 +33,7 @@ class ZeroHook(BaseOpHook): ...@@ -33,7 +33,7 @@ class ZeroHook(BaseOpHook):
self.process_group = process_group self.process_group = process_group
# NOTE(jiaruifang) Now the computing device of FWD and BWD is always on GPU # NOTE(jiaruifang) Now the computing device of FWD and BWD is always on GPU
self.computing_device = get_current_device() self.computing_device = get_accelerator().get_current_device()
self._memstarts_collector = memstarts_collector self._memstarts_collector = memstarts_collector
self._stateful_tensor_mgr = stateful_tensor_mgr self._stateful_tensor_mgr = stateful_tensor_mgr
......
...@@ -8,9 +8,9 @@ import torch.nn as nn ...@@ -8,9 +8,9 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from colossalai.accelerator import get_accelerator
from colossalai.moe._operation import moe_cumsum from colossalai.moe._operation import moe_cumsum
from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.manager import MOE_MANAGER
from colossalai.utils import get_current_device
class MoeRouter(nn.Module, ABC): class MoeRouter(nn.Module, ABC):
...@@ -24,14 +24,16 @@ class MoeRouter(nn.Module, ABC): ...@@ -24,14 +24,16 @@ class MoeRouter(nn.Module, ABC):
drop_tks (bool, optional): Whether drops tokens in evaluation drop_tks (bool, optional): Whether drops tokens in evaluation
""" """
def __init__(self, def __init__(
k_value: int, self,
capacity_factor_train: float, k_value: int,
capacity_factor_eval: float, capacity_factor_train: float,
min_capacity: int, capacity_factor_eval: float,
noisy_func: Optional[Callable] = None, min_capacity: int,
drop_tks: bool = True, noisy_func: Optional[Callable] = None,
use_kernel: bool = False): drop_tks: bool = True,
use_kernel: bool = False,
):
super().__init__() super().__init__()
self.k_value = k_value self.k_value = k_value
self.capacity_factor_train = capacity_factor_train self.capacity_factor_train = capacity_factor_train
...@@ -68,8 +70,9 @@ class MoeRouter(nn.Module, ABC): ...@@ -68,8 +70,9 @@ class MoeRouter(nn.Module, ABC):
if router_probs.dim() == expert_indices.dim() == 2: if router_probs.dim() == expert_indices.dim() == 2:
router_probs = router_probs.unsqueeze(0) router_probs = router_probs.unsqueeze(0)
expert_indices = expert_indices.unsqueeze(0) expert_indices = expert_indices.unsqueeze(0)
assert router_probs.dim() == expert_indices.dim() == 3, \ assert (
"router_probs must be 3D tensor and expert_indices must be 4D tensor" router_probs.dim() == expert_indices.dim() == 3
), "router_probs must be 3D tensor and expert_indices must be 4D tensor"
# Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts]. # Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts].
expert_mask = F.one_hot(expert_indices, num_experts) expert_mask = F.one_hot(expert_indices, num_experts)
...@@ -122,25 +125,29 @@ class Top1Router(MoeRouter): ...@@ -122,25 +125,29 @@ class Top1Router(MoeRouter):
drop_tks (bool, optional): Whether drops tokens in evaluation drop_tks (bool, optional): Whether drops tokens in evaluation
""" """
def __init__(self, def __init__(
capacity_factor_train: float = 1.25, self,
capacity_factor_eval: float = 2.0, capacity_factor_train: float = 1.25,
min_capacity: int = 4, capacity_factor_eval: float = 2.0,
select_policy: str = "first", min_capacity: int = 4,
noisy_func: Optional[Callable] = None, select_policy: str = "first",
drop_tks: bool = True): noisy_func: Optional[Callable] = None,
super().__init__(k_value=1, drop_tks: bool = True,
capacity_factor_train=capacity_factor_train, ):
capacity_factor_eval=capacity_factor_eval, super().__init__(
min_capacity=min_capacity, k_value=1,
noisy_func=noisy_func, capacity_factor_train=capacity_factor_train,
drop_tks=drop_tks) capacity_factor_eval=capacity_factor_eval,
min_capacity=min_capacity,
noisy_func=noisy_func,
drop_tks=drop_tks,
)
self.select_policy = select_policy self.select_policy = select_policy
assert select_policy in {"first", "random"} assert select_policy in {"first", "random"}
if select_policy == "random": if select_policy == "random":
self.uniform = torch.distributions.uniform.Uniform( self.uniform = torch.distributions.uniform.Uniform(
low=torch.tensor(0.0, device=get_current_device()), low=torch.tensor(0.0, device=get_accelerator().get_current_device()),
high=torch.tensor(1.0, device=get_current_device()) high=torch.tensor(1.0, device=get_accelerator().get_current_device()),
).rsample ).rsample
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple: def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple:
...@@ -216,18 +223,22 @@ class Top2Router(MoeRouter): ...@@ -216,18 +223,22 @@ class Top2Router(MoeRouter):
drop_tks (bool, optional): Whether drops tokens in evaluation. drop_tks (bool, optional): Whether drops tokens in evaluation.
""" """
def __init__(self, def __init__(
capacity_factor_train: float = 1.25, self,
capacity_factor_eval: float = 2.0, capacity_factor_train: float = 1.25,
min_capacity: int = 4, capacity_factor_eval: float = 2.0,
noisy_func: Optional[Callable] = None, min_capacity: int = 4,
drop_tks: bool = True): noisy_func: Optional[Callable] = None,
super().__init__(k_value=2, drop_tks: bool = True,
capacity_factor_train=capacity_factor_train, ):
capacity_factor_eval=capacity_factor_eval, super().__init__(
min_capacity=min_capacity, k_value=2,
noisy_func=noisy_func, capacity_factor_train=capacity_factor_train,
drop_tks=drop_tks) capacity_factor_eval=capacity_factor_eval,
min_capacity=min_capacity,
noisy_func=noisy_func,
drop_tks=drop_tks,
)
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple: def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple:
""" """
...@@ -255,8 +266,8 @@ class Top2Router(MoeRouter): ...@@ -255,8 +266,8 @@ class Top2Router(MoeRouter):
top2_idx = torch.argmax(logits_except1, dim=-1) top2_idx = torch.argmax(logits_except1, dim=-1)
mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32) mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32)
cmask = (mask1 + mask2) # loss: [s, e] cmask = mask1 + mask2 # loss: [s, e]
cmask = cmask.float() / 2.0 # div 2 to normalize it to 1 cmask = cmask.float() / 2.0 # div 2 to normalize it to 1
# calculate loss # calculate loss
expert_indices = torch.stack([top1_idx, top2_idx], dim=-1) expert_indices = torch.stack([top1_idx, top2_idx], dim=-1)
...@@ -269,7 +280,7 @@ class Top2Router(MoeRouter): ...@@ -269,7 +280,7 @@ class Top2Router(MoeRouter):
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group) dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group)
capacity = max_num.item() capacity = max_num.item()
rank1 = moe_cumsum(mask1, use_kernel=self.use_kernel) # rank1: [s, e] rank1 = moe_cumsum(mask1, use_kernel=self.use_kernel) # rank1: [s, e]
rank2 = moe_cumsum(mask2, use_kernel=self.use_kernel) rank2 = moe_cumsum(mask2, use_kernel=self.use_kernel)
rank2 += torch.sum(mask1, dim=-2, keepdim=True) rank2 += torch.sum(mask1, dim=-2, keepdim=True)
...@@ -336,15 +347,18 @@ class TopKRouter(MoeRouter): ...@@ -336,15 +347,18 @@ class TopKRouter(MoeRouter):
oversubscribed / reach capacity. oversubscribed / reach capacity.
""" """
def __init__(self, def __init__(
num_selected_experts: int, self,
capacity_factor_train: float = 1.25, num_selected_experts: int,
capacity_factor_eval: float = 2.0, capacity_factor_train: float = 1.25,
min_capacity: int = 4, capacity_factor_eval: float = 2.0,
noisy_func: Optional[Callable] = None, min_capacity: int = 4,
drop_tks: bool = True): noisy_func: Optional[Callable] = None,
super().__init__(num_selected_experts, capacity_factor_train, capacity_factor_eval, min_capacity, noisy_func, drop_tks: bool = True,
drop_tks) ):
super().__init__(
num_selected_experts, capacity_factor_train, capacity_factor_eval, min_capacity, noisy_func, drop_tks
)
def forward( def forward(
self, self,
...@@ -410,7 +424,7 @@ class TopKRouter(MoeRouter): ...@@ -410,7 +424,7 @@ class TopKRouter(MoeRouter):
# The combine array will be used for combining expert outputs, scaled by the # The combine array will be used for combining expert outputs, scaled by the
# router probabilities. Shape: [num_groups, tokens_per_group, num_experts, # router probabilities. Shape: [num_groups, tokens_per_group, num_experts,
# expert_capacity]. # expert_capacity].
combine_array = torch.einsum('...te,...tec->...tec', router_probs, dispatch_mask) combine_array = torch.einsum("...te,...tec->...tec", router_probs, dispatch_mask)
return combine_array, dispatch_mask return combine_array, dispatch_mask
......
...@@ -7,13 +7,12 @@ import torch.distributed as dist ...@@ -7,13 +7,12 @@ import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from colossalai.accelerator import get_accelerator
from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.manager import MOE_MANAGER
from colossalai.tensor.moe_tensor.api import get_dp_group, get_dp_group_ranks, get_ep_size, is_moe_tensor from colossalai.tensor.moe_tensor.api import get_dp_group, get_dp_group_ranks, get_ep_size, is_moe_tensor
from colossalai.utils import get_current_device
class ForceFP32Parameter(torch.nn.Parameter): class ForceFP32Parameter(torch.nn.Parameter):
def half(self, memory_format=None): def half(self, memory_format=None):
return self.data.clone() return self.data.clone()
...@@ -30,8 +29,8 @@ class NormalNoiseGenerator: ...@@ -30,8 +29,8 @@ class NormalNoiseGenerator:
def __init__(self, num_experts: int): def __init__(self, num_experts: int):
self.normal = torch.distributions.normal.Normal( self.normal = torch.distributions.normal.Normal(
loc=torch.tensor(0.0, device=get_current_device()), loc=torch.tensor(0.0, device=get_accelerator().get_current_device()),
scale=torch.tensor(1.0 / num_experts**2, device=get_current_device()), scale=torch.tensor(1.0 / num_experts**2, device=get_accelerator().get_current_device()),
).rsample ).rsample
def __call__(self, inputs: torch.Tensor): def __call__(self, inputs: torch.Tensor):
...@@ -52,8 +51,8 @@ class UniformNoiseGenerator: ...@@ -52,8 +51,8 @@ class UniformNoiseGenerator:
def __init__(self, eps: float = 1e-2): def __init__(self, eps: float = 1e-2):
self.uniform = torch.distributions.uniform.Uniform( self.uniform = torch.distributions.uniform.Uniform(
low=torch.tensor(1.0 - eps, device=get_current_device()), low=torch.tensor(1.0 - eps, device=get_accelerator().get_current_device()),
high=torch.tensor(1.0 + eps, device=get_current_device()), high=torch.tensor(1.0 + eps, device=get_accelerator().get_current_device()),
).rsample ).rsample
def __call__(self, inputs: torch.Tensor): def __call__(self, inputs: torch.Tensor):
...@@ -142,7 +141,7 @@ def get_moe_epsize_param_dict(model: nn.Module) -> Dict[int, List[nn.Parameter]] ...@@ -142,7 +141,7 @@ def get_moe_epsize_param_dict(model: nn.Module) -> Dict[int, List[nn.Parameter]]
epsize_param_dict = dict() epsize_param_dict = dict()
for param in model.parameters(): for param in model.parameters():
if not is_moe_tensor(param): if not is_moe_tensor(param):
ep_size = 1 # set ep_size to 1 for dp parameters ep_size = 1 # set ep_size to 1 for dp parameters
else: else:
ep_size = get_ep_size(param) ep_size = get_ep_size(param)
if ep_size not in epsize_param_dict: if ep_size not in epsize_param_dict:
...@@ -193,18 +192,13 @@ def create_ep_hierarchical_group( ...@@ -193,18 +192,13 @@ def create_ep_hierarchical_group(
assert nproc_per_node is not None, "Please use torchrun to launch the job, or specify nproc_per_node manually." assert nproc_per_node is not None, "Please use torchrun to launch the job, or specify nproc_per_node manually."
nproc_per_node = int(nproc_per_node) nproc_per_node = int(nproc_per_node)
else: else:
assert dist.get_world_size() % nproc_per_node == 0, \ assert dist.get_world_size() % nproc_per_node == 0, "nproc_per_node should be a divisor of world_size."
"nproc_per_node should be a divisor of world_size."
num_node = dist.get_world_size() // nproc_per_node num_node = dist.get_world_size() // nproc_per_node
intra_src_rank = None intra_src_rank = None
ep_intra_node_group = None ep_intra_node_group = None
for i in range(num_node): for i in range(num_node):
ep_intra_ranks = [ ep_intra_ranks = [i * nproc_per_node + j for j in range(nproc_per_node) if j in ep_group_ranks]
i * nproc_per_node + j
for j in range(nproc_per_node)
if j in ep_group_ranks
]
group = dist.new_group(ep_intra_ranks) group = dist.new_group(ep_intra_ranks)
if rank in ep_intra_ranks: if rank in ep_intra_ranks:
assert ep_intra_node_group is None assert ep_intra_node_group is None
...@@ -212,10 +206,7 @@ def create_ep_hierarchical_group( ...@@ -212,10 +206,7 @@ def create_ep_hierarchical_group(
intra_src_rank = ep_intra_ranks[0] intra_src_rank = ep_intra_ranks[0]
ep_inter_node_group = None ep_inter_node_group = None
ep_inter_ranks = [ ep_inter_ranks = [ep_group_ranks[0] + i * nproc_per_node for i in range(num_node)]
ep_group_ranks[0] + i * nproc_per_node
for i in range(num_node)
]
if len(ep_inter_ranks) > 1: if len(ep_inter_ranks) > 1:
group = dist.new_group(ep_inter_ranks) group = dist.new_group(ep_inter_ranks)
if rank in ep_inter_ranks: if rank in ep_inter_ranks:
......
...@@ -7,10 +7,10 @@ import torch.cuda ...@@ -7,10 +7,10 @@ import torch.cuda
from torch.nn import Module from torch.nn import Module
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from colossalai.accelerator import get_accelerator
from colossalai.inference.engine.microbatch_manager import MicroBatchManager, Status from colossalai.inference.engine.microbatch_manager import MicroBatchManager, Status
from colossalai.pipeline.p2p import PipelineP2PCommunication from colossalai.pipeline.p2p import PipelineP2PCommunication
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils.device import get_current_device
from ._utils import get_batch_size, get_micro_batch, model_forward, to_device from ._utils import get_batch_size, get_micro_batch, model_forward, to_device
from .base import PipelineSchedule from .base import PipelineSchedule
...@@ -86,7 +86,7 @@ class GenerateSchedule(PipelineSchedule): ...@@ -86,7 +86,7 @@ class GenerateSchedule(PipelineSchedule):
""" """
micro_batch = get_micro_batch(self.batch, self.microbatch_offset, self.microbatch_size) micro_batch = get_micro_batch(self.batch, self.microbatch_offset, self.microbatch_size)
self.microbatch_offset += self.microbatch_size self.microbatch_offset += self.microbatch_size
return tree_map(partial(to_device, device=get_current_device()), micro_batch) return tree_map(partial(to_device, device=get_accelerator().get_current_device()), micro_batch)
def _prepare_inputs_for_interval_stage(self): def _prepare_inputs_for_interval_stage(self):
""" """
......
...@@ -6,10 +6,10 @@ import torch.cuda ...@@ -6,10 +6,10 @@ import torch.cuda
from torch.nn import Module from torch.nn import Module
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from colossalai.accelerator import get_accelerator
from colossalai.interface import OptimizerWrapper from colossalai.interface import OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication from colossalai.pipeline.p2p import PipelineP2PCommunication
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils.device import get_current_device
from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device
from .base import PipelineSchedule from .base import PipelineSchedule
...@@ -56,7 +56,7 @@ class InterleavedSchedule(PipelineSchedule): ...@@ -56,7 +56,7 @@ class InterleavedSchedule(PipelineSchedule):
""" """
micro_batch = get_micro_batch(self.batch, self.microbatch_offset[model_chunk_id], self.microbatch_size) micro_batch = get_micro_batch(self.batch, self.microbatch_offset[model_chunk_id], self.microbatch_size)
self.microbatch_offset[model_chunk_id] += self.microbatch_size self.microbatch_offset[model_chunk_id] += self.microbatch_size
return tree_map(partial(to_device, device=get_current_device()), micro_batch) return tree_map(partial(to_device, device=get_accelerator().get_current_device()), micro_batch)
def get_model_chunk_id(self, microbatch_id: int, forward: bool) -> int: def get_model_chunk_id(self, microbatch_id: int, forward: bool) -> int:
"""Helper method to get the model chunk ID given the iteration number. """Helper method to get the model chunk ID given the iteration number.
...@@ -292,7 +292,7 @@ class InterleavedSchedule(PipelineSchedule): ...@@ -292,7 +292,7 @@ class InterleavedSchedule(PipelineSchedule):
outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None
if return_loss and self.stage_manager.is_last_stage(): if return_loss and self.stage_manager.is_last_stage():
accum_loss = torch.zeros(1, device=get_current_device()) accum_loss = torch.zeros(1, device=get_accelerator().get_current_device())
else: else:
accum_loss = None accum_loss = None
......
...@@ -6,10 +6,10 @@ import torch.cuda ...@@ -6,10 +6,10 @@ import torch.cuda
from torch.nn import Module from torch.nn import Module
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from colossalai.accelerator import get_accelerator
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication from colossalai.pipeline.p2p import PipelineP2PCommunication
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils.device import get_current_device
from ._utils import ( from ._utils import (
detach, detach,
...@@ -80,7 +80,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): ...@@ -80,7 +80,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
""" """
micro_batch = get_micro_batch(self.batch, self.microbatch_offset, self.microbatch_size) micro_batch = get_micro_batch(self.batch, self.microbatch_offset, self.microbatch_size)
self.microbatch_offset += self.microbatch_size self.microbatch_offset += self.microbatch_size
return tree_map(partial(to_device, device=get_current_device()), micro_batch) return tree_map(partial(to_device, device=get_accelerator().get_current_device()), micro_batch)
def recv_forward(self, prev_rank: int = None) -> Any: def recv_forward(self, prev_rank: int = None) -> Any:
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage. """Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
...@@ -297,7 +297,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): ...@@ -297,7 +297,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None
if return_loss and self.stage_manager.is_last_stage(): if return_loss and self.stage_manager.is_last_stage():
accum_loss = torch.zeros(1, device=get_current_device()) accum_loss = torch.zeros(1, device=get_accelerator().get_current_device())
else: else:
accum_loss = None accum_loss = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment