Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
fae6c92e
Unverified
Commit
fae6c92e
authored
Sep 05, 2023
by
Hongxin Liu
Committed by
GitHub
Sep 05, 2023
Browse files
Merge branch 'main' into feature/shardformer
parents
bd186784
ac178ca5
Changes
113
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
269 additions
and
238 deletions
+269
-238
colossalai/legacy/trainer/hooks/_base_hook.py
colossalai/legacy/trainer/hooks/_base_hook.py
+0
-0
colossalai/legacy/trainer/hooks/_checkpoint_hook.py
colossalai/legacy/trainer/hooks/_checkpoint_hook.py
+4
-3
colossalai/legacy/trainer/hooks/_commons_.py
colossalai/legacy/trainer/hooks/_commons_.py
+0
-0
colossalai/legacy/trainer/hooks/_log_hook.py
colossalai/legacy/trainer/hooks/_log_hook.py
+5
-5
colossalai/legacy/trainer/hooks/_lr_scheduler_hook.py
colossalai/legacy/trainer/hooks/_lr_scheduler_hook.py
+2
-1
colossalai/legacy/trainer/hooks/_metric_hook.py
colossalai/legacy/trainer/hooks/_metric_hook.py
+9
-8
colossalai/nn/layer/parallel_1d/layers.py
colossalai/nn/layer/parallel_1d/layers.py
+1
-1
colossalai/nn/layer/parallel_2d/layers.py
colossalai/nn/layer/parallel_2d/layers.py
+14
-5
colossalai/nn/layer/parallel_2p5d/layers.py
colossalai/nn/layer/parallel_2p5d/layers.py
+19
-7
colossalai/nn/layer/parallel_3d/layers.py
colossalai/nn/layer/parallel_3d/layers.py
+1
-1
colossalai/nn/layer/parallel_sequence/layers.py
colossalai/nn/layer/parallel_sequence/layers.py
+5
-5
colossalai/nn/layer/vanilla/layers.py
colossalai/nn/layer/vanilla/layers.py
+1
-1
colossalai/nn/loss/loss_1d.py
colossalai/nn/loss/loss_1d.py
+106
-105
colossalai/nn/loss/loss_2d.py
colossalai/nn/loss/loss_2d.py
+5
-4
colossalai/nn/loss/loss_2p5d.py
colossalai/nn/loss/loss_2p5d.py
+5
-4
colossalai/nn/loss/loss_3d.py
colossalai/nn/loss/loss_3d.py
+6
-5
colossalai/nn/loss/loss_moe.py
colossalai/nn/loss/loss_moe.py
+81
-80
colossalai/nn/lr_scheduler/cosine.py
colossalai/nn/lr_scheduler/cosine.py
+2
-1
colossalai/nn/lr_scheduler/linear.py
colossalai/nn/lr_scheduler/linear.py
+1
-1
colossalai/nn/lr_scheduler/multistep.py
colossalai/nn/lr_scheduler/multistep.py
+2
-1
No files found.
colossalai/trainer/hooks/_base_hook.py
→
colossalai/
legacy/
trainer/hooks/_base_hook.py
View file @
fae6c92e
File moved
colossalai/trainer/hooks/_checkpoint_hook.py
→
colossalai/
legacy/
trainer/hooks/_checkpoint_hook.py
View file @
fae6c92e
#!/usr/bin/env python
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# -*- encoding: utf-8 -*-
import
torch
import
torch
from
colossalai.logging
import
get_dist_logger
from
colossalai.registry
import
HOOKS
from
colossalai.legacy.registry
import
HOOKS
from
colossalai.trainer.hooks
import
BaseHook
from
colossalai.legacy.trainer.hooks
import
BaseHook
from
colossalai.logging
import
get_dist_logger
from
colossalai.utils.checkpointing
import
save_checkpoint
from
colossalai.utils.checkpointing
import
save_checkpoint
from
._lr_scheduler_hook
import
LRSchedulerHook
from
._lr_scheduler_hook
import
LRSchedulerHook
...
...
colossalai/trainer/hooks/_commons_.py
→
colossalai/
legacy/
trainer/hooks/_commons_.py
View file @
fae6c92e
File moved
colossalai/trainer/hooks/_log_hook.py
→
colossalai/
legacy/
trainer/hooks/_log_hook.py
View file @
fae6c92e
...
@@ -3,17 +3,17 @@
...
@@ -3,17 +3,17 @@
import
os
import
os
import
os.path
as
osp
import
os.path
as
osp
from
typing
import
List
from
typing
import
List
from
colossalai.context
import
ParallelMode
from
colossalai.context
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.registry
import
HOOKS
from
colossalai.legacy.registry
import
HOOKS
from
colossalai.legacy.trainer.hooks._metric_hook
import
ThroughputMetric
from
colossalai.logging
import
DistributedLogger
from
colossalai.logging
import
DistributedLogger
from
colossalai.utils
import
report_memory_us
age
,
is_
d
p_rank_0
,
\
from
colossalai.utils
import
MultiTimer
,
is_dp_rank_0
,
is_no_pp_or_last_st
age
,
is_
t
p_rank_0
,
report_memory_usage
is_tp_rank_0
,
is_no_pp_or_last_stage
,
MultiTimer
from
._base_hook
import
BaseHook
from
._base_hook
import
BaseHook
from
._commons_
import
_format_number
from
._commons_
import
_format_number
from
colossalai.trainer.hooks._metric_hook
import
ThroughputMetric
class
LogByEpochHook
(
BaseHook
):
class
LogByEpochHook
(
BaseHook
):
...
...
colossalai/trainer/hooks/_lr_scheduler_hook.py
→
colossalai/
legacy/
trainer/hooks/_lr_scheduler_hook.py
View file @
fae6c92e
from
colossalai.registry
import
HOOKS
from
torch
import
Tensor
from
torch
import
Tensor
from
colossalai.legacy.registry
import
HOOKS
from
._metric_hook
import
LearningRateMetric
,
MetricHook
from
._metric_hook
import
LearningRateMetric
,
MetricHook
...
...
colossalai/trainer/hooks/_metric_hook.py
→
colossalai/
legacy/
trainer/hooks/_metric_hook.py
View file @
fae6c92e
...
@@ -6,10 +6,11 @@ from typing import Callable
...
@@ -6,10 +6,11 @@ from typing import Callable
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
colossalai.communication
import
all_reduce
from
colossalai.communication
import
all_reduce
from
colossalai.context
import
ParallelMode
from
colossalai.context
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.registry
import
HOOKS
from
colossalai.
legacy.
registry
import
HOOKS
from
colossalai.utils
import
get_current_device
,
is_no_pp_or_last_stage
from
colossalai.utils
import
get_current_device
,
is_no_pp_or_last_stage
from
._base_hook
import
BaseHook
from
._base_hook
import
BaseHook
...
@@ -19,8 +20,8 @@ from ._commons_ import _format_number
...
@@ -19,8 +20,8 @@ from ._commons_ import _format_number
class
Metric
(
ABC
):
class
Metric
(
ABC
):
"""A basic class of metric collectors. It collects a specific
"""A basic class of metric collectors. It collects a specific
metric during training or evaluation and would always be used with
metric during training or evaluation and would always be used with
:class:`MetricHook` to help it update its states and show the
:class:`MetricHook` to help it update its states and show the
metric. So please use corresponding hook class to make the metric
metric. So please use corresponding hook class to make the metric
collector works.
collector works.
Args:
Args:
...
@@ -220,9 +221,9 @@ class AccuracyMetric(Metric):
...
@@ -220,9 +221,9 @@ class AccuracyMetric(Metric):
class
MetricHook
(
BaseHook
):
class
MetricHook
(
BaseHook
):
"""Specialized hook classes for :class:`Metric`.
"""Specialized hook classes for :class:`Metric`.
Some help metric collectors initialize, reset and
Some help metric collectors initialize, reset and
update their states. Others are used to display and
update their states. Others are used to display and
record the metric.
record the metric.
Args:
Args:
...
@@ -355,7 +356,7 @@ class ThroughputMetric(Metric):
...
@@ -355,7 +356,7 @@ class ThroughputMetric(Metric):
self
.
last_step_num_samples
*=
gpc
.
get_world_size
(
ParallelMode
.
DATA
)
self
.
last_step_num_samples
*=
gpc
.
get_world_size
(
ParallelMode
.
DATA
)
else
:
else
:
self
.
last_step_used_time
=
all_reduce
(
self
.
last_step_used_time
,
ParallelMode
.
DATA
)
/
\
self
.
last_step_used_time
=
all_reduce
(
self
.
last_step_used_time
,
ParallelMode
.
DATA
)
/
\
gpc
.
get_world_size
(
ParallelMode
.
DATA
)
gpc
.
get_world_size
(
ParallelMode
.
DATA
)
self
.
last_step_num_samples
=
all_reduce
(
self
.
last_step_num_samples
,
ParallelMode
.
DATA
)
self
.
last_step_num_samples
=
all_reduce
(
self
.
last_step_num_samples
,
ParallelMode
.
DATA
)
sample_per_sec
=
_format_number
(
self
.
last_step_num_samples
/
(
self
.
last_step_used_time
+
1e-12
).
item
())
sample_per_sec
=
_format_number
(
self
.
last_step_num_samples
/
(
self
.
last_step_used_time
+
1e-12
).
item
())
...
@@ -366,7 +367,7 @@ class ThroughputMetric(Metric):
...
@@ -366,7 +367,7 @@ class ThroughputMetric(Metric):
self
.
last_step_num_samples
*=
gpc
.
get_world_size
(
ParallelMode
.
DATA
)
self
.
last_step_num_samples
*=
gpc
.
get_world_size
(
ParallelMode
.
DATA
)
else
:
else
:
self
.
last_step_used_time
=
all_reduce
(
self
.
last_step_used_time
,
ParallelMode
.
DATA
)
/
\
self
.
last_step_used_time
=
all_reduce
(
self
.
last_step_used_time
,
ParallelMode
.
DATA
)
/
\
gpc
.
get_world_size
(
ParallelMode
.
DATA
)
gpc
.
get_world_size
(
ParallelMode
.
DATA
)
self
.
last_step_num_samples
=
all_reduce
(
self
.
last_step_num_samples
,
ParallelMode
.
DATA
)
self
.
last_step_num_samples
=
all_reduce
(
self
.
last_step_num_samples
,
ParallelMode
.
DATA
)
sample_per_sec
=
_format_number
(
self
.
last_step_num_samples
/
(
self
.
last_step_used_time
+
1e-12
).
item
())
sample_per_sec
=
_format_number
(
self
.
last_step_num_samples
/
(
self
.
last_step_used_time
+
1e-12
).
item
())
...
...
colossalai/nn/layer/parallel_1d/layers.py
View file @
fae6c92e
...
@@ -15,8 +15,8 @@ from colossalai.context import ParallelMode, seed
...
@@ -15,8 +15,8 @@ from colossalai.context import ParallelMode, seed
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.global_variables
import
tensor_parallel_env
as
env
from
colossalai.global_variables
import
tensor_parallel_env
as
env
from
colossalai.kernel
import
LayerNorm
from
colossalai.kernel
import
LayerNorm
from
colossalai.legacy.registry
import
LAYERS
from
colossalai.nn
import
init
as
init
from
colossalai.nn
import
init
as
init
from
colossalai.registry
import
LAYERS
from
colossalai.utils.checkpointing
import
(
from
colossalai.utils.checkpointing
import
(
broadcast_state_dict
,
broadcast_state_dict
,
gather_tensor_parallel_state_dict
,
gather_tensor_parallel_state_dict
,
...
...
colossalai/nn/layer/parallel_2d/layers.py
View file @
fae6c92e
...
@@ -5,21 +5,30 @@ from typing import Callable
...
@@ -5,21 +5,30 @@ from typing import Callable
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch
import
Tensor
from
torch.nn
import
Parameter
from
colossalai.communication
import
broadcast
from
colossalai.communication
import
broadcast
from
colossalai.context
import
ParallelMode
,
seed
from
colossalai.context
import
ParallelMode
,
seed
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.global_variables
import
tensor_parallel_env
as
env
from
colossalai.global_variables
import
tensor_parallel_env
as
env
from
colossalai.legacy.registry
import
LAYERS
from
colossalai.nn
import
init
as
init
from
colossalai.nn
import
init
as
init
from
colossalai.registry
import
LAYERS
from
colossalai.utils.checkpointing
import
gather_tensor_parallel_state_dict
,
partition_tensor_parallel_state_dict
from
colossalai.utils.checkpointing
import
gather_tensor_parallel_state_dict
,
partition_tensor_parallel_state_dict
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.cuda
import
get_current_device
from
torch
import
Tensor
from
torch.nn
import
Parameter
from
..base_layer
import
ParallelLayer
from
..base_layer
import
ParallelLayer
from
..utils
import
divide
,
set_tensor_parallel_attribute_by_partition
,
to_2tuple
from
..utils
import
divide
,
set_tensor_parallel_attribute_by_partition
,
to_2tuple
from
._operation
import
(
Matmul_AB_2D
,
Matmul_ABT_2D
,
add_bias_2d
,
all_gather_tensor_2d
,
classifier_2d
,
layernorm_2d
,
from
._operation
import
(
reduce_scatter_tensor_2d
,
split_batch_2d
)
Matmul_AB_2D
,
Matmul_ABT_2D
,
add_bias_2d
,
all_gather_tensor_2d
,
classifier_2d
,
layernorm_2d
,
reduce_scatter_tensor_2d
,
split_batch_2d
,
)
from
._utils
import
assert_summa_initialization
,
get_summa_dim_from_env
from
._utils
import
assert_summa_initialization
,
get_summa_dim_from_env
...
...
colossalai/nn/layer/parallel_2p5d/layers.py
View file @
fae6c92e
...
@@ -5,22 +5,34 @@ from typing import Callable
...
@@ -5,22 +5,34 @@ from typing import Callable
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch
import
Tensor
from
torch.nn
import
Parameter
from
colossalai.communication
import
broadcast
from
colossalai.communication
import
broadcast
from
colossalai.context
import
ParallelMode
,
seed
from
colossalai.context
import
ParallelMode
,
seed
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.global_variables
import
tensor_parallel_env
as
env
from
colossalai.global_variables
import
tensor_parallel_env
as
env
from
colossalai.legacy.registry
import
LAYERS
from
colossalai.nn
import
init
as
init
from
colossalai.nn
import
init
as
init
from
colossalai.registry
import
LAYERS
from
colossalai.utils.checkpointing
import
(
from
colossalai.utils.checkpointing
import
(
broadcast_state_dict
,
gather_tensor_parallel_state_dict
,
broadcast_state_dict
,
partition_tensor_parallel_state_dict
)
gather_tensor_parallel_state_dict
,
partition_tensor_parallel_state_dict
,
)
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.cuda
import
get_current_device
from
torch
import
Tensor
from
torch.nn
import
Parameter
from
..base_layer
import
ParallelLayer
from
..base_layer
import
ParallelLayer
from
..utils
import
divide
,
set_tensor_parallel_attribute_by_partition
,
to_2tuple
from
..utils
import
divide
,
set_tensor_parallel_attribute_by_partition
,
to_2tuple
from
._operation
import
(
Matmul_AB_2p5D
,
Matmul_ABT_2p5D
,
add_bias_2p5d
,
all_gather_tensor_2p5d
,
classifier_2p5d
,
from
._operation
import
(
layernorm_2p5d
,
reduce_scatter_tensor_2p5d
,
split_batch_2p5d
)
Matmul_AB_2p5D
,
Matmul_ABT_2p5D
,
add_bias_2p5d
,
all_gather_tensor_2p5d
,
classifier_2p5d
,
layernorm_2p5d
,
reduce_scatter_tensor_2p5d
,
split_batch_2p5d
,
)
from
._utils
import
assert_tesseract_initialization
,
get_tesseract_dim_dep_from_env
from
._utils
import
assert_tesseract_initialization
,
get_tesseract_dim_dep_from_env
...
...
colossalai/nn/layer/parallel_3d/layers.py
View file @
fae6c92e
...
@@ -13,9 +13,9 @@ from colossalai.constants import INPUT_GROUP_3D, INPUT_X_WEIGHT_3D, OUTPUT_GROUP
...
@@ -13,9 +13,9 @@ from colossalai.constants import INPUT_GROUP_3D, INPUT_X_WEIGHT_3D, OUTPUT_GROUP
from
colossalai.context
import
ParallelMode
,
seed
from
colossalai.context
import
ParallelMode
,
seed
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.global_variables
import
tensor_parallel_env
as
env
from
colossalai.global_variables
import
tensor_parallel_env
as
env
from
colossalai.legacy.registry
import
LAYERS
from
colossalai.nn
import
init
as
init
from
colossalai.nn
import
init
as
init
from
colossalai.nn.layer.base_layer
import
ParallelLayer
from
colossalai.nn.layer.base_layer
import
ParallelLayer
from
colossalai.registry
import
LAYERS
from
colossalai.utils.checkpointing
import
(
from
colossalai.utils.checkpointing
import
(
broadcast_state_dict
,
broadcast_state_dict
,
gather_tensor_parallel_state_dict
,
gather_tensor_parallel_state_dict
,
...
...
colossalai/nn/layer/parallel_sequence/layers.py
View file @
fae6c92e
...
@@ -2,20 +2,20 @@
...
@@ -2,20 +2,20 @@
# -*- encoding: utf-8 -*-
# -*- encoding: utf-8 -*-
import
math
import
math
import
colossalai
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch.nn
import
Parameter
from
torch.nn
import
Parameter
import
colossalai
from
colossalai.context
import
seed
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.nn.layer.parallel_sequence._operation
import
RingQK
,
RingAV
from
colossalai.registry
import
LAYERS
from
colossalai.kernel.cuda_native.scaled_softmax
import
AttnMaskType
from
colossalai.kernel
import
FusedScaleMaskSoftmax
from
colossalai.kernel
import
FusedScaleMaskSoftmax
from
colossalai.context
import
seed
from
colossalai.kernel.cuda_native.scaled_softmax
import
AttnMaskType
from
colossalai.legacy.registry
import
LAYERS
from
colossalai.nn.layer.parallel_sequence._operation
import
RingAV
,
RingQK
@
LAYERS
.
register_module
@
LAYERS
.
register_module
...
...
colossalai/nn/layer/vanilla/layers.py
View file @
fae6c92e
...
@@ -8,8 +8,8 @@ from torch import nn as nn
...
@@ -8,8 +8,8 @@ from torch import nn as nn
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
colossalai.context
import
seed
from
colossalai.context
import
seed
from
colossalai.legacy.registry
import
LAYERS
from
colossalai.nn
import
init
as
init
from
colossalai.nn
import
init
as
init
from
colossalai.registry
import
LAYERS
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.cuda
import
get_current_device
from
..utils
import
to_2tuple
from
..utils
import
to_2tuple
...
...
colossalai/nn/loss/loss_1d.py
View file @
fae6c92e
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
colossalai.context
import
ParallelMode
from
torch.cuda.amp
import
custom_bwd
,
custom_fwd
from
colossalai.core
import
global_context
as
gpc
from
torch.nn.modules.loss
import
_Loss
from
colossalai.registry
import
LOSSES
from
torch.cuda.amp
import
custom_bwd
,
custom_fwd
from
colossalai.context
import
ParallelMode
from
torch.nn.modules.loss
import
_Loss
from
colossalai.core
import
global_context
as
gpc
from
colossalai.legacy.registry
import
LOSSES
class
_VocabParallelCrossEntropy1D
(
torch
.
autograd
.
Function
):
class
_VocabParallelCrossEntropy1D
(
torch
.
autograd
.
Function
):
@
staticmethod
@
custom_fwd
(
cast_inputs
=
torch
.
float32
)
@
staticmethod
def
forward
(
ctx
,
vocab_parallel_logits
,
targets
,
process_group
):
@
custom_fwd
(
cast_inputs
=
torch
.
float32
)
if
process_group
is
None
:
def
forward
(
ctx
,
vocab_parallel_logits
,
targets
,
process_group
):
process_group
=
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
)
if
process_group
is
None
:
process_group
=
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
)
# Maximum value along vocab dimension across all GPUs.
logits_max
=
torch
.
max
(
vocab_parallel_logits
,
dim
=-
1
)[
0
]
# Maximum value along vocab dimension across all GPUs.
torch
.
distributed
.
all_reduce
(
logits_max
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
process_group
)
logits_max
=
torch
.
max
(
vocab_parallel_logits
,
dim
=-
1
)[
0
]
# Subtract the maximum value.
torch
.
distributed
.
all_reduce
(
logits_max
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
process_group
)
vocab_parallel_logits
.
sub_
(
logits_max
.
unsqueeze
(
dim
=-
1
))
# Subtract the maximum value.
vocab_parallel_logits
.
sub_
(
logits_max
.
unsqueeze
(
dim
=-
1
))
# Get the partition's vocab indices
partition_vocab_size
=
vocab_parallel_logits
.
size
()[
-
1
]
# Get the partition's vocab indices
rank
=
dist
.
get_rank
(
process_group
)
partition_vocab_size
=
vocab_parallel_logits
.
size
()[
-
1
]
vocab_start_index
=
partition_vocab_size
*
rank
rank
=
dist
.
get_rank
(
process_group
)
vocab_end_index
=
vocab_start_index
+
partition_vocab_size
vocab_start_index
=
partition_vocab_size
*
rank
vocab_end_index
=
vocab_start_index
+
partition_vocab_size
# Create a mask of valid vocab ids (1 means it needs to be masked).
target_mask
=
(
targets
<
vocab_start_index
)
|
(
targets
>=
vocab_end_index
)
# Create a mask of valid vocab ids (1 means it needs to be masked).
masked_target
=
targets
.
clone
()
-
vocab_start_index
target_mask
=
(
targets
<
vocab_start_index
)
|
(
targets
>=
vocab_end_index
)
masked_target
[
target_mask
]
=
0
masked_target
=
targets
.
clone
()
-
vocab_start_index
masked_target
[
target_mask
]
=
0
# Get predicted-logits = logits[target].
# For Simplicity, we convert logits to a 2-D tensor with size
# Get predicted-logits = logits[target].
# [*, partition-vocab-size] and target to a 1-D tensor of size [*].
# For Simplicity, we convert logits to a 2-D tensor with size
logits_2d
=
vocab_parallel_logits
.
view
(
-
1
,
partition_vocab_size
)
# [*, partition-vocab-size] and target to a 1-D tensor of size [*].
masked_target_1d
=
masked_target
.
view
(
-
1
)
logits_2d
=
vocab_parallel_logits
.
view
(
-
1
,
partition_vocab_size
)
arange_1d
=
torch
.
arange
(
start
=
0
,
end
=
logits_2d
.
size
()[
0
],
device
=
logits_2d
.
device
)
masked_target_1d
=
masked_target
.
view
(
-
1
)
predicted_logits_1d
=
logits_2d
[
arange_1d
,
masked_target_1d
]
arange_1d
=
torch
.
arange
(
start
=
0
,
end
=
logits_2d
.
size
()[
0
],
device
=
logits_2d
.
device
)
predicted_logits_1d
=
predicted_logits_1d
.
clone
().
contiguous
()
predicted_logits_1d
=
logits_2d
[
arange_1d
,
masked_target_1d
]
predicted_logits
=
predicted_logits_1d
.
view_as
(
targets
)
predicted_logits_1d
=
predicted_logits_1d
.
clone
().
contiguous
()
predicted_logits
[
target_mask
]
=
0.0
predicted_logits
=
predicted_logits_1d
.
view_as
(
targets
)
# All reduce is needed to get the chunks from other GPUs.
predicted_logits
[
target_mask
]
=
0.0
torch
.
distributed
.
all_reduce
(
predicted_logits
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
process_group
)
# All reduce is needed to get the chunks from other GPUs.
torch
.
distributed
.
all_reduce
(
predicted_logits
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
process_group
)
# Sum of exponential of logits along vocab dimension across all GPUs.
exp_logits
=
torch
.
exp
(
vocab_parallel_logits
)
# Sum of exponential of logits along vocab dimension across all GPUs.
sum_exp_logits
=
exp_logits
.
sum
(
dim
=-
1
)
exp_logits
=
torch
.
exp
(
vocab_parallel_logits
)
torch
.
distributed
.
all_reduce
(
sum_exp_logits
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
process_group
)
sum_exp_logits
=
exp_logits
.
sum
(
dim
=-
1
)
torch
.
distributed
.
all_reduce
(
sum_exp_logits
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
process_group
)
# Loss = log(sum(exp(logits))) - predicted-logit.
loss
=
torch
.
log
(
sum_exp_logits
)
-
predicted_logits
# Loss = log(sum(exp(logits))) - predicted-logit.
# Store softmax, target-mask and masked-target for backward pass.
loss
=
torch
.
log
(
sum_exp_logits
)
-
predicted_logits
exp_logits
.
div_
(
sum_exp_logits
.
unsqueeze
(
dim
=-
1
))
# Store softmax, target-mask and masked-target for backward pass.
ctx
.
save_for_backward
(
exp_logits
,
target_mask
,
masked_target_1d
)
exp_logits
.
div_
(
sum_exp_logits
.
unsqueeze
(
dim
=-
1
))
return
loss
ctx
.
save_for_backward
(
exp_logits
,
target_mask
,
masked_target_1d
)
return
loss
@
staticmethod
@
custom_bwd
@
staticmethod
def
backward
(
ctx
,
grad_output
):
@
custom_bwd
def
backward
(
ctx
,
grad_output
):
# Retrieve tensors from the forward path.
softmax
,
target_mask
,
masked_target_1d
=
ctx
.
saved_tensors
# Retrieve tensors from the forward path.
softmax
,
target_mask
,
masked_target_1d
=
ctx
.
saved_tensors
# All the inputs have softmax as their gradient.
grad_input
=
softmax
# All the inputs have softmax as their gradient.
# For simplicity, work with the 2D gradient.
grad_input
=
softmax
partition_vocab_size
=
softmax
.
size
()[
-
1
]
# For simplicity, work with the 2D gradient.
grad_2d
=
grad_input
.
view
(
-
1
,
partition_vocab_size
)
partition_vocab_size
=
softmax
.
size
()[
-
1
]
grad_2d
=
grad_input
.
view
(
-
1
,
partition_vocab_size
)
# Add the gradient from matching classes.
arange_1d
=
torch
.
arange
(
start
=
0
,
end
=
grad_2d
.
size
()[
0
],
device
=
grad_2d
.
device
)
# Add the gradient from matching classes.
grad_2d
[
arange_1d
,
masked_target_1d
]
-=
(
1.0
-
target_mask
.
view
(
-
1
).
float
())
arange_1d
=
torch
.
arange
(
start
=
0
,
end
=
grad_2d
.
size
()[
0
],
device
=
grad_2d
.
device
)
grad_2d
[
arange_1d
,
masked_target_1d
]
-=
(
1.0
-
target_mask
.
view
(
-
1
).
float
())
# Finally elementwise multiplication with the output gradients.
grad_input
.
mul_
(
grad_output
.
unsqueeze
(
dim
=-
1
))
# Finally elementwise multiplication with the output gradients.
grad_input
.
mul_
(
grad_output
.
unsqueeze
(
dim
=-
1
))
return
grad_input
,
None
,
None
return
grad_input
,
None
,
None
@
LOSSES
.
register_module
class
VocabParallelCrossEntropyLoss1D
(
_Loss
):
@
LOSSES
.
register_module
"""Vocab parallel cross entropy loss for 1D parallelism.
class
VocabParallelCrossEntropyLoss1D
(
_Loss
):
"""Vocab parallel cross entropy loss for 1D parallelism.
Args:
reduction (bool, optional): whether to average the loss, defaults to True.
Args:
"""
reduction (bool, optional): whether to average the loss, defaults to True.
"""
def
__init__
(
self
,
reduction
=
True
):
super
().
__init__
()
def
__init__
(
self
,
reduction
=
True
):
self
.
reduction_mean
=
reduction
super
().
__init__
()
self
.
reduction_mean
=
reduction
def
forward
(
self
,
logits
,
targets
,
process_group
=
None
):
"""Calculate loss between logits and targets.
def
forward
(
self
,
logits
,
targets
,
process_group
=
None
):
"""Calculate loss between logits and targets.
Args:
logits (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
Args:
targets (:class:`torch.tensor`): Ground truth class indices or class probabilities.
logits (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
"""
targets (:class:`torch.tensor`): Ground truth class indices or class probabilities.
loss
=
_VocabParallelCrossEntropy1D
.
apply
(
logits
,
targets
,
process_group
)
"""
if
self
.
reduction_mean
:
loss
=
_VocabParallelCrossEntropy1D
.
apply
(
logits
,
targets
,
process_group
)
loss
=
loss
.
mean
()
if
self
.
reduction_mean
:
return
loss
loss
=
loss
.
mean
()
return
loss
colossalai/nn/loss/loss_2d.py
View file @
fae6c92e
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
torch.cuda.amp
import
custom_bwd
,
custom_fwd
from
torch.nn.functional
import
cross_entropy
from
torch.nn.modules.loss
import
_Loss
from
colossalai.context
import
ParallelMode
from
colossalai.context
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.legacy.registry
import
LOSSES
from
colossalai.nn.layer.parallel_2d
import
reduce_by_batch_2d
,
split_batch_2d
from
colossalai.nn.layer.parallel_2d
import
reduce_by_batch_2d
,
split_batch_2d
from
colossalai.nn.layer.parallel_2d._utils
import
assert_summa_initialization
from
colossalai.nn.layer.parallel_2d._utils
import
assert_summa_initialization
from
colossalai.registry
import
LOSSES
from
colossalai.utils
import
get_current_device
from
colossalai.utils
import
get_current_device
from
torch.cuda.amp
import
custom_bwd
,
custom_fwd
from
torch.nn.functional
import
cross_entropy
from
torch.nn.modules.loss
import
_Loss
@
LOSSES
.
register_module
@
LOSSES
.
register_module
...
...
colossalai/nn/loss/loss_2p5d.py
View file @
fae6c92e
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
torch.cuda.amp
import
custom_bwd
,
custom_fwd
from
torch.nn.functional
import
cross_entropy
from
torch.nn.modules.loss
import
_Loss
from
colossalai.context
import
ParallelMode
from
colossalai.context
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.legacy.registry
import
LOSSES
from
colossalai.nn.layer.parallel_2p5d
import
reduce_by_batch_2p5d
,
split_batch_2p5d
from
colossalai.nn.layer.parallel_2p5d
import
reduce_by_batch_2p5d
,
split_batch_2p5d
from
colossalai.nn.layer.parallel_2p5d._utils
import
assert_tesseract_initialization
from
colossalai.nn.layer.parallel_2p5d._utils
import
assert_tesseract_initialization
from
colossalai.registry
import
LOSSES
from
colossalai.utils
import
get_current_device
from
colossalai.utils
import
get_current_device
from
torch.cuda.amp
import
custom_bwd
,
custom_fwd
from
torch.nn.functional
import
cross_entropy
from
torch.nn.modules.loss
import
_Loss
@
LOSSES
.
register_module
@
LOSSES
.
register_module
...
...
colossalai/nn/loss/loss_3d.py
View file @
fae6c92e
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
colossalai.constants
import
INPUT_GROUP_3D
,
WEIGHT_GROUP_3D
,
OUTPUT_GROUP_3D
from
torch.cuda.amp
import
custom_bwd
,
custom_fwd
from
torch.nn.functional
import
cross_entropy
from
torch.nn.modules.loss
import
_Loss
from
colossalai.constants
import
INPUT_GROUP_3D
,
OUTPUT_GROUP_3D
,
WEIGHT_GROUP_3D
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.legacy.registry
import
LOSSES
from
colossalai.nn.layer.parallel_3d
import
reduce_by_batch_3d
,
split_tensor_3d
from
colossalai.nn.layer.parallel_3d
import
reduce_by_batch_3d
,
split_tensor_3d
from
colossalai.nn.layer.parallel_3d._utils
import
get_parallel_mode_from_env
from
colossalai.nn.layer.parallel_3d._utils
import
get_parallel_mode_from_env
from
colossalai.registry
import
LOSSES
from
colossalai.utils
import
get_current_device
from
colossalai.utils
import
get_current_device
from
torch.cuda.amp
import
custom_bwd
,
custom_fwd
from
torch.nn.functional
import
cross_entropy
from
torch.nn.modules.loss
import
_Loss
@
LOSSES
.
register_module
@
LOSSES
.
register_module
...
...
colossalai/nn/loss/loss_moe.py
View file @
fae6c92e
import
torch.nn
as
nn
import
torch.nn
as
nn
from
colossalai.registry
import
LOSSES
from
torch.nn.modules.loss
import
_Loss
from
torch.nn.modules.loss
import
_Loss
from
colossalai.context.moe_context
import
MOE_CONTEXT
from
colossalai.context.moe_context
import
MOE_CONTEXT
from
colossalai.legacy.registry
import
LOSSES
@
LOSSES
.
register_module
class
MoeCrossEntropyLoss
(
_Loss
):
@
LOSSES
.
register_module
r
"""torch.nn.CrossEntropyLoss added with auxiliary loss.
class
MoeCrossEntropyLoss
(
_Loss
):
r
"""torch.nn.CrossEntropyLoss added with auxiliary loss.
Args:
input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
Args:
target (:class:`torch.tensor`): Ground truth class indices or class probabilities.
input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
aux_weight (float, optional): Weight of auxiliary loss in total loss.Defaults 0.01.
target (:class:`torch.tensor`): Ground truth class indices or class probabilities.
aux_weight (float, optional): Weight of auxiliary loss in total loss.Defaults 0.01.
The ``args`` and ``kwargs`` should include parameters below:
::
The ``args`` and ``kwargs`` should include parameters below:
::
weight (Tensor, optional)
size_average (bool, optional)
weight (Tensor, optional)
ignore_index (int, optional)
size_average (bool, optional)
reduce (bool, optional)
ignore_index (int, optional)
reduction (str, optional)
reduce (bool, optional)
label_smoothing (float, optional)
reduction (str, optional)
label_smoothing (float, optional)
More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in
`Cross_entropy <https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html#torch.nn.functional.cross_entropy>`_.
More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in
"""
`Cross_entropy <https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html#torch.nn.functional.cross_entropy>`_.
"""
def
__init__
(
self
,
aux_weight
:
float
=
0.01
,
*
args
,
**
kwargs
):
super
().
__init__
()
def
__init__
(
self
,
aux_weight
:
float
=
0.01
,
*
args
,
**
kwargs
):
self
.
loss
=
nn
.
CrossEntropyLoss
(
*
args
,
**
kwargs
)
super
().
__init__
()
self
.
aux_weight
=
aux_weight
self
.
loss
=
nn
.
CrossEntropyLoss
(
*
args
,
**
kwargs
)
self
.
aux_weight
=
aux_weight
def
forward
(
self
,
*
args
):
"""
def
forward
(
self
,
*
args
):
The ``args`` should at least include parameters below:
"""
::
The ``args`` should at least include parameters below:
::
input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
target (:class:`torch.tensor`): Ground truth class indices or class probabilities.
input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
target (:class:`torch.tensor`): Ground truth class indices or class probabilities.
More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in
`Cross_entropy <https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html#torch.nn.functional.cross_entropy>`_.
More details about ``args``, ``kwargs`` and ``torch.nn.functional.cross_entropy`` could be found in
"""
`Cross_entropy <https://pytorch.org/docs/stable/generated/torch.nn.functional.cross_entropy.html#torch.nn.functional.cross_entropy>`_.
main_loss
=
self
.
loss
(
*
args
)
"""
aux_loss
=
MOE_CONTEXT
.
get_loss
()
main_loss
=
self
.
loss
(
*
args
)
return
main_loss
+
self
.
aux_weight
*
aux_loss
aux_loss
=
MOE_CONTEXT
.
get_loss
()
return
main_loss
+
self
.
aux_weight
*
aux_loss
@
LOSSES
.
register_module
class
MoeLoss
(
_Loss
):
@
LOSSES
.
register_module
"""A wrapper class for any loss module to add with auxiliary loss.
class
MoeLoss
(
_Loss
):
"""A wrapper class for any loss module to add with auxiliary loss.
Args:
aux_weight (float): Weight of auxiliary loss in total loss.
Args:
loss_fn (``Callable``): Loss function.
aux_weight (float): Weight of auxiliary loss in total loss.
args (list): Args in loss function.
loss_fn (``Callable``): Loss function.
kwargs (dict): Kwargs in loss function
args (list): Args in loss function.
"""
kwargs (dict): Kwargs in loss function
"""
def
__init__
(
self
,
aux_weight
:
float
,
loss_fn
,
*
args
,
**
kwargs
):
super
().
__init__
()
def
__init__
(
self
,
aux_weight
:
float
,
loss_fn
,
*
args
,
**
kwargs
):
self
.
loss_fn
=
loss_fn
(
*
args
,
**
kwargs
)
super
().
__init__
()
self
.
aux_weight
=
aux_weight
self
.
loss_fn
=
loss_fn
(
*
args
,
**
kwargs
)
self
.
aux_weight
=
aux_weight
def
forward
(
self
,
*
args
,
**
kwargs
):
"""
def
forward
(
self
,
*
args
,
**
kwargs
):
The ``args`` and ``kwargs`` should at least include parameters below:
"""
::
The ``args`` and ``kwargs`` should at least include parameters below:
::
input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
target (:class:`torch.tensor`): Ground truth class indices or class probabilities.
input (:class:`torch.tensor`): Predicted unnormalized scores (often referred to as logits).
target (:class:`torch.tensor`): Ground truth class indices or class probabilities.
Note:
The ``args`` and ``kwargs`` may include different parameters varying with different loss function.
Note:
"""
The ``args`` and ``kwargs`` may include different parameters varying with different loss function.
main_loss
=
self
.
loss_fn
(
*
args
,
**
kwargs
)
"""
aux_loss
=
MOE_CONTEXT
.
get_loss
()
main_loss
=
self
.
loss_fn
(
*
args
,
**
kwargs
)
return
main_loss
+
self
.
aux_weight
*
aux_loss
aux_loss
=
MOE_CONTEXT
.
get_loss
()
return
main_loss
+
self
.
aux_weight
*
aux_loss
colossalai/nn/lr_scheduler/cosine.py
View file @
fae6c92e
from
torch.optim.lr_scheduler
import
CosineAnnealingLR
as
_CosineAnnealingLR
from
torch.optim.lr_scheduler
import
CosineAnnealingLR
as
_CosineAnnealingLR
from
colossalai.registry
import
LR_SCHEDULERS
from
colossalai.legacy.registry
import
LR_SCHEDULERS
from
.delayed
import
DelayerScheduler
,
WarmupDelayerScheduler
,
WarmupScheduler
from
.delayed
import
DelayerScheduler
,
WarmupDelayerScheduler
,
WarmupScheduler
...
...
colossalai/nn/lr_scheduler/linear.py
View file @
fae6c92e
from
torch.optim.lr_scheduler
import
_LRScheduler
from
torch.optim.lr_scheduler
import
_LRScheduler
from
colossalai.registry
import
LR_SCHEDULERS
from
colossalai.
legacy.
registry
import
LR_SCHEDULERS
@
LR_SCHEDULERS
.
register_module
@
LR_SCHEDULERS
.
register_module
...
...
colossalai/nn/lr_scheduler/multistep.py
View file @
fae6c92e
...
@@ -2,7 +2,8 @@ from typing import List
...
@@ -2,7 +2,8 @@ from typing import List
from
torch.optim.lr_scheduler
import
MultiStepLR
as
_MultiStepLR
from
torch.optim.lr_scheduler
import
MultiStepLR
as
_MultiStepLR
from
colossalai.registry
import
LR_SCHEDULERS
from
colossalai.legacy.registry
import
LR_SCHEDULERS
from
.delayed
import
WarmupScheduler
from
.delayed
import
WarmupScheduler
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment