Commit 8edb777c authored by shenggan's avatar shenggan Committed by Frank Lee
Browse files

[NFC] polish colossalai/nn/loss/loss_2p5d.py code style (#1553)

parent bd2d7898
...@@ -30,6 +30,7 @@ class CrossEntropyLoss2p5D(_Loss): ...@@ -30,6 +30,7 @@ class CrossEntropyLoss2p5D(_Loss):
More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in
`Cross_entropy <https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html#torch.nn.functional.cross_entropy>`_. `Cross_entropy <https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html#torch.nn.functional.cross_entropy>`_.
""" """
def __init__(self, reduction=True, *args, **kwargs): def __init__(self, reduction=True, *args, **kwargs):
super().__init__() super().__init__()
assert_tesseract_initialization() assert_tesseract_initialization()
...@@ -127,6 +128,7 @@ class VocabParallelCrossEntropyLoss2p5D(_Loss): ...@@ -127,6 +128,7 @@ class VocabParallelCrossEntropyLoss2p5D(_Loss):
Args: Args:
reduction (bool, optional): whether to average the loss, defaults to True. reduction (bool, optional): whether to average the loss, defaults to True.
""" """
def __init__(self, reduction=True): def __init__(self, reduction=True):
super().__init__() super().__init__()
self.reduction_mean = reduction self.reduction_mean = reduction
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment