Unverified Commit fae6c92e authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

Merge branch 'main' into feature/shardformer

parents bd186784 ac178ca5
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import torch import torch
from colossalai.logging import get_dist_logger
from colossalai.registry import HOOKS from colossalai.legacy.registry import HOOKS
from colossalai.trainer.hooks import BaseHook from colossalai.legacy.trainer.hooks import BaseHook
from colossalai.logging import get_dist_logger
from colossalai.utils.checkpointing import save_checkpoint from colossalai.utils.checkpointing import save_checkpoint
from ._lr_scheduler_hook import LRSchedulerHook from ._lr_scheduler_hook import LRSchedulerHook
......
...@@ -3,17 +3,17 @@ ...@@ -3,17 +3,17 @@
import os import os
import os.path as osp import os.path as osp
from typing import List from typing import List
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.registry import HOOKS from colossalai.legacy.registry import HOOKS
from colossalai.legacy.trainer.hooks._metric_hook import ThroughputMetric
from colossalai.logging import DistributedLogger from colossalai.logging import DistributedLogger
from colossalai.utils import report_memory_usage, is_dp_rank_0, \ from colossalai.utils import MultiTimer, is_dp_rank_0, is_no_pp_or_last_stage, is_tp_rank_0, report_memory_usage
is_tp_rank_0, is_no_pp_or_last_stage, MultiTimer
from ._base_hook import BaseHook from ._base_hook import BaseHook
from ._commons_ import _format_number from ._commons_ import _format_number
from colossalai.trainer.hooks._metric_hook import ThroughputMetric
class LogByEpochHook(BaseHook): class LogByEpochHook(BaseHook):
......
from colossalai.registry import HOOKS
from torch import Tensor from torch import Tensor
from colossalai.legacy.registry import HOOKS
from ._metric_hook import LearningRateMetric, MetricHook from ._metric_hook import LearningRateMetric, MetricHook
......
...@@ -6,10 +6,11 @@ from typing import Callable ...@@ -6,10 +6,11 @@ from typing import Callable
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.communication import all_reduce from colossalai.communication import all_reduce
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.registry import HOOKS from colossalai.legacy.registry import HOOKS
from colossalai.utils import get_current_device, is_no_pp_or_last_stage from colossalai.utils import get_current_device, is_no_pp_or_last_stage
from ._base_hook import BaseHook from ._base_hook import BaseHook
...@@ -19,8 +20,8 @@ from ._commons_ import _format_number ...@@ -19,8 +20,8 @@ from ._commons_ import _format_number
class Metric(ABC): class Metric(ABC):
"""A basic class of metric collectors. It collects a specific """A basic class of metric collectors. It collects a specific
metric during training or evaluation and would always be used with metric during training or evaluation and would always be used with
:class:`MetricHook` to help it update its states and show the :class:`MetricHook` to help it update its states and show the
metric. So please use corresponding hook class to make the metric metric. So please use corresponding hook class to make the metric
collector works. collector works.
Args: Args:
...@@ -220,9 +221,9 @@ class AccuracyMetric(Metric): ...@@ -220,9 +221,9 @@ class AccuracyMetric(Metric):
class MetricHook(BaseHook): class MetricHook(BaseHook):
"""Specialized hook classes for :class:`Metric`. """Specialized hook classes for :class:`Metric`.
Some help metric collectors initialize, reset and Some help metric collectors initialize, reset and
update their states. Others are used to display and update their states. Others are used to display and
record the metric. record the metric.
Args: Args:
...@@ -355,7 +356,7 @@ class ThroughputMetric(Metric): ...@@ -355,7 +356,7 @@ class ThroughputMetric(Metric):
self.last_step_num_samples *= gpc.get_world_size(ParallelMode.DATA) self.last_step_num_samples *= gpc.get_world_size(ParallelMode.DATA)
else: else:
self.last_step_used_time = all_reduce(self.last_step_used_time, ParallelMode.DATA) / \ self.last_step_used_time = all_reduce(self.last_step_used_time, ParallelMode.DATA) / \
gpc.get_world_size(ParallelMode.DATA) gpc.get_world_size(ParallelMode.DATA)
self.last_step_num_samples = all_reduce(self.last_step_num_samples, ParallelMode.DATA) self.last_step_num_samples = all_reduce(self.last_step_num_samples, ParallelMode.DATA)
sample_per_sec = _format_number(self.last_step_num_samples / (self.last_step_used_time + 1e-12).item()) sample_per_sec = _format_number(self.last_step_num_samples / (self.last_step_used_time + 1e-12).item())
...@@ -366,7 +367,7 @@ class ThroughputMetric(Metric): ...@@ -366,7 +367,7 @@ class ThroughputMetric(Metric):
self.last_step_num_samples *= gpc.get_world_size(ParallelMode.DATA) self.last_step_num_samples *= gpc.get_world_size(ParallelMode.DATA)
else: else:
self.last_step_used_time = all_reduce(self.last_step_used_time, ParallelMode.DATA) / \ self.last_step_used_time = all_reduce(self.last_step_used_time, ParallelMode.DATA) / \
gpc.get_world_size(ParallelMode.DATA) gpc.get_world_size(ParallelMode.DATA)
self.last_step_num_samples = all_reduce(self.last_step_num_samples, ParallelMode.DATA) self.last_step_num_samples = all_reduce(self.last_step_num_samples, ParallelMode.DATA)
sample_per_sec = _format_number(self.last_step_num_samples / (self.last_step_used_time + 1e-12).item()) sample_per_sec = _format_number(self.last_step_num_samples / (self.last_step_used_time + 1e-12).item())
......
...@@ -15,8 +15,8 @@ from colossalai.context import ParallelMode, seed ...@@ -15,8 +15,8 @@ from colossalai.context import ParallelMode, seed
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env from colossalai.global_variables import tensor_parallel_env as env
from colossalai.kernel import LayerNorm from colossalai.kernel import LayerNorm
from colossalai.legacy.registry import LAYERS
from colossalai.nn import init as init from colossalai.nn import init as init
from colossalai.registry import LAYERS
from colossalai.utils.checkpointing import ( from colossalai.utils.checkpointing import (
broadcast_state_dict, broadcast_state_dict,
gather_tensor_parallel_state_dict, gather_tensor_parallel_state_dict,
......
...@@ -5,21 +5,30 @@ from typing import Callable ...@@ -5,21 +5,30 @@ from typing import Callable
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor
from torch.nn import Parameter
from colossalai.communication import broadcast from colossalai.communication import broadcast
from colossalai.context import ParallelMode, seed from colossalai.context import ParallelMode, seed
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env from colossalai.global_variables import tensor_parallel_env as env
from colossalai.legacy.registry import LAYERS
from colossalai.nn import init as init from colossalai.nn import init as init
from colossalai.registry import LAYERS
from colossalai.utils.checkpointing import gather_tensor_parallel_state_dict, partition_tensor_parallel_state_dict from colossalai.utils.checkpointing import gather_tensor_parallel_state_dict, partition_tensor_parallel_state_dict
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from torch import Tensor
from torch.nn import Parameter
from ..base_layer import ParallelLayer from ..base_layer import ParallelLayer
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
from ._operation import (Matmul_AB_2D, Matmul_ABT_2D, add_bias_2d, all_gather_tensor_2d, classifier_2d, layernorm_2d, from ._operation import (
reduce_scatter_tensor_2d, split_batch_2d) Matmul_AB_2D,
Matmul_ABT_2D,
add_bias_2d,
all_gather_tensor_2d,
classifier_2d,
layernorm_2d,
reduce_scatter_tensor_2d,
split_batch_2d,
)
from ._utils import assert_summa_initialization, get_summa_dim_from_env from ._utils import assert_summa_initialization, get_summa_dim_from_env
......
...@@ -5,22 +5,34 @@ from typing import Callable ...@@ -5,22 +5,34 @@ from typing import Callable
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor
from torch.nn import Parameter
from colossalai.communication import broadcast from colossalai.communication import broadcast
from colossalai.context import ParallelMode, seed from colossalai.context import ParallelMode, seed
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env from colossalai.global_variables import tensor_parallel_env as env
from colossalai.legacy.registry import LAYERS
from colossalai.nn import init as init from colossalai.nn import init as init
from colossalai.registry import LAYERS from colossalai.utils.checkpointing import (
from colossalai.utils.checkpointing import (broadcast_state_dict, gather_tensor_parallel_state_dict, broadcast_state_dict,
partition_tensor_parallel_state_dict) gather_tensor_parallel_state_dict,
partition_tensor_parallel_state_dict,
)
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from torch import Tensor
from torch.nn import Parameter
from ..base_layer import ParallelLayer from ..base_layer import ParallelLayer
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
from ._operation import (Matmul_AB_2p5D, Matmul_ABT_2p5D, add_bias_2p5d, all_gather_tensor_2p5d, classifier_2p5d, from ._operation import (
layernorm_2p5d, reduce_scatter_tensor_2p5d, split_batch_2p5d) Matmul_AB_2p5D,
Matmul_ABT_2p5D,
add_bias_2p5d,
all_gather_tensor_2p5d,
classifier_2p5d,
layernorm_2p5d,
reduce_scatter_tensor_2p5d,
split_batch_2p5d,
)
from ._utils import assert_tesseract_initialization, get_tesseract_dim_dep_from_env from ._utils import assert_tesseract_initialization, get_tesseract_dim_dep_from_env
......
...@@ -13,9 +13,9 @@ from colossalai.constants import INPUT_GROUP_3D, INPUT_X_WEIGHT_3D, OUTPUT_GROUP ...@@ -13,9 +13,9 @@ from colossalai.constants import INPUT_GROUP_3D, INPUT_X_WEIGHT_3D, OUTPUT_GROUP
from colossalai.context import ParallelMode, seed from colossalai.context import ParallelMode, seed
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env from colossalai.global_variables import tensor_parallel_env as env
from colossalai.legacy.registry import LAYERS
from colossalai.nn import init as init from colossalai.nn import init as init
from colossalai.nn.layer.base_layer import ParallelLayer from colossalai.nn.layer.base_layer import ParallelLayer
from colossalai.registry import LAYERS
from colossalai.utils.checkpointing import ( from colossalai.utils.checkpointing import (
broadcast_state_dict, broadcast_state_dict,
gather_tensor_parallel_state_dict, gather_tensor_parallel_state_dict,
......
...@@ -2,20 +2,20 @@ ...@@ -2,20 +2,20 @@
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import math import math
import colossalai
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn import Parameter from torch.nn import Parameter
import colossalai
from colossalai.context import seed
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.nn.layer.parallel_sequence._operation import RingQK, RingAV
from colossalai.registry import LAYERS
from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType
from colossalai.kernel import FusedScaleMaskSoftmax from colossalai.kernel import FusedScaleMaskSoftmax
from colossalai.context import seed from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType
from colossalai.legacy.registry import LAYERS
from colossalai.nn.layer.parallel_sequence._operation import RingAV, RingQK
@LAYERS.register_module @LAYERS.register_module
......
...@@ -8,8 +8,8 @@ from torch import nn as nn ...@@ -8,8 +8,8 @@ from torch import nn as nn
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from colossalai.context import seed from colossalai.context import seed
from colossalai.legacy.registry import LAYERS
from colossalai.nn import init as init from colossalai.nn import init as init
from colossalai.registry import LAYERS
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from ..utils import to_2tuple from ..utils import to_2tuple
......
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.context import ParallelMode from torch.cuda.amp import custom_bwd, custom_fwd
from colossalai.core import global_context as gpc from torch.nn.modules.loss import _Loss
from colossalai.registry import LOSSES
from torch.cuda.amp import custom_bwd, custom_fwd from colossalai.context import ParallelMode
from torch.nn.modules.loss import _Loss from colossalai.core import global_context as gpc
from colossalai.legacy.registry import LOSSES
class _VocabParallelCrossEntropy1D(torch.autograd.Function):
class _VocabParallelCrossEntropy1D(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32) @staticmethod
def forward(ctx, vocab_parallel_logits, targets, process_group): @custom_fwd(cast_inputs=torch.float32)
if process_group is None: def forward(ctx, vocab_parallel_logits, targets, process_group):
process_group = gpc.get_group(ParallelMode.PARALLEL_1D) if process_group is None:
process_group = gpc.get_group(ParallelMode.PARALLEL_1D)
# Maximum value along vocab dimension across all GPUs.
logits_max = torch.max(vocab_parallel_logits, dim=-1)[0] # Maximum value along vocab dimension across all GPUs.
torch.distributed.all_reduce(logits_max, op=torch.distributed.ReduceOp.MAX, group=process_group) logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
# Subtract the maximum value. torch.distributed.all_reduce(logits_max, op=torch.distributed.ReduceOp.MAX, group=process_group)
vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1)) # Subtract the maximum value.
vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1))
# Get the partition's vocab indices
partition_vocab_size = vocab_parallel_logits.size()[-1] # Get the partition's vocab indices
rank = dist.get_rank(process_group) partition_vocab_size = vocab_parallel_logits.size()[-1]
vocab_start_index = partition_vocab_size * rank rank = dist.get_rank(process_group)
vocab_end_index = vocab_start_index + partition_vocab_size vocab_start_index = partition_vocab_size * rank
vocab_end_index = vocab_start_index + partition_vocab_size
# Create a mask of valid vocab ids (1 means it needs to be masked).
target_mask = (targets < vocab_start_index) | (targets >= vocab_end_index) # Create a mask of valid vocab ids (1 means it needs to be masked).
masked_target = targets.clone() - vocab_start_index target_mask = (targets < vocab_start_index) | (targets >= vocab_end_index)
masked_target[target_mask] = 0 masked_target = targets.clone() - vocab_start_index
masked_target[target_mask] = 0
# Get predicted-logits = logits[target].
# For Simplicity, we convert logits to a 2-D tensor with size # Get predicted-logits = logits[target].
# [*, partition-vocab-size] and target to a 1-D tensor of size [*]. # For Simplicity, we convert logits to a 2-D tensor with size
logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size) # [*, partition-vocab-size] and target to a 1-D tensor of size [*].
masked_target_1d = masked_target.view(-1) logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size)
arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device) masked_target_1d = masked_target.view(-1)
predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device)
predicted_logits_1d = predicted_logits_1d.clone().contiguous() predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
predicted_logits = predicted_logits_1d.view_as(targets) predicted_logits_1d = predicted_logits_1d.clone().contiguous()
predicted_logits[target_mask] = 0.0 predicted_logits = predicted_logits_1d.view_as(targets)
# All reduce is needed to get the chunks from other GPUs. predicted_logits[target_mask] = 0.0
torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=process_group) # All reduce is needed to get the chunks from other GPUs.
torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=process_group)
# Sum of exponential of logits along vocab dimension across all GPUs.
exp_logits = torch.exp(vocab_parallel_logits) # Sum of exponential of logits along vocab dimension across all GPUs.
sum_exp_logits = exp_logits.sum(dim=-1) exp_logits = torch.exp(vocab_parallel_logits)
torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=process_group) sum_exp_logits = exp_logits.sum(dim=-1)
torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=process_group)
# Loss = log(sum(exp(logits))) - predicted-logit.
loss = torch.log(sum_exp_logits) - predicted_logits # Loss = log(sum(exp(logits))) - predicted-logit.
# Store softmax, target-mask and masked-target for backward pass. loss = torch.log(sum_exp_logits) - predicted_logits
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) # Store softmax, target-mask and masked-target for backward pass.
ctx.save_for_backward(exp_logits, target_mask, masked_target_1d) exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
return loss ctx.save_for_backward(exp_logits, target_mask, masked_target_1d)
return loss
@staticmethod
@custom_bwd @staticmethod
def backward(ctx, grad_output): @custom_bwd
def backward(ctx, grad_output):
# Retrieve tensors from the forward path.
softmax, target_mask, masked_target_1d = ctx.saved_tensors # Retrieve tensors from the forward path.
softmax, target_mask, masked_target_1d = ctx.saved_tensors
# All the inputs have softmax as their gradient.
grad_input = softmax # All the inputs have softmax as their gradient.
# For simplicity, work with the 2D gradient. grad_input = softmax
partition_vocab_size = softmax.size()[-1] # For simplicity, work with the 2D gradient.
grad_2d = grad_input.view(-1, partition_vocab_size) partition_vocab_size = softmax.size()[-1]
grad_2d = grad_input.view(-1, partition_vocab_size)
# Add the gradient from matching classes.
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device) # Add the gradient from matching classes.
grad_2d[arange_1d, masked_target_1d] -= (1.0 - target_mask.view(-1).float()) arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device)
grad_2d[arange_1d, masked_target_1d] -= (1.0 - target_mask.view(-1).float())
# Finally elementwise multiplication with the output gradients.
grad_input.mul_(grad_output.unsqueeze(dim=-1)) # Finally elementwise multiplication with the output gradients.
grad_input.mul_(grad_output.unsqueeze(dim=-1))
return grad_input, None, None
return grad_input, None, None
@LOSSES.register_module
class VocabParallelCrossEntropyLoss1D(_Loss): @LOSSES.register_module
"""Vocab parallel cross entropy loss for 1D parallelism. class VocabParallelCrossEntropyLoss1D(_Loss):
"""Vocab parallel cross entropy loss for 1D parallelism.
Args:
reduction (bool, optional): whether to average the loss, defaults to True. Args:
""" reduction (bool, optional): whether to average the loss, defaults to True.
"""
def __init__(self, reduction=True):
super().__init__() def __init__(self, reduction=True):
self.reduction_mean = reduction super().__init__()
self.reduction_mean = reduction
def forward(self, logits, targets, process_group=None):
"""Calculate loss between logits and targets. def forward(self, logits, targets, process_group=None):
"""Calculate loss between logits and targets.
Args:
logits (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). Args:
targets (:class:`torch.tensor`): Ground truth class indices or class probabilities. logits (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
""" targets (:class:`torch.tensor`): Ground truth class indices or class probabilities.
loss = _VocabParallelCrossEntropy1D.apply(logits, targets, process_group) """
if self.reduction_mean: loss = _VocabParallelCrossEntropy1D.apply(logits, targets, process_group)
loss = loss.mean() if self.reduction_mean:
return loss loss = loss.mean()
return loss
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.nn.functional import cross_entropy
from torch.nn.modules.loss import _Loss
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.legacy.registry import LOSSES
from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d, split_batch_2d from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d, split_batch_2d
from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization
from colossalai.registry import LOSSES
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.nn.functional import cross_entropy
from torch.nn.modules.loss import _Loss
@LOSSES.register_module @LOSSES.register_module
......
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.nn.functional import cross_entropy
from torch.nn.modules.loss import _Loss
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.legacy.registry import LOSSES
from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_batch_2p5d from colossalai.nn.layer.parallel_2p5d import reduce_by_batch_2p5d, split_batch_2p5d
from colossalai.nn.layer.parallel_2p5d._utils import assert_tesseract_initialization from colossalai.nn.layer.parallel_2p5d._utils import assert_tesseract_initialization
from colossalai.registry import LOSSES
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.nn.functional import cross_entropy
from torch.nn.modules.loss import _Loss
@LOSSES.register_module @LOSSES.register_module
......
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D from torch.cuda.amp import custom_bwd, custom_fwd
from torch.nn.functional import cross_entropy
from torch.nn.modules.loss import _Loss
from colossalai.constants import INPUT_GROUP_3D, OUTPUT_GROUP_3D, WEIGHT_GROUP_3D
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.legacy.registry import LOSSES
from colossalai.nn.layer.parallel_3d import reduce_by_batch_3d, split_tensor_3d from colossalai.nn.layer.parallel_3d import reduce_by_batch_3d, split_tensor_3d
from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env from colossalai.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
from colossalai.registry import LOSSES
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.nn.functional import cross_entropy
from torch.nn.modules.loss import _Loss
@LOSSES.register_module @LOSSES.register_module
......
import torch.nn as nn import torch.nn as nn
from colossalai.registry import LOSSES from torch.nn.modules.loss import _Loss
from torch.nn.modules.loss import _Loss
from colossalai.context.moe_context import MOE_CONTEXT from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.legacy.registry import LOSSES
@LOSSES.register_module
class MoeCrossEntropyLoss(_Loss): @LOSSES.register_module
r"""torch.nn.CrossEntropyLoss added with auxiliary loss. class MoeCrossEntropyLoss(_Loss):
r"""torch.nn.CrossEntropyLoss added with auxiliary loss.
Args:
input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits). Args:
target (:class:`torch.tensor`): Ground truth class indices or class probabilities. input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
aux_weight (float, optional): Weight of auxiliary loss in total loss.Defaults 0.01. target (:class:`torch.tensor`): Ground truth class indices or class probabilities.
aux_weight (float, optional): Weight of auxiliary loss in total loss.Defaults 0.01.
The ``args`` and ``kwargs`` should include parameters below:
:: The ``args`` and ``kwargs`` should include parameters below:
::
weight (Tensor, optional)
size_average (bool, optional) weight (Tensor, optional)
ignore_index (int, optional) size_average (bool, optional)
reduce (bool, optional) ignore_index (int, optional)
reduction (str, optional) reduce (bool, optional)
label_smoothing (float, optional) reduction (str, optional)
label_smoothing (float, optional)
More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in
`Cross_entropy <https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html#torch.nn.functional.cross_entropy>`_. More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in
""" `Cross_entropy <https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html#torch.nn.functional.cross_entropy>`_.
"""
def __init__(self, aux_weight: float = 0.01, *args, **kwargs):
super().__init__() def __init__(self, aux_weight: float = 0.01, *args, **kwargs):
self.loss = nn.CrossEntropyLoss(*args, **kwargs) super().__init__()
self.aux_weight = aux_weight self.loss = nn.CrossEntropyLoss(*args, **kwargs)
self.aux_weight = aux_weight
def forward(self, *args):
""" def forward(self, *args):
The ``args`` should at least include parameters below: """
:: The ``args`` should at least include parameters below:
::
input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
target (:class:`torch.tensor`): Ground truth class indices or class probabilities. input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
target (:class:`torch.tensor`): Ground truth class indices or class probabilities.
More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in
`Cross_entropy <https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html#torch.nn.functional.cross_entropy>`_. More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in
""" `Cross_entropy <https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html#torch.nn.functional.cross_entropy>`_.
main_loss = self.loss(*args) """
aux_loss = MOE_CONTEXT.get_loss() main_loss = self.loss(*args)
return main_loss + self.aux_weight * aux_loss aux_loss = MOE_CONTEXT.get_loss()
return main_loss + self.aux_weight * aux_loss
@LOSSES.register_module
class MoeLoss(_Loss): @LOSSES.register_module
"""A wrapper class for any loss module to add with auxiliary loss. class MoeLoss(_Loss):
"""A wrapper class for any loss module to add with auxiliary loss.
Args:
aux_weight (float): Weight of auxiliary loss in total loss. Args:
loss_fn (``Callable``): Loss function. aux_weight (float): Weight of auxiliary loss in total loss.
args (list): Args in loss function. loss_fn (``Callable``): Loss function.
kwargs (dict): Kwargs in loss function args (list): Args in loss function.
""" kwargs (dict): Kwargs in loss function
"""
def __init__(self, aux_weight: float, loss_fn, *args, **kwargs):
super().__init__() def __init__(self, aux_weight: float, loss_fn, *args, **kwargs):
self.loss_fn = loss_fn(*args, **kwargs) super().__init__()
self.aux_weight = aux_weight self.loss_fn = loss_fn(*args, **kwargs)
self.aux_weight = aux_weight
def forward(self, *args, **kwargs):
""" def forward(self, *args, **kwargs):
The ``args`` and ``kwargs`` should at least include parameters below: """
:: The ``args`` and ``kwargs`` should at least include parameters below:
::
input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
target (:class:`torch.tensor`): Ground truth class indices or class probabilities. input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
target (:class:`torch.tensor`): Ground truth class indices or class probabilities.
Note:
The ``args`` and ``kwargs`` may include different parameters varying with different loss function. Note:
""" The ``args`` and ``kwargs`` may include different parameters varying with different loss function.
main_loss = self.loss_fn(*args, **kwargs) """
aux_loss = MOE_CONTEXT.get_loss() main_loss = self.loss_fn(*args, **kwargs)
return main_loss + self.aux_weight * aux_loss aux_loss = MOE_CONTEXT.get_loss()
return main_loss + self.aux_weight * aux_loss
from torch.optim.lr_scheduler import CosineAnnealingLR as _CosineAnnealingLR from torch.optim.lr_scheduler import CosineAnnealingLR as _CosineAnnealingLR
from colossalai.registry import LR_SCHEDULERS from colossalai.legacy.registry import LR_SCHEDULERS
from .delayed import DelayerScheduler, WarmupDelayerScheduler, WarmupScheduler from .delayed import DelayerScheduler, WarmupDelayerScheduler, WarmupScheduler
......
from torch.optim.lr_scheduler import _LRScheduler from torch.optim.lr_scheduler import _LRScheduler
from colossalai.registry import LR_SCHEDULERS from colossalai.legacy.registry import LR_SCHEDULERS
@LR_SCHEDULERS.register_module @LR_SCHEDULERS.register_module
......
...@@ -2,7 +2,8 @@ from typing import List ...@@ -2,7 +2,8 @@ from typing import List
from torch.optim.lr_scheduler import MultiStepLR as _MultiStepLR from torch.optim.lr_scheduler import MultiStepLR as _MultiStepLR
from colossalai.registry import LR_SCHEDULERS from colossalai.legacy.registry import LR_SCHEDULERS
from .delayed import WarmupScheduler from .delayed import WarmupScheduler
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment