Unverified Commit 8823cc48 authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

Merge pull request #5310 from hpcaitech/feature/npu

Feature/npu
parents bce9499e 73f4dc57
......@@ -4,12 +4,12 @@ from torch.cuda.amp import custom_bwd, custom_fwd
from torch.nn.functional import cross_entropy
from torch.nn.modules.loss import _Loss
from colossalai.accelerator import get_accelerator
from colossalai.legacy.constants import INPUT_GROUP_3D, OUTPUT_GROUP_3D, WEIGHT_GROUP_3D
from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.nn.layer.parallel_3d import reduce_by_batch_3d, split_tensor_3d
from colossalai.legacy.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
from colossalai.legacy.registry import LOSSES
from colossalai.utils import get_current_device
@LOSSES.register_module
......@@ -80,7 +80,7 @@ class _VocabParallelCrossEntropy3D(torch.autograd.Function):
target_mask = (targets < vocab_start) | (targets > vocab_end)
masked_target = targets.clone() - vocab_start
masked_target[target_mask] = 0
arange_1d = torch.arange(start=0, end=logits.size()[0], device=get_current_device())
arange_1d = torch.arange(start=0, end=logits.size()[0], device=get_accelerator().get_current_device())
predicted_logits = logits[arange_1d, masked_target]
predicted_logits = predicted_logits.clone().contiguous().view_as(targets)
predicted_logits[target_mask] = 0.0
......@@ -110,7 +110,7 @@ class _VocabParallelCrossEntropy3D(torch.autograd.Function):
grad_2d = input_grad.view(-1, partition_vocab_size)
# Add the gradient from matching classes.
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_current_device())
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_accelerator().get_current_device())
grad_2d[arange_1d, masked_target] -= 1.0 - target_mask.view(-1).float()
input_grad.mul_(output_grad.unsqueeze(dim=-1))
......
......@@ -7,12 +7,12 @@ from typing import Callable
import torch
import torch.distributed as dist
from colossalai.accelerator import get_accelerator
from colossalai.legacy.communication import all_reduce
from colossalai.legacy.context import ParallelMode
from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.registry import HOOKS
from colossalai.legacy.utils import is_no_pp_or_last_stage
from colossalai.utils import get_current_device
from ._base_hook import BaseHook
from ._commons_ import _format_number
......@@ -82,8 +82,8 @@ class LossMetric(Metric):
def __init__(self, epoch_only):
super().__init__(epoch_only=epoch_only)
self.last_step_loss = torch.zeros(1, device=get_current_device())
self.accum_loss = torch.zeros(1, device=get_current_device())
self.last_step_loss = torch.zeros(1, device=get_accelerator().get_current_device())
self.accum_loss = torch.zeros(1, device=get_accelerator().get_current_device())
self.count = 0
def reset(self) -> None:
......@@ -164,10 +164,10 @@ class AccuracyMetric(Metric):
def __init__(self, epoch_only: bool, accuracy_func: Callable):
super().__init__(epoch_only=epoch_only)
self.acc = accuracy_func
self.last_step_sum = torch.zeros(1, device=get_current_device())
self.last_step_correct = torch.zeros(1, device=get_current_device())
self.accumulated_sum = torch.zeros(1, device=get_current_device())
self.accumulated_correct = torch.zeros(1, device=get_current_device())
self.last_step_sum = torch.zeros(1, device=get_accelerator().get_current_device())
self.last_step_correct = torch.zeros(1, device=get_accelerator().get_current_device())
self.accumulated_sum = torch.zeros(1, device=get_accelerator().get_current_device())
self.accumulated_correct = torch.zeros(1, device=get_accelerator().get_current_device())
def reset(self) -> None:
self.last_step_sum.zero_()
......@@ -320,10 +320,10 @@ class ThroughputMetric(Metric):
super().__init__(epoch_only=epoch_only)
self.ignored_steps = ignored_steps
self.cur_steps = 0
self.accumulated_num_samples = torch.zeros(1, device=get_current_device())
self.accumulated_used_time = torch.zeros(1, device=get_current_device())
self.last_step_num_samples = torch.zeros(1, device=get_current_device())
self.last_step_used_time = torch.zeros(1, device=get_current_device())
self.accumulated_num_samples = torch.zeros(1, device=get_accelerator().get_current_device())
self.accumulated_used_time = torch.zeros(1, device=get_accelerator().get_current_device())
self.last_step_num_samples = torch.zeros(1, device=get_accelerator().get_current_device())
self.last_step_used_time = torch.zeros(1, device=get_accelerator().get_current_device())
self._tflop_per_step = tflop_per_step
self._use_local = use_local
......
......@@ -6,8 +6,8 @@ import weakref
import torch
from torch.utils.checkpoint import check_backward_validity, detach_variable
from colossalai.accelerator import get_accelerator
from colossalai.legacy.context.random import get_current_mode, get_states, set_mode, set_seed_states, sync_states
from colossalai.utils.device import autocast, get_current_device
def copy_to_device(obj, device):
......@@ -33,7 +33,7 @@ class CheckpointFunction(torch.autograd.Function):
check_backward_validity(args)
ctx.run_function = run_function
ctx.activation_offload = activation_offload
ctx.device = get_current_device()
ctx.device = get_accelerator().get_current_device()
# preserve rng states
ctx.fwd_cpu_rng_state = torch.get_rng_state()
......@@ -110,7 +110,7 @@ class CheckpointFunction(torch.autograd.Function):
inputs[idx] = tensors[i]
detached_inputs = detach_variable(tuple(inputs))
if ctx.had_autocast_in_fwd:
with torch.enable_grad(), autocast():
with torch.enable_grad(), get_accelerator().autocast()():
outputs = ctx.run_function(*detached_inputs)
else:
with torch.enable_grad():
......@@ -226,7 +226,7 @@ def _checkpoint_without_reentrant(function, activation_offload=False, *args):
# rerun forward, the inner_pack will store all the activations in storage
if has_autocast_in_fwd:
with torch.enable_grad(), autocast(), torch.autograd.graph.saved_tensors_hooks(
with torch.enable_grad(), get_accelerator().autocast()(), torch.autograd.graph.saved_tensors_hooks(
inner_pack, inner_unpack
):
_unused = function(*args)
......@@ -245,7 +245,7 @@ def _checkpoint_without_reentrant(function, activation_offload=False, *args):
# get device if we need to offload the activation
if activation_offload:
device = get_current_device()
device = get_accelerator().get_current_device()
# run function with pack and unpack as saved_tensors_hooks
with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
......
......@@ -96,9 +96,9 @@ def _calc_l2_norm(grads):
global fused_optim
if fused_optim is None:
from colossalai.kernel.op_builder import FusedOptimBuilder
from colossalai.kernel.kernel_loader import FusedOptimizerLoader
fused_optim = FusedOptimBuilder().load()
fused_optim = FusedOptimizerLoader().load()
norm = 0.0
if len(grads) > 0:
......
......@@ -6,9 +6,9 @@ import torch
import torch.distributed as dist
from packaging import version
from colossalai.accelerator import get_accelerator
from colossalai.legacy.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device
_GLOBAL_CUDA_MEM_FRACTION = 1.0
_GLOBAL_CPU_MEM_CAPACITY = -1
......@@ -112,7 +112,10 @@ def colo_device_memory_capacity(device: torch.device) -> int:
# In the context of 1-CPU-N-GPU, the memory capacity of the current process is 1/N overall CPU memory.
return colo_get_cpu_memory_capacity() / gpc.num_processes_on_current_node
if device.type == "cuda":
return torch.cuda.get_device_properties(get_current_device()).total_memory * _GLOBAL_CUDA_MEM_FRACTION
return (
torch.cuda.get_device_properties(get_accelerator().get_current_device()).total_memory
* _GLOBAL_CUDA_MEM_FRACTION
)
def colo_device_memory_used(device: torch.device) -> int:
......@@ -153,7 +156,7 @@ def colo_set_process_memory_fraction(ratio: float) -> None:
return
global _GLOBAL_CUDA_MEM_FRACTION
_GLOBAL_CUDA_MEM_FRACTION = ratio
torch.cuda.set_per_process_memory_fraction(_GLOBAL_CUDA_MEM_FRACTION, get_current_device())
torch.cuda.set_per_process_memory_fraction(_GLOBAL_CUDA_MEM_FRACTION, get_accelerator().get_current_device())
def colo_set_cpu_memory_capacity(size: int) -> None:
......
......@@ -8,7 +8,7 @@ import torch.distributed as dist
from torch.autograd.profiler import profile
from torch.distributed import ReduceOp
from colossalai.utils import get_current_device
from colossalai.accelerator import get_accelerator
from .prof_utils import BaseProfiler, _format_bandwidth, _format_memory, _format_time
......@@ -177,7 +177,7 @@ class CommProfiler(BaseProfiler):
assert current_comm_event is not None, "dist op has not been found"
buffer = torch.tensor([current_comm_event.self_cuda_time], device=get_current_device())
buffer = torch.tensor([current_comm_event.self_cuda_time], device=get_accelerator().get_current_device())
torch_all_reduce(buffer, op=ReduceOp.MIN, group=group)
current_comm_event.self_cuda_time = buffer.item()
......
......@@ -3,7 +3,7 @@ import types
from time import time
from typing import List
from colossalai.utils.device import get_current_device
from colossalai.accelerator import get_accelerator
from .stateful_tensor import StatefulTensor, TensorState
from .tensor_placement_policy import TensorPlacementPolicy
......@@ -69,7 +69,7 @@ class StatefulTensorMgr(object):
# move COMPUTE tensors to CUDA
self._cpu_gpu_move_volume += cuda_demand
for t in move_to_cuda_tensor_list:
colo_model_data_tensor_move_inline(t, get_current_device())
colo_model_data_tensor_move_inline(t, get_accelerator().get_current_device())
@property
def cpu_gpu_move_volume(self):
......
......@@ -5,8 +5,8 @@ from typing import List, Optional, Type
import torch
from colossalai.accelerator import get_accelerator
from colossalai.legacy.utils.memory import colo_device_memory_capacity
from colossalai.utils import get_current_device
from colossalai.zero.gemini.memory_tracer import MemStatsCollector
from .stateful_tensor import StatefulTensor
......@@ -38,7 +38,7 @@ class CPUTensorPlacementPolicy(TensorPlacementPolicy):
class CUDATensorPlacementPolicy(TensorPlacementPolicy):
def __init__(self, mem_stats_collector: Optional[MemStatsCollector] = None) -> None:
assert torch.cuda.is_available(), "Cannot use CUDATensorPlacementPolicy when CUDA is not available"
super().__init__(get_current_device(), mem_stats_collector=mem_stats_collector)
super().__init__(get_accelerator().get_current_device(), mem_stats_collector=mem_stats_collector)
def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) -> int:
return 0, 0
......@@ -78,7 +78,7 @@ class AutoTensorPlacementPolicy(TensorPlacementPolicy):
int: the volume of memory that is evicted
"""
start = time()
cuda_capacity = colo_device_memory_capacity(get_current_device())
cuda_capacity = colo_device_memory_capacity(get_accelerator().get_current_device())
used_cuda_model_data = StatefulTensor.GST_MGR.total_mem["cuda"]
if warmup:
# We designate a part of CUDA memory for model data in warmup iterations.
......
......@@ -4,8 +4,8 @@ import torch
import torch.distributed as dist
from torch._utils import _flatten_dense_tensors as flatten
from colossalai.accelerator import get_accelerator
from colossalai.legacy.zero.sharded_param.sharded_tensor import ShardedTensor
from colossalai.utils import get_current_device
from .tensor_shard_strategy import TensorShardStrategy
......@@ -30,9 +30,11 @@ class BucketTensorShardStrategy(TensorShardStrategy):
rank = dist.get_rank(process_group)
for i in range(world_size):
if i == rank:
buffer_list.append(flatten([t.payload for t in tensor_list]).cuda(get_current_device()))
buffer_list.append(
flatten([t.payload for t in tensor_list]).cuda(get_accelerator().get_current_device())
)
else:
buffer_list.append(torch.zeros(buffer_size, dtype=dtype, device=get_current_device()))
buffer_list.append(torch.zeros(buffer_size, dtype=dtype, device=get_accelerator().get_current_device()))
dist.all_gather(buffer_list, buffer_list[rank], group=process_group)
# Move to target device before splitting buffer
# Ensure we utilize maximum PCIE bandwidth
......
......@@ -3,11 +3,11 @@ from typing import List, Optional
import torch
import torch.distributed as dist
from colossalai.accelerator import get_accelerator
from colossalai.legacy.zero.gemini.tensor_utils import colo_model_data_tensor_move_inline
from colossalai.legacy.zero.shard_utils import BaseShardStrategy
from colossalai.legacy.zero.shard_utils.commons import get_shard
from colossalai.legacy.zero.sharded_param.sharded_tensor import ShardedTensor
from colossalai.utils import get_current_device
class TensorShardStrategy(BaseShardStrategy):
......@@ -34,9 +34,9 @@ class TensorShardStrategy(BaseShardStrategy):
if t.is_sharded:
return
if t.payload.device.type == "cuda":
assert t.payload.device == get_current_device(), (
assert t.payload.device == get_accelerator().get_current_device(), (
f"shard tensor on cuda device index {t.payload.device.index},"
f" but current cuda device is {get_current_device()}"
f" but current cuda device is {get_accelerator().get_current_device()}"
)
sharded_payload, _ = get_shard(t.payload, dist.get_rank(process_group), dist.get_world_size(process_group))
t.payload_reset(sharded_payload)
......@@ -50,7 +50,9 @@ class TensorShardStrategy(BaseShardStrategy):
world_size = dist.get_world_size(process_group)
rank = dist.get_rank(process_group)
buffer = torch.empty(payload_numel * world_size, dtype=t.payload.dtype, device=get_current_device())
buffer = torch.empty(
payload_numel * world_size, dtype=t.payload.dtype, device=get_accelerator().get_current_device()
)
buffer_list = list(torch.chunk(buffer, chunks=world_size, dim=0))
buffer_list[rank].copy_(t.payload)
......
......@@ -10,6 +10,7 @@ import torch.nn as nn
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from colossalai.accelerator import get_accelerator
from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.utils.memory import colo_device_memory_capacity
......@@ -22,7 +23,7 @@ from colossalai.legacy.zero.gemini.tensor_utils import colo_model_data_move_to_c
from colossalai.legacy.zero.shard_utils import BaseShardStrategy
from colossalai.legacy.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
from colossalai.logging import get_dist_logger
from colossalai.utils import disposable, get_current_device
from colossalai.utils import disposable
from colossalai.zero.gemini.memory_tracer import MemStatsCollector
from ._utils import (
......@@ -212,8 +213,12 @@ class ShardedModelV2(nn.Module):
self.logger.error(f"dump memory tracer collected information to a {filename}", ranks=[0])
if gpc.get_global_rank() == 0:
with open(filename, "w+") as f:
f.write(f"cuda reserved {torch.cuda.memory_reserved(get_current_device()) / 1e9} GB\n")
f.write(f"cuda max allocated {torch.cuda.max_memory_allocated(get_current_device()) / 1e9} GB\n")
f.write(
f"cuda reserved {torch.cuda.memory_reserved(get_accelerator().get_current_device()) / 1e9} GB\n"
)
f.write(
f"cuda max allocated {torch.cuda.max_memory_allocated(get_accelerator().get_current_device()) / 1e9} GB\n"
)
f.write("CUDA model data (GB)\n")
f.write("\n")
f.write("CUDA non model data (GB)\n")
......@@ -266,7 +271,8 @@ class ShardedModelV2(nn.Module):
# model data is fixed in cuda during training.
# cuda margin space can be used to store OS.
self._cuda_margin_space = (
colo_device_memory_capacity(get_current_device()) - self._memstats_collector._memstats.max_overall_cuda
colo_device_memory_capacity(get_accelerator().get_current_device())
- self._memstats_collector._memstats.max_overall_cuda
)
@torch.no_grad()
......
......@@ -3,13 +3,13 @@ from typing import Optional
import torch
import torch.distributed as dist
from colossalai.accelerator import get_accelerator
from colossalai.legacy.registry import OPHOOKS
from colossalai.legacy.zero.gemini.ophooks import BaseOpHook
from colossalai.legacy.zero.gemini.stateful_tensor import TensorState
from colossalai.legacy.zero.gemini.stateful_tensor_mgr import StatefulTensorMgr
from colossalai.legacy.zero.shard_utils import BaseShardStrategy
from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device
from colossalai.zero.gemini.memory_tracer import MemStatsCollector
......@@ -33,7 +33,7 @@ class ZeroHook(BaseOpHook):
self.process_group = process_group
# NOTE(jiaruifang) Now the computing device of FWD and BWD is always on GPU
self.computing_device = get_current_device()
self.computing_device = get_accelerator().get_current_device()
self._memstarts_collector = memstarts_collector
self._stateful_tensor_mgr = stateful_tensor_mgr
......
......@@ -11,9 +11,9 @@ MOE_KERNEL = None
def load_moe():
global MOE_KERNEL
from colossalai.kernel.op_builder import MOEBuilder
from colossalai.kernel.kernel_loader import MoeLoader
MOE_KERNEL = MOEBuilder().load()
MOE_KERNEL = MoeLoader().load()
class AllGather(torch.autograd.Function):
......@@ -145,14 +145,8 @@ class AllToAll(torch.autograd.Function):
class HierarchicalAllToAll(torch.autograd.Function):
@staticmethod
def forward(
ctx: Any,
inputs: Tensor,
groups: Tuple[ProcessGroup, ProcessGroup],
src_rank: int
) -> Tensor:
def forward(ctx: Any, inputs: Tensor, groups: Tuple[ProcessGroup, ProcessGroup], src_rank: int) -> Tensor:
"""
Returns:
outputs: Tensor
......@@ -276,8 +270,9 @@ class MoeCombine(torch.autograd.Function):
if tokens_grad.dtype != torch.float32:
tokens_grad = tokens_grad.to(torch.float32)
d_expert, d_logits = MOE_KERNEL.combine_backward(ctx.s, ctx.e, ctx.c, ctx.h, tokens_grad, expert_tokens, logits,
mask, dest_idx)
d_expert, d_logits = MOE_KERNEL.combine_backward(
ctx.s, ctx.e, ctx.c, ctx.h, tokens_grad, expert_tokens, logits, mask, dest_idx
)
if d_expert.dtype != ctx.dtype:
d_expert = d_expert.to(ctx.dtype)
......
......@@ -8,9 +8,9 @@ import torch.nn as nn
import torch.nn.functional as F
from torch.distributed import ProcessGroup
from colossalai.accelerator import get_accelerator
from colossalai.moe._operation import moe_cumsum
from colossalai.moe.manager import MOE_MANAGER
from colossalai.utils import get_current_device
class MoeRouter(nn.Module, ABC):
......@@ -24,14 +24,16 @@ class MoeRouter(nn.Module, ABC):
drop_tks (bool, optional): Whether drops tokens in evaluation
"""
def __init__(self,
k_value: int,
capacity_factor_train: float,
capacity_factor_eval: float,
min_capacity: int,
noisy_func: Optional[Callable] = None,
drop_tks: bool = True,
use_kernel: bool = False):
def __init__(
self,
k_value: int,
capacity_factor_train: float,
capacity_factor_eval: float,
min_capacity: int,
noisy_func: Optional[Callable] = None,
drop_tks: bool = True,
use_kernel: bool = False,
):
super().__init__()
self.k_value = k_value
self.capacity_factor_train = capacity_factor_train
......@@ -68,8 +70,9 @@ class MoeRouter(nn.Module, ABC):
if router_probs.dim() == expert_indices.dim() == 2:
router_probs = router_probs.unsqueeze(0)
expert_indices = expert_indices.unsqueeze(0)
assert router_probs.dim() == expert_indices.dim() == 3, \
"router_probs must be 3D tensor and expert_indices must be 4D tensor"
assert (
router_probs.dim() == expert_indices.dim() == 3
), "router_probs must be 3D tensor and expert_indices must be 4D tensor"
# Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts].
expert_mask = F.one_hot(expert_indices, num_experts)
......@@ -122,25 +125,29 @@ class Top1Router(MoeRouter):
drop_tks (bool, optional): Whether drops tokens in evaluation
"""
def __init__(self,
capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0,
min_capacity: int = 4,
select_policy: str = "first",
noisy_func: Optional[Callable] = None,
drop_tks: bool = True):
super().__init__(k_value=1,
capacity_factor_train=capacity_factor_train,
capacity_factor_eval=capacity_factor_eval,
min_capacity=min_capacity,
noisy_func=noisy_func,
drop_tks=drop_tks)
def __init__(
self,
capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0,
min_capacity: int = 4,
select_policy: str = "first",
noisy_func: Optional[Callable] = None,
drop_tks: bool = True,
):
super().__init__(
k_value=1,
capacity_factor_train=capacity_factor_train,
capacity_factor_eval=capacity_factor_eval,
min_capacity=min_capacity,
noisy_func=noisy_func,
drop_tks=drop_tks,
)
self.select_policy = select_policy
assert select_policy in {"first", "random"}
if select_policy == "random":
self.uniform = torch.distributions.uniform.Uniform(
low=torch.tensor(0.0, device=get_current_device()),
high=torch.tensor(1.0, device=get_current_device())
low=torch.tensor(0.0, device=get_accelerator().get_current_device()),
high=torch.tensor(1.0, device=get_accelerator().get_current_device()),
).rsample
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple:
......@@ -216,18 +223,22 @@ class Top2Router(MoeRouter):
drop_tks (bool, optional): Whether drops tokens in evaluation.
"""
def __init__(self,
capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0,
min_capacity: int = 4,
noisy_func: Optional[Callable] = None,
drop_tks: bool = True):
super().__init__(k_value=2,
capacity_factor_train=capacity_factor_train,
capacity_factor_eval=capacity_factor_eval,
min_capacity=min_capacity,
noisy_func=noisy_func,
drop_tks=drop_tks)
def __init__(
self,
capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0,
min_capacity: int = 4,
noisy_func: Optional[Callable] = None,
drop_tks: bool = True,
):
super().__init__(
k_value=2,
capacity_factor_train=capacity_factor_train,
capacity_factor_eval=capacity_factor_eval,
min_capacity=min_capacity,
noisy_func=noisy_func,
drop_tks=drop_tks,
)
def forward(self, inputs: torch.Tensor, use_kernel: bool = False, ep_group: Optional[ProcessGroup] = None) -> Tuple:
"""
......@@ -255,8 +266,8 @@ class Top2Router(MoeRouter):
top2_idx = torch.argmax(logits_except1, dim=-1)
mask2 = F.one_hot(top2_idx, num_classes=num_experts).to(torch.int32)
cmask = (mask1 + mask2) # loss: [s, e]
cmask = cmask.float() / 2.0 # div 2 to normalize it to 1
cmask = mask1 + mask2 # loss: [s, e]
cmask = cmask.float() / 2.0 # div 2 to normalize it to 1
# calculate loss
expert_indices = torch.stack([top1_idx, top2_idx], dim=-1)
......@@ -269,7 +280,7 @@ class Top2Router(MoeRouter):
dist.all_reduce(max_num, op=dist.ReduceOp.MAX, group=ep_group)
capacity = max_num.item()
rank1 = moe_cumsum(mask1, use_kernel=self.use_kernel) # rank1: [s, e]
rank1 = moe_cumsum(mask1, use_kernel=self.use_kernel) # rank1: [s, e]
rank2 = moe_cumsum(mask2, use_kernel=self.use_kernel)
rank2 += torch.sum(mask1, dim=-2, keepdim=True)
......@@ -336,15 +347,18 @@ class TopKRouter(MoeRouter):
oversubscribed / reach capacity.
"""
def __init__(self,
num_selected_experts: int,
capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0,
min_capacity: int = 4,
noisy_func: Optional[Callable] = None,
drop_tks: bool = True):
super().__init__(num_selected_experts, capacity_factor_train, capacity_factor_eval, min_capacity, noisy_func,
drop_tks)
def __init__(
self,
num_selected_experts: int,
capacity_factor_train: float = 1.25,
capacity_factor_eval: float = 2.0,
min_capacity: int = 4,
noisy_func: Optional[Callable] = None,
drop_tks: bool = True,
):
super().__init__(
num_selected_experts, capacity_factor_train, capacity_factor_eval, min_capacity, noisy_func, drop_tks
)
def forward(
self,
......@@ -410,7 +424,7 @@ class TopKRouter(MoeRouter):
# The combine array will be used for combining expert outputs, scaled by the
# router probabilities. Shape: [num_groups, tokens_per_group, num_experts,
# expert_capacity].
combine_array = torch.einsum('...te,...tec->...tec', router_probs, dispatch_mask)
combine_array = torch.einsum("...te,...tec->...tec", router_probs, dispatch_mask)
return combine_array, dispatch_mask
......
......@@ -7,13 +7,12 @@ import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from colossalai.accelerator import get_accelerator
from colossalai.moe.manager import MOE_MANAGER
from colossalai.tensor.moe_tensor.api import get_dp_group, get_dp_group_ranks, get_ep_size, is_moe_tensor
from colossalai.utils import get_current_device
class ForceFP32Parameter(torch.nn.Parameter):
def half(self, memory_format=None):
return self.data.clone()
......@@ -30,8 +29,8 @@ class NormalNoiseGenerator:
def __init__(self, num_experts: int):
self.normal = torch.distributions.normal.Normal(
loc=torch.tensor(0.0, device=get_current_device()),
scale=torch.tensor(1.0 / num_experts**2, device=get_current_device()),
loc=torch.tensor(0.0, device=get_accelerator().get_current_device()),
scale=torch.tensor(1.0 / num_experts**2, device=get_accelerator().get_current_device()),
).rsample
def __call__(self, inputs: torch.Tensor):
......@@ -52,8 +51,8 @@ class UniformNoiseGenerator:
def __init__(self, eps: float = 1e-2):
self.uniform = torch.distributions.uniform.Uniform(
low=torch.tensor(1.0 - eps, device=get_current_device()),
high=torch.tensor(1.0 + eps, device=get_current_device()),
low=torch.tensor(1.0 - eps, device=get_accelerator().get_current_device()),
high=torch.tensor(1.0 + eps, device=get_accelerator().get_current_device()),
).rsample
def __call__(self, inputs: torch.Tensor):
......@@ -142,7 +141,7 @@ def get_moe_epsize_param_dict(model: nn.Module) -> Dict[int, List[nn.Parameter]]
epsize_param_dict = dict()
for param in model.parameters():
if not is_moe_tensor(param):
ep_size = 1 # set ep_size to 1 for dp parameters
ep_size = 1 # set ep_size to 1 for dp parameters
else:
ep_size = get_ep_size(param)
if ep_size not in epsize_param_dict:
......@@ -193,18 +192,13 @@ def create_ep_hierarchical_group(
assert nproc_per_node is not None, "Please use torchrun to launch the job, or specify nproc_per_node manually."
nproc_per_node = int(nproc_per_node)
else:
assert dist.get_world_size() % nproc_per_node == 0, \
"nproc_per_node should be a divisor of world_size."
assert dist.get_world_size() % nproc_per_node == 0, "nproc_per_node should be a divisor of world_size."
num_node = dist.get_world_size() // nproc_per_node
intra_src_rank = None
ep_intra_node_group = None
for i in range(num_node):
ep_intra_ranks = [
i * nproc_per_node + j
for j in range(nproc_per_node)
if j in ep_group_ranks
]
ep_intra_ranks = [i * nproc_per_node + j for j in range(nproc_per_node) if j in ep_group_ranks]
group = dist.new_group(ep_intra_ranks)
if rank in ep_intra_ranks:
assert ep_intra_node_group is None
......@@ -212,10 +206,7 @@ def create_ep_hierarchical_group(
intra_src_rank = ep_intra_ranks[0]
ep_inter_node_group = None
ep_inter_ranks = [
ep_group_ranks[0] + i * nproc_per_node
for i in range(num_node)
]
ep_inter_ranks = [ep_group_ranks[0] + i * nproc_per_node for i in range(num_node)]
if len(ep_inter_ranks) > 1:
group = dist.new_group(ep_inter_ranks)
if rank in ep_inter_ranks:
......
import enum
import math
from typing import Optional
import warnings
from dataclasses import dataclass
from typing import Iterable, Optional, Tuple
import torch
import torch.nn.functional as F
from einops import rearrange
from ..scaled_softmax import AttnMaskType
from .flash_attn_2 import HAS_FLASH_ATTN
from .mem_eff_attn import HAS_MEM_EFF_ATTN
from .utils import Repad, SeqLenInfo, Unpad
from colossalai.accelerator import get_accelerator
from colossalai.kernel.kernel_loader import FlashAttentionLoader
if HAS_FLASH_ATTN:
from .flash_attn_2 import flash_attention
if HAS_MEM_EFF_ATTN:
from .mem_eff_attn import mem_eff_attention
@dataclass
class SeqLenInfo:
seqlens: Iterable[int] = None
indices: torch.Tensor = None
max_seqlen: int = None
cu_seqlens: torch.Tensor = None
@staticmethod
def materialize(
attn_mask: torch.Tensor = None, size: Tuple[int] = None, device=get_accelerator().get_current_device()
):
if attn_mask is not None:
indices = torch.nonzero(attn_mask.flatten(), as_tuple=False).flatten().to(device)
seqlens = attn_mask.sum(dim=-1, dtype=torch.int32).flatten()
else:
batch_size, tgt_len = size[0], size[1]
indices = torch.arange(batch_size * tgt_len, dtype=torch.long, device=device)
seqlens = torch.LongTensor([tgt_len] * batch_size, device=device)
max_seqlen = max(seqlens)
cu_seqlens = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)).to(device)
return SeqLenInfo(seqlens.tolist(), indices, max_seqlen, cu_seqlens)
class AttnMaskType(enum.Enum):
padding = 1
causal = 2
paddedcausal = 3
class Unpad(torch.autograd.Function):
"""
Adapted from
https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
"""
@staticmethod
def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor):
ctx.save_for_backward(indices)
# [b, s, ...]
assert tensor.ndim >= 3
ctx.bsz = tensor.shape[0]
out = rearrange(tensor, "b s ... -> (b s) ...")
ctx.shape = out.shape
# [ntokens, ...]
return out[indices]
@staticmethod
def backward(ctx, grad_output):
(indices,) = ctx.saved_tensors
# [ntokens, ...]
grad = torch.zeros(ctx.shape, dtype=grad_output.dtype, device=grad_output.device)
grad[indices] = grad_output
grad = rearrange(grad, "(b s) ... -> b s ...", b=ctx.bsz)
# [b, s, ...]
return grad, None
class Repad(torch.autograd.Function):
"""
Adapted from
https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
"""
@staticmethod
def forward(ctx, tensor: torch.Tensor, indices: torch.Tensor, batch_size: int, seq_len: int):
ctx.save_for_backward(indices)
# [ntokens, ...]
tensor = tensor
out = torch.zeros((batch_size * seq_len, *tensor.shape[1:]), dtype=tensor.dtype, device=tensor.device)
# [b*s, ...]
out[indices] = tensor
return out
@staticmethod
def backward(ctx, grad_output):
(indices,) = ctx.saved_tensors
# [b*s, ...]
grad = grad_output[indices]
# [ntokens, ...]
return grad, None, None, None
class ColoAttention(torch.nn.Module):
......@@ -27,8 +106,7 @@ class ColoAttention(torch.nn.Module):
self.scale = 1 / math.sqrt(embed_dim // num_heads)
self.dropout = dropout
if not HAS_MEM_EFF_ATTN and not HAS_FLASH_ATTN:
raise Exception("flash attention can not support!")
self.attn = FlashAttentionLoader().load()
@staticmethod
def unpad(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
......@@ -44,14 +122,30 @@ class ColoAttention(torch.nn.Module):
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
origin_attn_mask: Optional[torch.Tensor] = None,
attn_mask_type: Optional[AttnMaskType] = None,
bias: Optional[torch.Tensor] = None,
):
attn = None
if HAS_FLASH_ATTN and query.dtype in [torch.float16, torch.bfloat16] and bias == None:
attn = flash_attention
else:
attn = mem_eff_attention
"""
ColoAttention
Args:
q: (batch, q_seqlen, nheads, headdim)
k: (batch, kv_seqlen, nheads, headdim)
v: (batch, kv_seqlen, nheads, headdim)
origin_attn_mask: (nheads, q_seqlen, kv_seqlen)
bias: will not be used
Return:
attn_out: (batch, q_seqlen, nheads, headdim).
"""
# if flash attention is not applicable, switch to memory effcient attention
if self.attn.__name__ == "flash_attention" and (
query.dtype not in [torch.float16, torch.bfloat16] or bias != None
):
warnings.warn(
f"flash-attn expects fp16 or bf16 but got {query.dtype}, switching to xformers' implementation."
)
self.attn = FlashAttentionLoader().load(ext_name="flash_attention_xformers_cuda")
padded = attn_mask_type is not None and attn_mask_type.value % 2 == 1
causal = attn_mask_type is not None and attn_mask_type.value > 1
......@@ -91,12 +185,13 @@ class ColoAttention(torch.nn.Module):
else:
query, key, value = torch.stack([query, key, value], dim=2).squeeze(0).unbind(dim=1)
out = attn(
out = self.attn(
query,
key,
value,
seq_len_info_q,
seq_len_info_kv,
seq_len_info_q=seq_len_info_q,
seq_len_info_kv=seq_len_info_kv,
origin_attn_mask=origin_attn_mask,
dropout_p=self.dropout,
scale=self.scale,
causal=causal,
......@@ -109,5 +204,6 @@ class ColoAttention(torch.nn.Module):
out = self.repad(out, seq_len_info_q.indices, batch_size, tgt_len)
out = rearrange(out, "(b s) h d -> b s h d", b=batch_size)
out = rearrange(out, "b s h d -> b s (h d)")
if len(out.shape) == 4:
out = rearrange(out, "b s h d -> b s (h d)")
return out
......@@ -9,7 +9,7 @@ from torch.cuda.amp import custom_bwd, custom_fwd
from torch.nn import init
from torch.nn.parameter import Parameter
from colossalai.kernel.op_builder.layernorm import LayerNormBuilder
from colossalai.kernel.kernel_loader import LayerNormLoader
try:
from colossalai._C import layer_norm
......@@ -29,7 +29,7 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
global layer_norm
if layer_norm is None:
layer_norm = LayerNormBuilder().load()
layer_norm = LayerNormLoader().load()
output, mean, invvar = layer_norm.forward_affine(input_, ctx.normalized_shape, weight_, bias_, ctx.eps)
ctx.layernorm_op = layer_norm
ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
......
# This code from NVIDIA Megatron:
# with minor changes.
import enum
import torch
import torch.nn as nn
from colossalai.kernel.kernel_loader import ScaledMaskedSoftmaxLoader, ScaledUpperTriangleMaskedSoftmaxLoader
class AttnMaskType(enum.Enum):
padding = 1
causal = 2
paddedcausal = 3
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
2. Apply upper triangular mask (typically used in gpt models).
3. Perform softmax.
"""
@staticmethod
def forward(ctx, inputs, scale):
global scaled_upper_triang_masked_softmax
if scaled_upper_triang_masked_softmax:
scaled_upper_triang_masked_softmax = ScaledUpperTriangleMaskedSoftmaxLoader().load()
scale_t = torch.tensor([scale])
softmax_results = scaled_upper_triang_masked_softmax.forward(inputs, scale_t[0])
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
@staticmethod
def backward(ctx, output_grads):
softmax_results, scale_t = ctx.saved_tensors
input_grads = scaled_upper_triang_masked_softmax.backward(output_grads, softmax_results, scale_t[0])
return input_grads, None
class ScaledMaskedSoftmax(torch.autograd.Function):
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
2. Apply the mask.
3. Perform softmax.
"""
@staticmethod
def forward(ctx, inputs, mask, scale):
scale_t = torch.tensor([scale])
# build and load kernel if not pre-built
global scaled_masked_softmax
if scaled_masked_softmax is None:
scaled_masked_softmax = ScaledMaskedSoftmaxLoader().load()
softmax_results = scaled_masked_softmax.forward(inputs, mask, scale_t[0])
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
@staticmethod
def backward(ctx, output_grads):
softmax_results, scale_t = ctx.saved_tensors
input_grads = scaled_masked_softmax.backward(output_grads, softmax_results, scale_t[0])
return input_grads, None, None, None
class FusedScaleMaskSoftmax(nn.Module):
"""
Fused operation: scaling + mask + softmax
Arguments:
input_in_fp16: Flag to indicate if input in fp16 data format.
input_in_bf16: Flag to indicate if input in bf16 data format.
attn_mask_type: Attention mask type (pad or causal)
scaled_masked_softmax_fusion: Flag to indicate user want to use softmax fusion
mask_func: Mask function to be applied.
softmax_in_fp32: If True, softmax in performed at fp32 precision.
scale: Scaling factor used in input tensor scaling.
"""
def __init__(
self,
input_in_fp16,
input_in_bf16,
attn_mask_type,
scaled_masked_softmax_fusion,
mask_func,
softmax_in_fp32,
scale,
):
super(FusedScaleMaskSoftmax, self).__init__()
self.input_in_fp16 = input_in_fp16
self.input_in_bf16 = input_in_bf16
assert not (
self.input_in_fp16 and self.input_in_bf16
), "both fp16 and bf16 flags cannot be active at the same time."
self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16
self.attn_mask_type = attn_mask_type
self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
self.mask_func = mask_func
self.softmax_in_fp32 = softmax_in_fp32
self.scale = scale
assert self.scale is None or softmax_in_fp32, "softmax should be in fp32 when scaled"
def forward(self, input, mask):
# [b, np, sq, sk]
assert input.dim() == 4
if self.is_kernel_available(mask, *input.size()):
return self.forward_fused_softmax(input, mask)
else:
return self.forward_torch_softmax(input, mask)
def is_kernel_available(self, mask, b, np, sq, sk):
attn_batches = b * np
if (
self.scaled_masked_softmax_fusion # user want to fuse
and self.input_in_float16 # input must be fp16
and mask is not None # mask tensor must not be None
and 16 < sk <= 2048 # sk must be 16 ~ 2048
and sq % 4 == 0 # sq must be divisor of 4
and attn_batches % 4 == 0 # np * b must be divisor of 4
):
if 0 <= sk <= 2048:
batch_per_block = self.get_batch_per_block(sq, sk, b, np)
if self.attn_mask_type.value > 1:
if attn_batches % batch_per_block == 0:
return True
else:
if sq % batch_per_block == 0:
return True
return False
def forward_fused_softmax(self, input, mask):
b, np, sq, sk = input.size()
scale = self.scale if self.scale is not None else 1.0
if self.attn_mask_type.value > 1:
assert sq == sk, "causal mask is only for self attention"
# input is 3D tensor (attn_batches, sq, sk)
input = input.view(-1, sq, sk)
probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale)
return probs.view(b, np, sq, sk)
else:
# input is 4D tensor (b, np, sq, sk)
return ScaledMaskedSoftmax.apply(input, mask, scale)
def forward_torch_softmax(self, input, mask):
if self.input_in_float16 and self.softmax_in_fp32:
input = input.float()
if self.scale is not None:
input = input * self.scale
mask_output = self.mask_func(input, mask) if mask is not None else input
probs = torch.nn.Softmax(dim=-1)(mask_output)
if self.input_in_float16 and self.softmax_in_fp32:
if self.input_in_fp16:
probs = probs.half()
else:
probs = probs.bfloat16()
return probs
def get_batch_per_block(self, sq, sk, b, np):
# build and load kernel if not pre-built
global scaled_masked_softmax
if scaled_masked_softmax is None:
scaled_masked_softmax = ScaledMaskedSoftmaxLoader().load()
return scaled_masked_softmax.get_batch_per_block(sq, sk, b, np)
import math
import platform
from typing import Optional
import torch
from colossalai.kernel.op_builder import ArmCPUAdamBuilder, CPUAdamBuilder
from colossalai.kernel.kernel_loader import CPUAdamLoader
from .nvme_optimizer import NVMeOptimizer
......@@ -78,7 +77,7 @@ class CPUAdam(NVMeOptimizer):
default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
super(CPUAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir)
self.adamw_mode = adamw_mode
cpu_adam = ArmCPUAdamBuilder().load() if platform.machine() == "aarch64" else CPUAdamBuilder().load()
cpu_adam = CPUAdamLoader().load()
# if you find yourself stuck here, make sure that you install colossalai with CUDA_EXT=1 specification
self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode)
......
......@@ -70,9 +70,9 @@ class FusedAdam(torch.optim.Optimizer):
self.adamw_mode = 1 if adamw_mode else 0
self.set_grad_none = set_grad_none
if multi_tensor_applier.available:
from colossalai.kernel.op_builder import FusedOptimBuilder
from colossalai.kernel.kernel_loader import FusedOptimizerLoader
fused_optim = FusedOptimBuilder().load()
fused_optim = FusedOptimizerLoader().load()
# Skip buffer
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment