Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
8823cc48
Unverified
Commit
8823cc48
authored
Jan 29, 2024
by
Frank Lee
Committed by
GitHub
Jan 29, 2024
Browse files
Merge pull request #5310 from hpcaitech/feature/npu
Feature/npu
parents
bce9499e
73f4dc57
Changes
266
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
536 additions
and
131 deletions
+536
-131
colossalai/legacy/nn/loss/loss_3d.py
colossalai/legacy/nn/loss/loss_3d.py
+3
-3
colossalai/legacy/trainer/hooks/_metric_hook.py
colossalai/legacy/trainer/hooks/_metric_hook.py
+11
-11
colossalai/legacy/utils/activation_checkpoint.py
colossalai/legacy/utils/activation_checkpoint.py
+5
-5
colossalai/legacy/utils/common.py
colossalai/legacy/utils/common.py
+2
-2
colossalai/legacy/utils/memory.py
colossalai/legacy/utils/memory.py
+6
-3
colossalai/legacy/utils/profiler/legacy/comm_profiler.py
colossalai/legacy/utils/profiler/legacy/comm_profiler.py
+2
-2
colossalai/legacy/zero/gemini/stateful_tensor_mgr.py
colossalai/legacy/zero/gemini/stateful_tensor_mgr.py
+2
-2
colossalai/legacy/zero/gemini/tensor_placement_policy.py
colossalai/legacy/zero/gemini/tensor_placement_policy.py
+3
-3
colossalai/legacy/zero/shard_utils/bucket_tensor_shard_strategy.py
...i/legacy/zero/shard_utils/bucket_tensor_shard_strategy.py
+5
-3
colossalai/legacy/zero/shard_utils/tensor_shard_strategy.py
colossalai/legacy/zero/shard_utils/tensor_shard_strategy.py
+6
-4
colossalai/legacy/zero/sharded_model/sharded_model_v2.py
colossalai/legacy/zero/sharded_model/sharded_model_v2.py
+10
-4
colossalai/legacy/zero/sharded_model/zero_hook.py
colossalai/legacy/zero/sharded_model/zero_hook.py
+2
-2
colossalai/moe/_operation.py
colossalai/moe/_operation.py
+6
-11
colossalai/moe/routers.py
colossalai/moe/routers.py
+65
-51
colossalai/moe/utils.py
colossalai/moe/utils.py
+9
-18
colossalai/nn/layer/colo_attention.py
colossalai/nn/layer/colo_attention.py
+209
-0
colossalai/nn/layer/layernorm.py
colossalai/nn/layer/layernorm.py
+2
-2
colossalai/nn/layer/scaled_softmax.py
colossalai/nn/layer/scaled_softmax.py
+184
-0
colossalai/nn/optimizer/cpu_adam.py
colossalai/nn/optimizer/cpu_adam.py
+2
-3
colossalai/nn/optimizer/fused_adam.py
colossalai/nn/optimizer/fused_adam.py
+2
-2
No files found.
colossalai/legacy/nn/loss/loss_3d.py
View file @
8823cc48
...
@@ -4,12 +4,12 @@ from torch.cuda.amp import custom_bwd, custom_fwd
...
@@ -4,12 +4,12 @@ from torch.cuda.amp import custom_bwd, custom_fwd
from
torch.nn.functional
import
cross_entropy
from
torch.nn.functional
import
cross_entropy
from
torch.nn.modules.loss
import
_Loss
from
torch.nn.modules.loss
import
_Loss
from
colossalai.accelerator
import
get_accelerator
from
colossalai.legacy.constants
import
INPUT_GROUP_3D
,
OUTPUT_GROUP_3D
,
WEIGHT_GROUP_3D
from
colossalai.legacy.constants
import
INPUT_GROUP_3D
,
OUTPUT_GROUP_3D
,
WEIGHT_GROUP_3D
from
colossalai.legacy.core
import
global_context
as
gpc
from
colossalai.legacy.core
import
global_context
as
gpc
from
colossalai.legacy.nn.layer.parallel_3d
import
reduce_by_batch_3d
,
split_tensor_3d
from
colossalai.legacy.nn.layer.parallel_3d
import
reduce_by_batch_3d
,
split_tensor_3d
from
colossalai.legacy.nn.layer.parallel_3d._utils
import
get_parallel_mode_from_env
from
colossalai.legacy.nn.layer.parallel_3d._utils
import
get_parallel_mode_from_env
from
colossalai.legacy.registry
import
LOSSES
from
colossalai.legacy.registry
import
LOSSES
from
colossalai.utils
import
get_current_device
@
LOSSES
.
register_module
@
LOSSES
.
register_module
...
@@ -80,7 +80,7 @@ class _VocabParallelCrossEntropy3D(torch.autograd.Function):
...
@@ -80,7 +80,7 @@ class _VocabParallelCrossEntropy3D(torch.autograd.Function):
target_mask
=
(
targets
<
vocab_start
)
|
(
targets
>
vocab_end
)
target_mask
=
(
targets
<
vocab_start
)
|
(
targets
>
vocab_end
)
masked_target
=
targets
.
clone
()
-
vocab_start
masked_target
=
targets
.
clone
()
-
vocab_start
masked_target
[
target_mask
]
=
0
masked_target
[
target_mask
]
=
0
arange_1d
=
torch
.
arange
(
start
=
0
,
end
=
logits
.
size
()[
0
],
device
=
get_current_device
())
arange_1d
=
torch
.
arange
(
start
=
0
,
end
=
logits
.
size
()[
0
],
device
=
get_
accelerator
().
get_
current_device
())
predicted_logits
=
logits
[
arange_1d
,
masked_target
]
predicted_logits
=
logits
[
arange_1d
,
masked_target
]
predicted_logits
=
predicted_logits
.
clone
().
contiguous
().
view_as
(
targets
)
predicted_logits
=
predicted_logits
.
clone
().
contiguous
().
view_as
(
targets
)
predicted_logits
[
target_mask
]
=
0.0
predicted_logits
[
target_mask
]
=
0.0
...
@@ -110,7 +110,7 @@ class _VocabParallelCrossEntropy3D(torch.autograd.Function):
...
@@ -110,7 +110,7 @@ class _VocabParallelCrossEntropy3D(torch.autograd.Function):
grad_2d
=
input_grad
.
view
(
-
1
,
partition_vocab_size
)
grad_2d
=
input_grad
.
view
(
-
1
,
partition_vocab_size
)
# Add the gradient from matching classes.
# Add the gradient from matching classes.
arange_1d
=
torch
.
arange
(
start
=
0
,
end
=
grad_2d
.
size
()[
0
],
device
=
get_current_device
())
arange_1d
=
torch
.
arange
(
start
=
0
,
end
=
grad_2d
.
size
()[
0
],
device
=
get_
accelerator
().
get_
current_device
())
grad_2d
[
arange_1d
,
masked_target
]
-=
1.0
-
target_mask
.
view
(
-
1
).
float
()
grad_2d
[
arange_1d
,
masked_target
]
-=
1.0
-
target_mask
.
view
(
-
1
).
float
()
input_grad
.
mul_
(
output_grad
.
unsqueeze
(
dim
=-
1
))
input_grad
.
mul_
(
output_grad
.
unsqueeze
(
dim
=-
1
))
...
...
colossalai/legacy/trainer/hooks/_metric_hook.py
View file @
8823cc48
...
@@ -7,12 +7,12 @@ from typing import Callable
...
@@ -7,12 +7,12 @@ from typing import Callable
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
colossalai.accelerator
import
get_accelerator
from
colossalai.legacy.communication
import
all_reduce
from
colossalai.legacy.communication
import
all_reduce
from
colossalai.legacy.context
import
ParallelMode
from
colossalai.legacy.context
import
ParallelMode
from
colossalai.legacy.core
import
global_context
as
gpc
from
colossalai.legacy.core
import
global_context
as
gpc
from
colossalai.legacy.registry
import
HOOKS
from
colossalai.legacy.registry
import
HOOKS
from
colossalai.legacy.utils
import
is_no_pp_or_last_stage
from
colossalai.legacy.utils
import
is_no_pp_or_last_stage
from
colossalai.utils
import
get_current_device
from
._base_hook
import
BaseHook
from
._base_hook
import
BaseHook
from
._commons_
import
_format_number
from
._commons_
import
_format_number
...
@@ -82,8 +82,8 @@ class LossMetric(Metric):
...
@@ -82,8 +82,8 @@ class LossMetric(Metric):
def
__init__
(
self
,
epoch_only
):
def
__init__
(
self
,
epoch_only
):
super
().
__init__
(
epoch_only
=
epoch_only
)
super
().
__init__
(
epoch_only
=
epoch_only
)
self
.
last_step_loss
=
torch
.
zeros
(
1
,
device
=
get_current_device
())
self
.
last_step_loss
=
torch
.
zeros
(
1
,
device
=
get_
accelerator
().
get_
current_device
())
self
.
accum_loss
=
torch
.
zeros
(
1
,
device
=
get_current_device
())
self
.
accum_loss
=
torch
.
zeros
(
1
,
device
=
get_
accelerator
().
get_
current_device
())
self
.
count
=
0
self
.
count
=
0
def
reset
(
self
)
->
None
:
def
reset
(
self
)
->
None
:
...
@@ -164,10 +164,10 @@ class AccuracyMetric(Metric):
...
@@ -164,10 +164,10 @@ class AccuracyMetric(Metric):
def
__init__
(
self
,
epoch_only
:
bool
,
accuracy_func
:
Callable
):
def
__init__
(
self
,
epoch_only
:
bool
,
accuracy_func
:
Callable
):
super
().
__init__
(
epoch_only
=
epoch_only
)
super
().
__init__
(
epoch_only
=
epoch_only
)
self
.
acc
=
accuracy_func
self
.
acc
=
accuracy_func
self
.
last_step_sum
=
torch
.
zeros
(
1
,
device
=
get_current_device
())
self
.
last_step_sum
=
torch
.
zeros
(
1
,
device
=
get_
accelerator
().
get_
current_device
())
self
.
last_step_correct
=
torch
.
zeros
(
1
,
device
=
get_current_device
())
self
.
last_step_correct
=
torch
.
zeros
(
1
,
device
=
get_
accelerator
().
get_
current_device
())
self
.
accumulated_sum
=
torch
.
zeros
(
1
,
device
=
get_current_device
())
self
.
accumulated_sum
=
torch
.
zeros
(
1
,
device
=
get_
accelerator
().
get_
current_device
())
self
.
accumulated_correct
=
torch
.
zeros
(
1
,
device
=
get_current_device
())
self
.
accumulated_correct
=
torch
.
zeros
(
1
,
device
=
get_
accelerator
().
get_
current_device
())
def
reset
(
self
)
->
None
:
def
reset
(
self
)
->
None
:
self
.
last_step_sum
.
zero_
()
self
.
last_step_sum
.
zero_
()
...
@@ -320,10 +320,10 @@ class ThroughputMetric(Metric):
...
@@ -320,10 +320,10 @@ class ThroughputMetric(Metric):
super
().
__init__
(
epoch_only
=
epoch_only
)
super
().
__init__
(
epoch_only
=
epoch_only
)
self
.
ignored_steps
=
ignored_steps
self
.
ignored_steps
=
ignored_steps
self
.
cur_steps
=
0
self
.
cur_steps
=
0
self
.
accumulated_num_samples
=
torch
.
zeros
(
1
,
device
=
get_current_device
())
self
.
accumulated_num_samples
=
torch
.
zeros
(
1
,
device
=
get_
accelerator
().
get_
current_device
())
self
.
accumulated_used_time
=
torch
.
zeros
(
1
,
device
=
get_current_device
())
self
.
accumulated_used_time
=
torch
.
zeros
(
1
,
device
=
get_
accelerator
().
get_
current_device
())
self
.
last_step_num_samples
=
torch
.
zeros
(
1
,
device
=
get_current_device
())
self
.
last_step_num_samples
=
torch
.
zeros
(
1
,
device
=
get_
accelerator
().
get_
current_device
())
self
.
last_step_used_time
=
torch
.
zeros
(
1
,
device
=
get_current_device
())
self
.
last_step_used_time
=
torch
.
zeros
(
1
,
device
=
get_
accelerator
().
get_
current_device
())
self
.
_tflop_per_step
=
tflop_per_step
self
.
_tflop_per_step
=
tflop_per_step
self
.
_use_local
=
use_local
self
.
_use_local
=
use_local
...
...
colossalai/legacy/utils/activation_checkpoint.py
View file @
8823cc48
...
@@ -6,8 +6,8 @@ import weakref
...
@@ -6,8 +6,8 @@ import weakref
import
torch
import
torch
from
torch.utils.checkpoint
import
check_backward_validity
,
detach_variable
from
torch.utils.checkpoint
import
check_backward_validity
,
detach_variable
from
colossalai.accelerator
import
get_accelerator
from
colossalai.legacy.context.random
import
get_current_mode
,
get_states
,
set_mode
,
set_seed_states
,
sync_states
from
colossalai.legacy.context.random
import
get_current_mode
,
get_states
,
set_mode
,
set_seed_states
,
sync_states
from
colossalai.utils.device
import
autocast
,
get_current_device
def
copy_to_device
(
obj
,
device
):
def
copy_to_device
(
obj
,
device
):
...
@@ -33,7 +33,7 @@ class CheckpointFunction(torch.autograd.Function):
...
@@ -33,7 +33,7 @@ class CheckpointFunction(torch.autograd.Function):
check_backward_validity
(
args
)
check_backward_validity
(
args
)
ctx
.
run_function
=
run_function
ctx
.
run_function
=
run_function
ctx
.
activation_offload
=
activation_offload
ctx
.
activation_offload
=
activation_offload
ctx
.
device
=
get_current_device
()
ctx
.
device
=
get_accelerator
().
get_current_device
()
# preserve rng states
# preserve rng states
ctx
.
fwd_cpu_rng_state
=
torch
.
get_rng_state
()
ctx
.
fwd_cpu_rng_state
=
torch
.
get_rng_state
()
...
@@ -110,7 +110,7 @@ class CheckpointFunction(torch.autograd.Function):
...
@@ -110,7 +110,7 @@ class CheckpointFunction(torch.autograd.Function):
inputs
[
idx
]
=
tensors
[
i
]
inputs
[
idx
]
=
tensors
[
i
]
detached_inputs
=
detach_variable
(
tuple
(
inputs
))
detached_inputs
=
detach_variable
(
tuple
(
inputs
))
if
ctx
.
had_autocast_in_fwd
:
if
ctx
.
had_autocast_in_fwd
:
with
torch
.
enable_grad
(),
autocast
():
with
torch
.
enable_grad
(),
get_accelerator
().
autocast
()
()
:
outputs
=
ctx
.
run_function
(
*
detached_inputs
)
outputs
=
ctx
.
run_function
(
*
detached_inputs
)
else
:
else
:
with
torch
.
enable_grad
():
with
torch
.
enable_grad
():
...
@@ -226,7 +226,7 @@ def _checkpoint_without_reentrant(function, activation_offload=False, *args):
...
@@ -226,7 +226,7 @@ def _checkpoint_without_reentrant(function, activation_offload=False, *args):
# rerun forward, the inner_pack will store all the activations in storage
# rerun forward, the inner_pack will store all the activations in storage
if
has_autocast_in_fwd
:
if
has_autocast_in_fwd
:
with
torch
.
enable_grad
(),
autocast
(),
torch
.
autograd
.
graph
.
saved_tensors_hooks
(
with
torch
.
enable_grad
(),
get_accelerator
().
autocast
()
()
,
torch
.
autograd
.
graph
.
saved_tensors_hooks
(
inner_pack
,
inner_unpack
inner_pack
,
inner_unpack
):
):
_unused
=
function
(
*
args
)
_unused
=
function
(
*
args
)
...
@@ -245,7 +245,7 @@ def _checkpoint_without_reentrant(function, activation_offload=False, *args):
...
@@ -245,7 +245,7 @@ def _checkpoint_without_reentrant(function, activation_offload=False, *args):
# get device if we need to offload the activation
# get device if we need to offload the activation
if
activation_offload
:
if
activation_offload
:
device
=
get_current_device
()
device
=
get_accelerator
().
get_current_device
()
# run function with pack and unpack as saved_tensors_hooks
# run function with pack and unpack as saved_tensors_hooks
with
torch
.
autograd
.
graph
.
saved_tensors_hooks
(
pack
,
unpack
):
with
torch
.
autograd
.
graph
.
saved_tensors_hooks
(
pack
,
unpack
):
...
...
colossalai/legacy/utils/common.py
View file @
8823cc48
...
@@ -96,9 +96,9 @@ def _calc_l2_norm(grads):
...
@@ -96,9 +96,9 @@ def _calc_l2_norm(grads):
global
fused_optim
global
fused_optim
if
fused_optim
is
None
:
if
fused_optim
is
None
:
from
colossalai.kernel.
op_buil
der
import
FusedOptim
Buil
der
from
colossalai.kernel.
kernel_loa
der
import
FusedOptim
izerLoa
der
fused_optim
=
FusedOptim
Buil
der
().
load
()
fused_optim
=
FusedOptim
izerLoa
der
().
load
()
norm
=
0.0
norm
=
0.0
if
len
(
grads
)
>
0
:
if
len
(
grads
)
>
0
:
...
...
colossalai/legacy/utils/memory.py
View file @
8823cc48
...
@@ -6,9 +6,9 @@ import torch
...
@@ -6,9 +6,9 @@ import torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
packaging
import
version
from
packaging
import
version
from
colossalai.accelerator
import
get_accelerator
from
colossalai.legacy.core
import
global_context
as
gpc
from
colossalai.legacy.core
import
global_context
as
gpc
from
colossalai.logging
import
get_dist_logger
from
colossalai.logging
import
get_dist_logger
from
colossalai.utils
import
get_current_device
_GLOBAL_CUDA_MEM_FRACTION
=
1.0
_GLOBAL_CUDA_MEM_FRACTION
=
1.0
_GLOBAL_CPU_MEM_CAPACITY
=
-
1
_GLOBAL_CPU_MEM_CAPACITY
=
-
1
...
@@ -112,7 +112,10 @@ def colo_device_memory_capacity(device: torch.device) -> int:
...
@@ -112,7 +112,10 @@ def colo_device_memory_capacity(device: torch.device) -> int:
# In the context of 1-CPU-N-GPU, the memory capacity of the current process is 1/N overall CPU memory.
# In the context of 1-CPU-N-GPU, the memory capacity of the current process is 1/N overall CPU memory.
return
colo_get_cpu_memory_capacity
()
/
gpc
.
num_processes_on_current_node
return
colo_get_cpu_memory_capacity
()
/
gpc
.
num_processes_on_current_node
if
device
.
type
==
"cuda"
:
if
device
.
type
==
"cuda"
:
return
torch
.
cuda
.
get_device_properties
(
get_current_device
()).
total_memory
*
_GLOBAL_CUDA_MEM_FRACTION
return
(
torch
.
cuda
.
get_device_properties
(
get_accelerator
().
get_current_device
()).
total_memory
*
_GLOBAL_CUDA_MEM_FRACTION
)
def
colo_device_memory_used
(
device
:
torch
.
device
)
->
int
:
def
colo_device_memory_used
(
device
:
torch
.
device
)
->
int
:
...
@@ -153,7 +156,7 @@ def colo_set_process_memory_fraction(ratio: float) -> None:
...
@@ -153,7 +156,7 @@ def colo_set_process_memory_fraction(ratio: float) -> None:
return
return
global
_GLOBAL_CUDA_MEM_FRACTION
global
_GLOBAL_CUDA_MEM_FRACTION
_GLOBAL_CUDA_MEM_FRACTION
=
ratio
_GLOBAL_CUDA_MEM_FRACTION
=
ratio
torch
.
cuda
.
set_per_process_memory_fraction
(
_GLOBAL_CUDA_MEM_FRACTION
,
get_current_device
())
torch
.
cuda
.
set_per_process_memory_fraction
(
_GLOBAL_CUDA_MEM_FRACTION
,
get_accelerator
().
get_current_device
())
def
colo_set_cpu_memory_capacity
(
size
:
int
)
->
None
:
def
colo_set_cpu_memory_capacity
(
size
:
int
)
->
None
:
...
...
colossalai/legacy/utils/profiler/legacy/comm_profiler.py
View file @
8823cc48
...
@@ -8,7 +8,7 @@ import torch.distributed as dist
...
@@ -8,7 +8,7 @@ import torch.distributed as dist
from
torch.autograd.profiler
import
profile
from
torch.autograd.profiler
import
profile
from
torch.distributed
import
ReduceOp
from
torch.distributed
import
ReduceOp
from
colossalai.
utils
import
get_
current_device
from
colossalai.
accelerator
import
get_
accelerator
from
.prof_utils
import
BaseProfiler
,
_format_bandwidth
,
_format_memory
,
_format_time
from
.prof_utils
import
BaseProfiler
,
_format_bandwidth
,
_format_memory
,
_format_time
...
@@ -177,7 +177,7 @@ class CommProfiler(BaseProfiler):
...
@@ -177,7 +177,7 @@ class CommProfiler(BaseProfiler):
assert
current_comm_event
is
not
None
,
"dist op has not been found"
assert
current_comm_event
is
not
None
,
"dist op has not been found"
buffer
=
torch
.
tensor
([
current_comm_event
.
self_cuda_time
],
device
=
get_current_device
())
buffer
=
torch
.
tensor
([
current_comm_event
.
self_cuda_time
],
device
=
get_
accelerator
().
get_
current_device
())
torch_all_reduce
(
buffer
,
op
=
ReduceOp
.
MIN
,
group
=
group
)
torch_all_reduce
(
buffer
,
op
=
ReduceOp
.
MIN
,
group
=
group
)
current_comm_event
.
self_cuda_time
=
buffer
.
item
()
current_comm_event
.
self_cuda_time
=
buffer
.
item
()
...
...
colossalai/legacy/zero/gemini/stateful_tensor_mgr.py
View file @
8823cc48
...
@@ -3,7 +3,7 @@ import types
...
@@ -3,7 +3,7 @@ import types
from
time
import
time
from
time
import
time
from
typing
import
List
from
typing
import
List
from
colossalai.
utils.device
import
get_current_device
from
colossalai.
accelerator
import
get_accelerator
from
.stateful_tensor
import
StatefulTensor
,
TensorState
from
.stateful_tensor
import
StatefulTensor
,
TensorState
from
.tensor_placement_policy
import
TensorPlacementPolicy
from
.tensor_placement_policy
import
TensorPlacementPolicy
...
@@ -69,7 +69,7 @@ class StatefulTensorMgr(object):
...
@@ -69,7 +69,7 @@ class StatefulTensorMgr(object):
# move COMPUTE tensors to CUDA
# move COMPUTE tensors to CUDA
self
.
_cpu_gpu_move_volume
+=
cuda_demand
self
.
_cpu_gpu_move_volume
+=
cuda_demand
for
t
in
move_to_cuda_tensor_list
:
for
t
in
move_to_cuda_tensor_list
:
colo_model_data_tensor_move_inline
(
t
,
get_current_device
())
colo_model_data_tensor_move_inline
(
t
,
get_accelerator
().
get_current_device
())
@
property
@
property
def
cpu_gpu_move_volume
(
self
):
def
cpu_gpu_move_volume
(
self
):
...
...
colossalai/legacy/zero/gemini/tensor_placement_policy.py
View file @
8823cc48
...
@@ -5,8 +5,8 @@ from typing import List, Optional, Type
...
@@ -5,8 +5,8 @@ from typing import List, Optional, Type
import
torch
import
torch
from
colossalai.accelerator
import
get_accelerator
from
colossalai.legacy.utils.memory
import
colo_device_memory_capacity
from
colossalai.legacy.utils.memory
import
colo_device_memory_capacity
from
colossalai.utils
import
get_current_device
from
colossalai.zero.gemini.memory_tracer
import
MemStatsCollector
from
colossalai.zero.gemini.memory_tracer
import
MemStatsCollector
from
.stateful_tensor
import
StatefulTensor
from
.stateful_tensor
import
StatefulTensor
...
@@ -38,7 +38,7 @@ class CPUTensorPlacementPolicy(TensorPlacementPolicy):
...
@@ -38,7 +38,7 @@ class CPUTensorPlacementPolicy(TensorPlacementPolicy):
class
CUDATensorPlacementPolicy
(
TensorPlacementPolicy
):
class
CUDATensorPlacementPolicy
(
TensorPlacementPolicy
):
def
__init__
(
self
,
mem_stats_collector
:
Optional
[
MemStatsCollector
]
=
None
)
->
None
:
def
__init__
(
self
,
mem_stats_collector
:
Optional
[
MemStatsCollector
]
=
None
)
->
None
:
assert
torch
.
cuda
.
is_available
(),
"Cannot use CUDATensorPlacementPolicy when CUDA is not available"
assert
torch
.
cuda
.
is_available
(),
"Cannot use CUDATensorPlacementPolicy when CUDA is not available"
super
().
__init__
(
get_current_device
(),
mem_stats_collector
=
mem_stats_collector
)
super
().
__init__
(
get_
accelerator
().
get_
current_device
(),
mem_stats_collector
=
mem_stats_collector
)
def
evict_tensors
(
self
,
hold_cuda_tensor_list
:
List
[
StatefulTensor
],
**
kwargs
)
->
int
:
def
evict_tensors
(
self
,
hold_cuda_tensor_list
:
List
[
StatefulTensor
],
**
kwargs
)
->
int
:
return
0
,
0
return
0
,
0
...
@@ -78,7 +78,7 @@ class AutoTensorPlacementPolicy(TensorPlacementPolicy):
...
@@ -78,7 +78,7 @@ class AutoTensorPlacementPolicy(TensorPlacementPolicy):
int: the volume of memory that is evicted
int: the volume of memory that is evicted
"""
"""
start
=
time
()
start
=
time
()
cuda_capacity
=
colo_device_memory_capacity
(
get_current_device
())
cuda_capacity
=
colo_device_memory_capacity
(
get_
accelerator
().
get_
current_device
())
used_cuda_model_data
=
StatefulTensor
.
GST_MGR
.
total_mem
[
"cuda"
]
used_cuda_model_data
=
StatefulTensor
.
GST_MGR
.
total_mem
[
"cuda"
]
if
warmup
:
if
warmup
:
# We designate a part of CUDA memory for model data in warmup iterations.
# We designate a part of CUDA memory for model data in warmup iterations.
...
...
colossalai/legacy/zero/shard_utils/bucket_tensor_shard_strategy.py
View file @
8823cc48
...
@@ -4,8 +4,8 @@ import torch
...
@@ -4,8 +4,8 @@ import torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
torch._utils
import
_flatten_dense_tensors
as
flatten
from
torch._utils
import
_flatten_dense_tensors
as
flatten
from
colossalai.accelerator
import
get_accelerator
from
colossalai.legacy.zero.sharded_param.sharded_tensor
import
ShardedTensor
from
colossalai.legacy.zero.sharded_param.sharded_tensor
import
ShardedTensor
from
colossalai.utils
import
get_current_device
from
.tensor_shard_strategy
import
TensorShardStrategy
from
.tensor_shard_strategy
import
TensorShardStrategy
...
@@ -30,9 +30,11 @@ class BucketTensorShardStrategy(TensorShardStrategy):
...
@@ -30,9 +30,11 @@ class BucketTensorShardStrategy(TensorShardStrategy):
rank
=
dist
.
get_rank
(
process_group
)
rank
=
dist
.
get_rank
(
process_group
)
for
i
in
range
(
world_size
):
for
i
in
range
(
world_size
):
if
i
==
rank
:
if
i
==
rank
:
buffer_list
.
append
(
flatten
([
t
.
payload
for
t
in
tensor_list
]).
cuda
(
get_current_device
()))
buffer_list
.
append
(
flatten
([
t
.
payload
for
t
in
tensor_list
]).
cuda
(
get_accelerator
().
get_current_device
())
)
else
:
else
:
buffer_list
.
append
(
torch
.
zeros
(
buffer_size
,
dtype
=
dtype
,
device
=
get_current_device
()))
buffer_list
.
append
(
torch
.
zeros
(
buffer_size
,
dtype
=
dtype
,
device
=
get_
accelerator
().
get_
current_device
()))
dist
.
all_gather
(
buffer_list
,
buffer_list
[
rank
],
group
=
process_group
)
dist
.
all_gather
(
buffer_list
,
buffer_list
[
rank
],
group
=
process_group
)
# Move to target device before splitting buffer
# Move to target device before splitting buffer
# Ensure we utilize maximum PCIE bandwidth
# Ensure we utilize maximum PCIE bandwidth
...
...
colossalai/legacy/zero/shard_utils/tensor_shard_strategy.py
View file @
8823cc48
...
@@ -3,11 +3,11 @@ from typing import List, Optional
...
@@ -3,11 +3,11 @@ from typing import List, Optional
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
colossalai.accelerator
import
get_accelerator
from
colossalai.legacy.zero.gemini.tensor_utils
import
colo_model_data_tensor_move_inline
from
colossalai.legacy.zero.gemini.tensor_utils
import
colo_model_data_tensor_move_inline
from
colossalai.legacy.zero.shard_utils
import
BaseShardStrategy
from
colossalai.legacy.zero.shard_utils
import
BaseShardStrategy
from
colossalai.legacy.zero.shard_utils.commons
import
get_shard
from
colossalai.legacy.zero.shard_utils.commons
import
get_shard
from
colossalai.legacy.zero.sharded_param.sharded_tensor
import
ShardedTensor
from
colossalai.legacy.zero.sharded_param.sharded_tensor
import
ShardedTensor
from
colossalai.utils
import
get_current_device
class
TensorShardStrategy
(
BaseShardStrategy
):
class
TensorShardStrategy
(
BaseShardStrategy
):
...
@@ -34,9 +34,9 @@ class TensorShardStrategy(BaseShardStrategy):
...
@@ -34,9 +34,9 @@ class TensorShardStrategy(BaseShardStrategy):
if
t
.
is_sharded
:
if
t
.
is_sharded
:
return
return
if
t
.
payload
.
device
.
type
==
"cuda"
:
if
t
.
payload
.
device
.
type
==
"cuda"
:
assert
t
.
payload
.
device
==
get_current_device
(),
(
assert
t
.
payload
.
device
==
get_accelerator
().
get_current_device
(),
(
f
"shard tensor on cuda device index
{
t
.
payload
.
device
.
index
}
,"
f
"shard tensor on cuda device index
{
t
.
payload
.
device
.
index
}
,"
f
" but current cuda device is
{
get_current_device
()
}
"
f
" but current cuda device is
{
get_
accelerator
().
get_
current_device
()
}
"
)
)
sharded_payload
,
_
=
get_shard
(
t
.
payload
,
dist
.
get_rank
(
process_group
),
dist
.
get_world_size
(
process_group
))
sharded_payload
,
_
=
get_shard
(
t
.
payload
,
dist
.
get_rank
(
process_group
),
dist
.
get_world_size
(
process_group
))
t
.
payload_reset
(
sharded_payload
)
t
.
payload_reset
(
sharded_payload
)
...
@@ -50,7 +50,9 @@ class TensorShardStrategy(BaseShardStrategy):
...
@@ -50,7 +50,9 @@ class TensorShardStrategy(BaseShardStrategy):
world_size
=
dist
.
get_world_size
(
process_group
)
world_size
=
dist
.
get_world_size
(
process_group
)
rank
=
dist
.
get_rank
(
process_group
)
rank
=
dist
.
get_rank
(
process_group
)
buffer
=
torch
.
empty
(
payload_numel
*
world_size
,
dtype
=
t
.
payload
.
dtype
,
device
=
get_current_device
())
buffer
=
torch
.
empty
(
payload_numel
*
world_size
,
dtype
=
t
.
payload
.
dtype
,
device
=
get_accelerator
().
get_current_device
()
)
buffer_list
=
list
(
torch
.
chunk
(
buffer
,
chunks
=
world_size
,
dim
=
0
))
buffer_list
=
list
(
torch
.
chunk
(
buffer
,
chunks
=
world_size
,
dim
=
0
))
buffer_list
[
rank
].
copy_
(
t
.
payload
)
buffer_list
[
rank
].
copy_
(
t
.
payload
)
...
...
colossalai/legacy/zero/sharded_model/sharded_model_v2.py
View file @
8823cc48
...
@@ -10,6 +10,7 @@ import torch.nn as nn
...
@@ -10,6 +10,7 @@ import torch.nn as nn
from
torch.distributed
import
ProcessGroup
from
torch.distributed
import
ProcessGroup
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
colossalai.accelerator
import
get_accelerator
from
colossalai.legacy.context.parallel_mode
import
ParallelMode
from
colossalai.legacy.context.parallel_mode
import
ParallelMode
from
colossalai.legacy.core
import
global_context
as
gpc
from
colossalai.legacy.core
import
global_context
as
gpc
from
colossalai.legacy.utils.memory
import
colo_device_memory_capacity
from
colossalai.legacy.utils.memory
import
colo_device_memory_capacity
...
@@ -22,7 +23,7 @@ from colossalai.legacy.zero.gemini.tensor_utils import colo_model_data_move_to_c
...
@@ -22,7 +23,7 @@ from colossalai.legacy.zero.gemini.tensor_utils import colo_model_data_move_to_c
from
colossalai.legacy.zero.shard_utils
import
BaseShardStrategy
from
colossalai.legacy.zero.shard_utils
import
BaseShardStrategy
from
colossalai.legacy.zero.sharded_model.reduce_scatter
import
ReduceScatterBucketer
from
colossalai.legacy.zero.sharded_model.reduce_scatter
import
ReduceScatterBucketer
from
colossalai.logging
import
get_dist_logger
from
colossalai.logging
import
get_dist_logger
from
colossalai.utils
import
disposable
,
get_current_device
from
colossalai.utils
import
disposable
from
colossalai.zero.gemini.memory_tracer
import
MemStatsCollector
from
colossalai.zero.gemini.memory_tracer
import
MemStatsCollector
from
._utils
import
(
from
._utils
import
(
...
@@ -212,8 +213,12 @@ class ShardedModelV2(nn.Module):
...
@@ -212,8 +213,12 @@ class ShardedModelV2(nn.Module):
self
.
logger
.
error
(
f
"dump memory tracer collected information to a
{
filename
}
"
,
ranks
=
[
0
])
self
.
logger
.
error
(
f
"dump memory tracer collected information to a
{
filename
}
"
,
ranks
=
[
0
])
if
gpc
.
get_global_rank
()
==
0
:
if
gpc
.
get_global_rank
()
==
0
:
with
open
(
filename
,
"w+"
)
as
f
:
with
open
(
filename
,
"w+"
)
as
f
:
f
.
write
(
f
"cuda reserved
{
torch
.
cuda
.
memory_reserved
(
get_current_device
())
/
1e9
}
GB
\n
"
)
f
.
write
(
f
.
write
(
f
"cuda max allocated
{
torch
.
cuda
.
max_memory_allocated
(
get_current_device
())
/
1e9
}
GB
\n
"
)
f
"cuda reserved
{
torch
.
cuda
.
memory_reserved
(
get_accelerator
().
get_current_device
())
/
1e9
}
GB
\n
"
)
f
.
write
(
f
"cuda max allocated
{
torch
.
cuda
.
max_memory_allocated
(
get_accelerator
().
get_current_device
())
/
1e9
}
GB
\n
"
)
f
.
write
(
"CUDA model data (GB)
\n
"
)
f
.
write
(
"CUDA model data (GB)
\n
"
)
f
.
write
(
"
\n
"
)
f
.
write
(
"
\n
"
)
f
.
write
(
"CUDA non model data (GB)
\n
"
)
f
.
write
(
"CUDA non model data (GB)
\n
"
)
...
@@ -266,7 +271,8 @@ class ShardedModelV2(nn.Module):
...
@@ -266,7 +271,8 @@ class ShardedModelV2(nn.Module):
# model data is fixed in cuda during training.
# model data is fixed in cuda during training.
# cuda margin space can be used to store OS.
# cuda margin space can be used to store OS.
self
.
_cuda_margin_space
=
(
self
.
_cuda_margin_space
=
(
colo_device_memory_capacity
(
get_current_device
())
-
self
.
_memstats_collector
.
_memstats
.
max_overall_cuda
colo_device_memory_capacity
(
get_accelerator
().
get_current_device
())
-
self
.
_memstats_collector
.
_memstats
.
max_overall_cuda
)
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
...
...
colossalai/legacy/zero/sharded_model/zero_hook.py
View file @
8823cc48
...
@@ -3,13 +3,13 @@ from typing import Optional
...
@@ -3,13 +3,13 @@ from typing import Optional
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
colossalai.accelerator
import
get_accelerator
from
colossalai.legacy.registry
import
OPHOOKS
from
colossalai.legacy.registry
import
OPHOOKS
from
colossalai.legacy.zero.gemini.ophooks
import
BaseOpHook
from
colossalai.legacy.zero.gemini.ophooks
import
BaseOpHook
from
colossalai.legacy.zero.gemini.stateful_tensor
import
TensorState
from
colossalai.legacy.zero.gemini.stateful_tensor
import
TensorState
from
colossalai.legacy.zero.gemini.stateful_tensor_mgr
import
StatefulTensorMgr
from
colossalai.legacy.zero.gemini.stateful_tensor_mgr
import
StatefulTensorMgr
from
colossalai.legacy.zero.shard_utils
import
BaseShardStrategy
from
colossalai.legacy.zero.shard_utils
import
BaseShardStrategy
from
colossalai.logging
import
get_dist_logger
from
colossalai.logging
import
get_dist_logger
from
colossalai.utils
import
get_current_device
from
colossalai.zero.gemini.memory_tracer
import
MemStatsCollector
from
colossalai.zero.gemini.memory_tracer
import
MemStatsCollector
...
@@ -33,7 +33,7 @@ class ZeroHook(BaseOpHook):
...
@@ -33,7 +33,7 @@ class ZeroHook(BaseOpHook):
self
.
process_group
=
process_group
self
.
process_group
=
process_group
# NOTE(jiaruifang) Now the computing device of FWD and BWD is always on GPU
# NOTE(jiaruifang) Now the computing device of FWD and BWD is always on GPU
self
.
computing_device
=
get_current_device
()
self
.
computing_device
=
get_accelerator
().
get_current_device
()
self
.
_memstarts_collector
=
memstarts_collector
self
.
_memstarts_collector
=
memstarts_collector
self
.
_stateful_tensor_mgr
=
stateful_tensor_mgr
self
.
_stateful_tensor_mgr
=
stateful_tensor_mgr
...
...
colossalai/moe/_operation.py
View file @
8823cc48
...
@@ -11,9 +11,9 @@ MOE_KERNEL = None
...
@@ -11,9 +11,9 @@ MOE_KERNEL = None
def
load_moe
():
def
load_moe
():
global
MOE_KERNEL
global
MOE_KERNEL
from
colossalai.kernel.
op_buil
der
import
M
OEBuil
der
from
colossalai.kernel.
kernel_loa
der
import
M
oeLoa
der
MOE_KERNEL
=
M
OEBuil
der
().
load
()
MOE_KERNEL
=
M
oeLoa
der
().
load
()
class
AllGather
(
torch
.
autograd
.
Function
):
class
AllGather
(
torch
.
autograd
.
Function
):
...
@@ -145,14 +145,8 @@ class AllToAll(torch.autograd.Function):
...
@@ -145,14 +145,8 @@ class AllToAll(torch.autograd.Function):
class
HierarchicalAllToAll
(
torch
.
autograd
.
Function
):
class
HierarchicalAllToAll
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
def
forward
(
ctx
:
Any
,
inputs
:
Tensor
,
groups
:
Tuple
[
ProcessGroup
,
ProcessGroup
],
src_rank
:
int
)
->
Tensor
:
ctx
:
Any
,
inputs
:
Tensor
,
groups
:
Tuple
[
ProcessGroup
,
ProcessGroup
],
src_rank
:
int
)
->
Tensor
:
"""
"""
Returns:
Returns:
outputs: Tensor
outputs: Tensor
...
@@ -276,8 +270,9 @@ class MoeCombine(torch.autograd.Function):
...
@@ -276,8 +270,9 @@ class MoeCombine(torch.autograd.Function):
if
tokens_grad
.
dtype
!=
torch
.
float32
:
if
tokens_grad
.
dtype
!=
torch
.
float32
:
tokens_grad
=
tokens_grad
.
to
(
torch
.
float32
)
tokens_grad
=
tokens_grad
.
to
(
torch
.
float32
)
d_expert
,
d_logits
=
MOE_KERNEL
.
combine_backward
(
ctx
.
s
,
ctx
.
e
,
ctx
.
c
,
ctx
.
h
,
tokens_grad
,
expert_tokens
,
logits
,
d_expert
,
d_logits
=
MOE_KERNEL
.
combine_backward
(
mask
,
dest_idx
)
ctx
.
s
,
ctx
.
e
,
ctx
.
c
,
ctx
.
h
,
tokens_grad
,
expert_tokens
,
logits
,
mask
,
dest_idx
)
if
d_expert
.
dtype
!=
ctx
.
dtype
:
if
d_expert
.
dtype
!=
ctx
.
dtype
:
d_expert
=
d_expert
.
to
(
ctx
.
dtype
)
d_expert
=
d_expert
.
to
(
ctx
.
dtype
)
...
...
colossalai/moe/routers.py
View file @
8823cc48
...
@@ -8,9 +8,9 @@ import torch.nn as nn
...
@@ -8,9 +8,9 @@ import torch.nn as nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch.distributed
import
ProcessGroup
from
torch.distributed
import
ProcessGroup
from
colossalai.accelerator
import
get_accelerator
from
colossalai.moe._operation
import
moe_cumsum
from
colossalai.moe._operation
import
moe_cumsum
from
colossalai.moe.manager
import
MOE_MANAGER
from
colossalai.moe.manager
import
MOE_MANAGER
from
colossalai.utils
import
get_current_device
class
MoeRouter
(
nn
.
Module
,
ABC
):
class
MoeRouter
(
nn
.
Module
,
ABC
):
...
@@ -24,14 +24,16 @@ class MoeRouter(nn.Module, ABC):
...
@@ -24,14 +24,16 @@ class MoeRouter(nn.Module, ABC):
drop_tks (bool, optional): Whether drops tokens in evaluation
drop_tks (bool, optional): Whether drops tokens in evaluation
"""
"""
def
__init__
(
self
,
def
__init__
(
k_value
:
int
,
self
,
capacity_factor_train
:
float
,
k_value
:
int
,
capacity_factor_eval
:
float
,
capacity_factor_train
:
float
,
min_capacity
:
int
,
capacity_factor_eval
:
float
,
noisy_func
:
Optional
[
Callable
]
=
None
,
min_capacity
:
int
,
drop_tks
:
bool
=
True
,
noisy_func
:
Optional
[
Callable
]
=
None
,
use_kernel
:
bool
=
False
):
drop_tks
:
bool
=
True
,
use_kernel
:
bool
=
False
,
):
super
().
__init__
()
super
().
__init__
()
self
.
k_value
=
k_value
self
.
k_value
=
k_value
self
.
capacity_factor_train
=
capacity_factor_train
self
.
capacity_factor_train
=
capacity_factor_train
...
@@ -68,8 +70,9 @@ class MoeRouter(nn.Module, ABC):
...
@@ -68,8 +70,9 @@ class MoeRouter(nn.Module, ABC):
if
router_probs
.
dim
()
==
expert_indices
.
dim
()
==
2
:
if
router_probs
.
dim
()
==
expert_indices
.
dim
()
==
2
:
router_probs
=
router_probs
.
unsqueeze
(
0
)
router_probs
=
router_probs
.
unsqueeze
(
0
)
expert_indices
=
expert_indices
.
unsqueeze
(
0
)
expert_indices
=
expert_indices
.
unsqueeze
(
0
)
assert
router_probs
.
dim
()
==
expert_indices
.
dim
()
==
3
,
\
assert
(
"router_probs must be 3D tensor and expert_indices must be 4D tensor"
router_probs
.
dim
()
==
expert_indices
.
dim
()
==
3
),
"router_probs must be 3D tensor and expert_indices must be 4D tensor"
# Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts].
# Shape: [num_groups, tokens_per_group, num_selected_experts, num_experts].
expert_mask
=
F
.
one_hot
(
expert_indices
,
num_experts
)
expert_mask
=
F
.
one_hot
(
expert_indices
,
num_experts
)
...
@@ -122,25 +125,29 @@ class Top1Router(MoeRouter):
...
@@ -122,25 +125,29 @@ class Top1Router(MoeRouter):
drop_tks (bool, optional): Whether drops tokens in evaluation
drop_tks (bool, optional): Whether drops tokens in evaluation
"""
"""
def
__init__
(
self
,
def
__init__
(
capacity_factor_train
:
float
=
1.25
,
self
,
capacity_factor_eval
:
float
=
2.0
,
capacity_factor_train
:
float
=
1.25
,
min_capacity
:
int
=
4
,
capacity_factor_eval
:
float
=
2.0
,
select_policy
:
str
=
"first"
,
min_capacity
:
int
=
4
,
noisy_func
:
Optional
[
Callable
]
=
None
,
select_policy
:
str
=
"first"
,
drop_tks
:
bool
=
True
):
noisy_func
:
Optional
[
Callable
]
=
None
,
super
().
__init__
(
k_value
=
1
,
drop_tks
:
bool
=
True
,
capacity_factor_train
=
capacity_factor_train
,
):
capacity_factor_eval
=
capacity_factor_eval
,
super
().
__init__
(
min_capacity
=
min_capacity
,
k_value
=
1
,
noisy_func
=
noisy_func
,
capacity_factor_train
=
capacity_factor_train
,
drop_tks
=
drop_tks
)
capacity_factor_eval
=
capacity_factor_eval
,
min_capacity
=
min_capacity
,
noisy_func
=
noisy_func
,
drop_tks
=
drop_tks
,
)
self
.
select_policy
=
select_policy
self
.
select_policy
=
select_policy
assert
select_policy
in
{
"first"
,
"random"
}
assert
select_policy
in
{
"first"
,
"random"
}
if
select_policy
==
"random"
:
if
select_policy
==
"random"
:
self
.
uniform
=
torch
.
distributions
.
uniform
.
Uniform
(
self
.
uniform
=
torch
.
distributions
.
uniform
.
Uniform
(
low
=
torch
.
tensor
(
0.0
,
device
=
get_current_device
()),
low
=
torch
.
tensor
(
0.0
,
device
=
get_
accelerator
().
get_
current_device
()),
high
=
torch
.
tensor
(
1.0
,
device
=
get_current_device
())
high
=
torch
.
tensor
(
1.0
,
device
=
get_
accelerator
().
get_
current_device
())
,
).
rsample
).
rsample
def
forward
(
self
,
inputs
:
torch
.
Tensor
,
use_kernel
:
bool
=
False
,
ep_group
:
Optional
[
ProcessGroup
]
=
None
)
->
Tuple
:
def
forward
(
self
,
inputs
:
torch
.
Tensor
,
use_kernel
:
bool
=
False
,
ep_group
:
Optional
[
ProcessGroup
]
=
None
)
->
Tuple
:
...
@@ -216,18 +223,22 @@ class Top2Router(MoeRouter):
...
@@ -216,18 +223,22 @@ class Top2Router(MoeRouter):
drop_tks (bool, optional): Whether drops tokens in evaluation.
drop_tks (bool, optional): Whether drops tokens in evaluation.
"""
"""
def
__init__
(
self
,
def
__init__
(
capacity_factor_train
:
float
=
1.25
,
self
,
capacity_factor_eval
:
float
=
2.0
,
capacity_factor_train
:
float
=
1.25
,
min_capacity
:
int
=
4
,
capacity_factor_eval
:
float
=
2.0
,
noisy_func
:
Optional
[
Callable
]
=
None
,
min_capacity
:
int
=
4
,
drop_tks
:
bool
=
True
):
noisy_func
:
Optional
[
Callable
]
=
None
,
super
().
__init__
(
k_value
=
2
,
drop_tks
:
bool
=
True
,
capacity_factor_train
=
capacity_factor_train
,
):
capacity_factor_eval
=
capacity_factor_eval
,
super
().
__init__
(
min_capacity
=
min_capacity
,
k_value
=
2
,
noisy_func
=
noisy_func
,
capacity_factor_train
=
capacity_factor_train
,
drop_tks
=
drop_tks
)
capacity_factor_eval
=
capacity_factor_eval
,
min_capacity
=
min_capacity
,
noisy_func
=
noisy_func
,
drop_tks
=
drop_tks
,
)
def
forward
(
self
,
inputs
:
torch
.
Tensor
,
use_kernel
:
bool
=
False
,
ep_group
:
Optional
[
ProcessGroup
]
=
None
)
->
Tuple
:
def
forward
(
self
,
inputs
:
torch
.
Tensor
,
use_kernel
:
bool
=
False
,
ep_group
:
Optional
[
ProcessGroup
]
=
None
)
->
Tuple
:
"""
"""
...
@@ -255,8 +266,8 @@ class Top2Router(MoeRouter):
...
@@ -255,8 +266,8 @@ class Top2Router(MoeRouter):
top2_idx
=
torch
.
argmax
(
logits_except1
,
dim
=-
1
)
top2_idx
=
torch
.
argmax
(
logits_except1
,
dim
=-
1
)
mask2
=
F
.
one_hot
(
top2_idx
,
num_classes
=
num_experts
).
to
(
torch
.
int32
)
mask2
=
F
.
one_hot
(
top2_idx
,
num_classes
=
num_experts
).
to
(
torch
.
int32
)
cmask
=
(
mask1
+
mask2
)
# loss: [s, e]
cmask
=
mask1
+
mask2
# loss: [s, e]
cmask
=
cmask
.
float
()
/
2.0
# div 2 to normalize it to 1
cmask
=
cmask
.
float
()
/
2.0
# div 2 to normalize it to 1
# calculate loss
# calculate loss
expert_indices
=
torch
.
stack
([
top1_idx
,
top2_idx
],
dim
=-
1
)
expert_indices
=
torch
.
stack
([
top1_idx
,
top2_idx
],
dim
=-
1
)
...
@@ -269,7 +280,7 @@ class Top2Router(MoeRouter):
...
@@ -269,7 +280,7 @@ class Top2Router(MoeRouter):
dist
.
all_reduce
(
max_num
,
op
=
dist
.
ReduceOp
.
MAX
,
group
=
ep_group
)
dist
.
all_reduce
(
max_num
,
op
=
dist
.
ReduceOp
.
MAX
,
group
=
ep_group
)
capacity
=
max_num
.
item
()
capacity
=
max_num
.
item
()
rank1
=
moe_cumsum
(
mask1
,
use_kernel
=
self
.
use_kernel
)
# rank1: [s, e]
rank1
=
moe_cumsum
(
mask1
,
use_kernel
=
self
.
use_kernel
)
# rank1: [s, e]
rank2
=
moe_cumsum
(
mask2
,
use_kernel
=
self
.
use_kernel
)
rank2
=
moe_cumsum
(
mask2
,
use_kernel
=
self
.
use_kernel
)
rank2
+=
torch
.
sum
(
mask1
,
dim
=-
2
,
keepdim
=
True
)
rank2
+=
torch
.
sum
(
mask1
,
dim
=-
2
,
keepdim
=
True
)
...
@@ -336,15 +347,18 @@ class TopKRouter(MoeRouter):
...
@@ -336,15 +347,18 @@ class TopKRouter(MoeRouter):
oversubscribed / reach capacity.
oversubscribed / reach capacity.
"""
"""
def
__init__
(
self
,
def
__init__
(
num_selected_experts
:
int
,
self
,
capacity_factor_train
:
float
=
1.25
,
num_selected_experts
:
int
,
capacity_factor_eval
:
float
=
2.0
,
capacity_factor_train
:
float
=
1.25
,
min_capacity
:
int
=
4
,
capacity_factor_eval
:
float
=
2.0
,
noisy_func
:
Optional
[
Callable
]
=
None
,
min_capacity
:
int
=
4
,
drop_tks
:
bool
=
True
):
noisy_func
:
Optional
[
Callable
]
=
None
,
super
().
__init__
(
num_selected_experts
,
capacity_factor_train
,
capacity_factor_eval
,
min_capacity
,
noisy_func
,
drop_tks
:
bool
=
True
,
drop_tks
)
):
super
().
__init__
(
num_selected_experts
,
capacity_factor_train
,
capacity_factor_eval
,
min_capacity
,
noisy_func
,
drop_tks
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -410,7 +424,7 @@ class TopKRouter(MoeRouter):
...
@@ -410,7 +424,7 @@ class TopKRouter(MoeRouter):
# The combine array will be used for combining expert outputs, scaled by the
# The combine array will be used for combining expert outputs, scaled by the
# router probabilities. Shape: [num_groups, tokens_per_group, num_experts,
# router probabilities. Shape: [num_groups, tokens_per_group, num_experts,
# expert_capacity].
# expert_capacity].
combine_array
=
torch
.
einsum
(
'
...te,...tec->...tec
'
,
router_probs
,
dispatch_mask
)
combine_array
=
torch
.
einsum
(
"
...te,...tec->...tec
"
,
router_probs
,
dispatch_mask
)
return
combine_array
,
dispatch_mask
return
combine_array
,
dispatch_mask
...
...
colossalai/moe/utils.py
View file @
8823cc48
...
@@ -7,13 +7,12 @@ import torch.distributed as dist
...
@@ -7,13 +7,12 @@ import torch.distributed as dist
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
colossalai.accelerator
import
get_accelerator
from
colossalai.moe.manager
import
MOE_MANAGER
from
colossalai.moe.manager
import
MOE_MANAGER
from
colossalai.tensor.moe_tensor.api
import
get_dp_group
,
get_dp_group_ranks
,
get_ep_size
,
is_moe_tensor
from
colossalai.tensor.moe_tensor.api
import
get_dp_group
,
get_dp_group_ranks
,
get_ep_size
,
is_moe_tensor
from
colossalai.utils
import
get_current_device
class
ForceFP32Parameter
(
torch
.
nn
.
Parameter
):
class
ForceFP32Parameter
(
torch
.
nn
.
Parameter
):
def
half
(
self
,
memory_format
=
None
):
def
half
(
self
,
memory_format
=
None
):
return
self
.
data
.
clone
()
return
self
.
data
.
clone
()
...
@@ -30,8 +29,8 @@ class NormalNoiseGenerator:
...
@@ -30,8 +29,8 @@ class NormalNoiseGenerator:
def
__init__
(
self
,
num_experts
:
int
):
def
__init__
(
self
,
num_experts
:
int
):
self
.
normal
=
torch
.
distributions
.
normal
.
Normal
(
self
.
normal
=
torch
.
distributions
.
normal
.
Normal
(
loc
=
torch
.
tensor
(
0.0
,
device
=
get_current_device
()),
loc
=
torch
.
tensor
(
0.0
,
device
=
get_
accelerator
().
get_
current_device
()),
scale
=
torch
.
tensor
(
1.0
/
num_experts
**
2
,
device
=
get_current_device
()),
scale
=
torch
.
tensor
(
1.0
/
num_experts
**
2
,
device
=
get_
accelerator
().
get_
current_device
()),
).
rsample
).
rsample
def
__call__
(
self
,
inputs
:
torch
.
Tensor
):
def
__call__
(
self
,
inputs
:
torch
.
Tensor
):
...
@@ -52,8 +51,8 @@ class UniformNoiseGenerator:
...
@@ -52,8 +51,8 @@ class UniformNoiseGenerator:
def
__init__
(
self
,
eps
:
float
=
1e-2
):
def
__init__
(
self
,
eps
:
float
=
1e-2
):
self
.
uniform
=
torch
.
distributions
.
uniform
.
Uniform
(
self
.
uniform
=
torch
.
distributions
.
uniform
.
Uniform
(
low
=
torch
.
tensor
(
1.0
-
eps
,
device
=
get_current_device
()),
low
=
torch
.
tensor
(
1.0
-
eps
,
device
=
get_
accelerator
().
get_
current_device
()),
high
=
torch
.
tensor
(
1.0
+
eps
,
device
=
get_current_device
()),
high
=
torch
.
tensor
(
1.0
+
eps
,
device
=
get_
accelerator
().
get_
current_device
()),
).
rsample
).
rsample
def
__call__
(
self
,
inputs
:
torch
.
Tensor
):
def
__call__
(
self
,
inputs
:
torch
.
Tensor
):
...
@@ -142,7 +141,7 @@ def get_moe_epsize_param_dict(model: nn.Module) -> Dict[int, List[nn.Parameter]]
...
@@ -142,7 +141,7 @@ def get_moe_epsize_param_dict(model: nn.Module) -> Dict[int, List[nn.Parameter]]
epsize_param_dict
=
dict
()
epsize_param_dict
=
dict
()
for
param
in
model
.
parameters
():
for
param
in
model
.
parameters
():
if
not
is_moe_tensor
(
param
):
if
not
is_moe_tensor
(
param
):
ep_size
=
1
# set ep_size to 1 for dp parameters
ep_size
=
1
# set ep_size to 1 for dp parameters
else
:
else
:
ep_size
=
get_ep_size
(
param
)
ep_size
=
get_ep_size
(
param
)
if
ep_size
not
in
epsize_param_dict
:
if
ep_size
not
in
epsize_param_dict
:
...
@@ -193,18 +192,13 @@ def create_ep_hierarchical_group(
...
@@ -193,18 +192,13 @@ def create_ep_hierarchical_group(
assert
nproc_per_node
is
not
None
,
"Please use torchrun to launch the job, or specify nproc_per_node manually."
assert
nproc_per_node
is
not
None
,
"Please use torchrun to launch the job, or specify nproc_per_node manually."
nproc_per_node
=
int
(
nproc_per_node
)
nproc_per_node
=
int
(
nproc_per_node
)
else
:
else
:
assert
dist
.
get_world_size
()
%
nproc_per_node
==
0
,
\
assert
dist
.
get_world_size
()
%
nproc_per_node
==
0
,
"nproc_per_node should be a divisor of world_size."
"nproc_per_node should be a divisor of world_size."
num_node
=
dist
.
get_world_size
()
//
nproc_per_node
num_node
=
dist
.
get_world_size
()
//
nproc_per_node
intra_src_rank
=
None
intra_src_rank
=
None
ep_intra_node_group
=
None
ep_intra_node_group
=
None
for
i
in
range
(
num_node
):
for
i
in
range
(
num_node
):
ep_intra_ranks
=
[
ep_intra_ranks
=
[
i
*
nproc_per_node
+
j
for
j
in
range
(
nproc_per_node
)
if
j
in
ep_group_ranks
]
i
*
nproc_per_node
+
j
for
j
in
range
(
nproc_per_node
)
if
j
in
ep_group_ranks
]
group
=
dist
.
new_group
(
ep_intra_ranks
)
group
=
dist
.
new_group
(
ep_intra_ranks
)
if
rank
in
ep_intra_ranks
:
if
rank
in
ep_intra_ranks
:
assert
ep_intra_node_group
is
None
assert
ep_intra_node_group
is
None
...
@@ -212,10 +206,7 @@ def create_ep_hierarchical_group(
...
@@ -212,10 +206,7 @@ def create_ep_hierarchical_group(
intra_src_rank
=
ep_intra_ranks
[
0
]
intra_src_rank
=
ep_intra_ranks
[
0
]
ep_inter_node_group
=
None
ep_inter_node_group
=
None
ep_inter_ranks
=
[
ep_inter_ranks
=
[
ep_group_ranks
[
0
]
+
i
*
nproc_per_node
for
i
in
range
(
num_node
)]
ep_group_ranks
[
0
]
+
i
*
nproc_per_node
for
i
in
range
(
num_node
)
]
if
len
(
ep_inter_ranks
)
>
1
:
if
len
(
ep_inter_ranks
)
>
1
:
group
=
dist
.
new_group
(
ep_inter_ranks
)
group
=
dist
.
new_group
(
ep_inter_ranks
)
if
rank
in
ep_inter_ranks
:
if
rank
in
ep_inter_ranks
:
...
...
colossalai/
kernel/cuda_native/mha/mha
.py
→
colossalai/
nn/layer/colo_attention
.py
View file @
8823cc48
import
enum
import
math
import
math
from
typing
import
Optional
import
warnings
from
dataclasses
import
dataclass
from
typing
import
Iterable
,
Optional
,
Tuple
import
torch
import
torch
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
einops
import
rearrange
from
..scaled_softmax
import
AttnMaskType
from
colossalai.accelerator
import
get_accelerator
from
.flash_attn_2
import
HAS_FLASH_ATTN
from
colossalai.kernel.kernel_loader
import
FlashAttentionLoader
from
.mem_eff_attn
import
HAS_MEM_EFF_ATTN
from
.utils
import
Repad
,
SeqLenInfo
,
Unpad
if
HAS_FLASH_ATTN
:
from
.flash_attn_2
import
flash_attention
@
dataclass
if
HAS_MEM_EFF_ATTN
:
class
SeqLenInfo
:
from
.mem_eff_attn
import
mem_eff_attention
seqlens
:
Iterable
[
int
]
=
None
indices
:
torch
.
Tensor
=
None
max_seqlen
:
int
=
None
cu_seqlens
:
torch
.
Tensor
=
None
@
staticmethod
def
materialize
(
attn_mask
:
torch
.
Tensor
=
None
,
size
:
Tuple
[
int
]
=
None
,
device
=
get_accelerator
().
get_current_device
()
):
if
attn_mask
is
not
None
:
indices
=
torch
.
nonzero
(
attn_mask
.
flatten
(),
as_tuple
=
False
).
flatten
().
to
(
device
)
seqlens
=
attn_mask
.
sum
(
dim
=-
1
,
dtype
=
torch
.
int32
).
flatten
()
else
:
batch_size
,
tgt_len
=
size
[
0
],
size
[
1
]
indices
=
torch
.
arange
(
batch_size
*
tgt_len
,
dtype
=
torch
.
long
,
device
=
device
)
seqlens
=
torch
.
LongTensor
([
tgt_len
]
*
batch_size
,
device
=
device
)
max_seqlen
=
max
(
seqlens
)
cu_seqlens
=
F
.
pad
(
torch
.
cumsum
(
seqlens
,
dim
=
0
,
dtype
=
torch
.
int32
),
(
1
,
0
)).
to
(
device
)
return
SeqLenInfo
(
seqlens
.
tolist
(),
indices
,
max_seqlen
,
cu_seqlens
)
class
AttnMaskType
(
enum
.
Enum
):
padding
=
1
causal
=
2
paddedcausal
=
3
class
Unpad
(
torch
.
autograd
.
Function
):
"""
Adapted from
https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
"""
@
staticmethod
def
forward
(
ctx
,
tensor
:
torch
.
Tensor
,
indices
:
torch
.
Tensor
):
ctx
.
save_for_backward
(
indices
)
# [b, s, ...]
assert
tensor
.
ndim
>=
3
ctx
.
bsz
=
tensor
.
shape
[
0
]
out
=
rearrange
(
tensor
,
"b s ... -> (b s) ..."
)
ctx
.
shape
=
out
.
shape
# [ntokens, ...]
return
out
[
indices
]
@
staticmethod
def
backward
(
ctx
,
grad_output
):
(
indices
,)
=
ctx
.
saved_tensors
# [ntokens, ...]
grad
=
torch
.
zeros
(
ctx
.
shape
,
dtype
=
grad_output
.
dtype
,
device
=
grad_output
.
device
)
grad
[
indices
]
=
grad_output
grad
=
rearrange
(
grad
,
"(b s) ... -> b s ..."
,
b
=
ctx
.
bsz
)
# [b, s, ...]
return
grad
,
None
class
Repad
(
torch
.
autograd
.
Function
):
"""
Adapted from
https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
"""
@
staticmethod
def
forward
(
ctx
,
tensor
:
torch
.
Tensor
,
indices
:
torch
.
Tensor
,
batch_size
:
int
,
seq_len
:
int
):
ctx
.
save_for_backward
(
indices
)
# [ntokens, ...]
tensor
=
tensor
out
=
torch
.
zeros
((
batch_size
*
seq_len
,
*
tensor
.
shape
[
1
:]),
dtype
=
tensor
.
dtype
,
device
=
tensor
.
device
)
# [b*s, ...]
out
[
indices
]
=
tensor
return
out
@
staticmethod
def
backward
(
ctx
,
grad_output
):
(
indices
,)
=
ctx
.
saved_tensors
# [b*s, ...]
grad
=
grad_output
[
indices
]
# [ntokens, ...]
return
grad
,
None
,
None
,
None
class
ColoAttention
(
torch
.
nn
.
Module
):
class
ColoAttention
(
torch
.
nn
.
Module
):
...
@@ -27,8 +106,7 @@ class ColoAttention(torch.nn.Module):
...
@@ -27,8 +106,7 @@ class ColoAttention(torch.nn.Module):
self
.
scale
=
1
/
math
.
sqrt
(
embed_dim
//
num_heads
)
self
.
scale
=
1
/
math
.
sqrt
(
embed_dim
//
num_heads
)
self
.
dropout
=
dropout
self
.
dropout
=
dropout
if
not
HAS_MEM_EFF_ATTN
and
not
HAS_FLASH_ATTN
:
self
.
attn
=
FlashAttentionLoader
().
load
()
raise
Exception
(
"flash attention can not support!"
)
@
staticmethod
@
staticmethod
def
unpad
(
tensor
:
torch
.
Tensor
,
indices
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
unpad
(
tensor
:
torch
.
Tensor
,
indices
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
@@ -44,14 +122,30 @@ class ColoAttention(torch.nn.Module):
...
@@ -44,14 +122,30 @@ class ColoAttention(torch.nn.Module):
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
attn_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
attn_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
origin_attn_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
attn_mask_type
:
Optional
[
AttnMaskType
]
=
None
,
attn_mask_type
:
Optional
[
AttnMaskType
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
):
):
attn
=
None
"""
if
HAS_FLASH_ATTN
and
query
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
and
bias
==
None
:
ColoAttention
attn
=
flash_attention
else
:
Args:
attn
=
mem_eff_attention
q: (batch, q_seqlen, nheads, headdim)
k: (batch, kv_seqlen, nheads, headdim)
v: (batch, kv_seqlen, nheads, headdim)
origin_attn_mask: (nheads, q_seqlen, kv_seqlen)
bias: will not be used
Return:
attn_out: (batch, q_seqlen, nheads, headdim).
"""
# if flash attention is not applicable, switch to memory effcient attention
if
self
.
attn
.
__name__
==
"flash_attention"
and
(
query
.
dtype
not
in
[
torch
.
float16
,
torch
.
bfloat16
]
or
bias
!=
None
):
warnings
.
warn
(
f
"flash-attn expects fp16 or bf16 but got
{
query
.
dtype
}
, switching to xformers' implementation."
)
self
.
attn
=
FlashAttentionLoader
().
load
(
ext_name
=
"flash_attention_xformers_cuda"
)
padded
=
attn_mask_type
is
not
None
and
attn_mask_type
.
value
%
2
==
1
padded
=
attn_mask_type
is
not
None
and
attn_mask_type
.
value
%
2
==
1
causal
=
attn_mask_type
is
not
None
and
attn_mask_type
.
value
>
1
causal
=
attn_mask_type
is
not
None
and
attn_mask_type
.
value
>
1
...
@@ -91,12 +185,13 @@ class ColoAttention(torch.nn.Module):
...
@@ -91,12 +185,13 @@ class ColoAttention(torch.nn.Module):
else
:
else
:
query
,
key
,
value
=
torch
.
stack
([
query
,
key
,
value
],
dim
=
2
).
squeeze
(
0
).
unbind
(
dim
=
1
)
query
,
key
,
value
=
torch
.
stack
([
query
,
key
,
value
],
dim
=
2
).
squeeze
(
0
).
unbind
(
dim
=
1
)
out
=
attn
(
out
=
self
.
attn
(
query
,
query
,
key
,
key
,
value
,
value
,
seq_len_info_q
,
seq_len_info_q
=
seq_len_info_q
,
seq_len_info_kv
,
seq_len_info_kv
=
seq_len_info_kv
,
origin_attn_mask
=
origin_attn_mask
,
dropout_p
=
self
.
dropout
,
dropout_p
=
self
.
dropout
,
scale
=
self
.
scale
,
scale
=
self
.
scale
,
causal
=
causal
,
causal
=
causal
,
...
@@ -109,5 +204,6 @@ class ColoAttention(torch.nn.Module):
...
@@ -109,5 +204,6 @@ class ColoAttention(torch.nn.Module):
out
=
self
.
repad
(
out
,
seq_len_info_q
.
indices
,
batch_size
,
tgt_len
)
out
=
self
.
repad
(
out
,
seq_len_info_q
.
indices
,
batch_size
,
tgt_len
)
out
=
rearrange
(
out
,
"(b s) h d -> b s h d"
,
b
=
batch_size
)
out
=
rearrange
(
out
,
"(b s) h d -> b s h d"
,
b
=
batch_size
)
out
=
rearrange
(
out
,
"b s h d -> b s (h d)"
)
if
len
(
out
.
shape
)
==
4
:
out
=
rearrange
(
out
,
"b s h d -> b s (h d)"
)
return
out
return
out
colossalai/
kernel/cuda_native
/layer
_
norm.py
→
colossalai/
nn/layer
/layernorm.py
View file @
8823cc48
...
@@ -9,7 +9,7 @@ from torch.cuda.amp import custom_bwd, custom_fwd
...
@@ -9,7 +9,7 @@ from torch.cuda.amp import custom_bwd, custom_fwd
from
torch.nn
import
init
from
torch.nn
import
init
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
colossalai.kernel.
op_builder.layernorm
import
LayerNorm
Buil
der
from
colossalai.kernel.
kernel_loader
import
LayerNorm
Loa
der
try
:
try
:
from
colossalai._C
import
layer_norm
from
colossalai._C
import
layer_norm
...
@@ -29,7 +29,7 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
...
@@ -29,7 +29,7 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
global
layer_norm
global
layer_norm
if
layer_norm
is
None
:
if
layer_norm
is
None
:
layer_norm
=
LayerNorm
Buil
der
().
load
()
layer_norm
=
LayerNorm
Loa
der
().
load
()
output
,
mean
,
invvar
=
layer_norm
.
forward_affine
(
input_
,
ctx
.
normalized_shape
,
weight_
,
bias_
,
ctx
.
eps
)
output
,
mean
,
invvar
=
layer_norm
.
forward_affine
(
input_
,
ctx
.
normalized_shape
,
weight_
,
bias_
,
ctx
.
eps
)
ctx
.
layernorm_op
=
layer_norm
ctx
.
layernorm_op
=
layer_norm
ctx
.
save_for_backward
(
input_
,
weight_
,
bias_
,
mean
,
invvar
)
ctx
.
save_for_backward
(
input_
,
weight_
,
bias_
,
mean
,
invvar
)
...
...
colossalai/nn/layer/scaled_softmax.py
0 → 100644
View file @
8823cc48
# This code from NVIDIA Megatron:
# with minor changes.
import
enum
import
torch
import
torch.nn
as
nn
from
colossalai.kernel.kernel_loader
import
ScaledMaskedSoftmaxLoader
,
ScaledUpperTriangleMaskedSoftmaxLoader
class
AttnMaskType
(
enum
.
Enum
):
padding
=
1
causal
=
2
paddedcausal
=
3
class
ScaledUpperTriangMaskedSoftmax
(
torch
.
autograd
.
Function
):
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
2. Apply upper triangular mask (typically used in gpt models).
3. Perform softmax.
"""
@
staticmethod
def
forward
(
ctx
,
inputs
,
scale
):
global
scaled_upper_triang_masked_softmax
if
scaled_upper_triang_masked_softmax
:
scaled_upper_triang_masked_softmax
=
ScaledUpperTriangleMaskedSoftmaxLoader
().
load
()
scale_t
=
torch
.
tensor
([
scale
])
softmax_results
=
scaled_upper_triang_masked_softmax
.
forward
(
inputs
,
scale_t
[
0
])
ctx
.
save_for_backward
(
softmax_results
,
scale_t
)
return
softmax_results
@
staticmethod
def
backward
(
ctx
,
output_grads
):
softmax_results
,
scale_t
=
ctx
.
saved_tensors
input_grads
=
scaled_upper_triang_masked_softmax
.
backward
(
output_grads
,
softmax_results
,
scale_t
[
0
])
return
input_grads
,
None
class
ScaledMaskedSoftmax
(
torch
.
autograd
.
Function
):
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
2. Apply the mask.
3. Perform softmax.
"""
@
staticmethod
def
forward
(
ctx
,
inputs
,
mask
,
scale
):
scale_t
=
torch
.
tensor
([
scale
])
# build and load kernel if not pre-built
global
scaled_masked_softmax
if
scaled_masked_softmax
is
None
:
scaled_masked_softmax
=
ScaledMaskedSoftmaxLoader
().
load
()
softmax_results
=
scaled_masked_softmax
.
forward
(
inputs
,
mask
,
scale_t
[
0
])
ctx
.
save_for_backward
(
softmax_results
,
scale_t
)
return
softmax_results
@
staticmethod
def
backward
(
ctx
,
output_grads
):
softmax_results
,
scale_t
=
ctx
.
saved_tensors
input_grads
=
scaled_masked_softmax
.
backward
(
output_grads
,
softmax_results
,
scale_t
[
0
])
return
input_grads
,
None
,
None
,
None
class
FusedScaleMaskSoftmax
(
nn
.
Module
):
"""
Fused operation: scaling + mask + softmax
Arguments:
input_in_fp16: Flag to indicate if input in fp16 data format.
input_in_bf16: Flag to indicate if input in bf16 data format.
attn_mask_type: Attention mask type (pad or causal)
scaled_masked_softmax_fusion: Flag to indicate user want to use softmax fusion
mask_func: Mask function to be applied.
softmax_in_fp32: If True, softmax in performed at fp32 precision.
scale: Scaling factor used in input tensor scaling.
"""
def
__init__
(
self
,
input_in_fp16
,
input_in_bf16
,
attn_mask_type
,
scaled_masked_softmax_fusion
,
mask_func
,
softmax_in_fp32
,
scale
,
):
super
(
FusedScaleMaskSoftmax
,
self
).
__init__
()
self
.
input_in_fp16
=
input_in_fp16
self
.
input_in_bf16
=
input_in_bf16
assert
not
(
self
.
input_in_fp16
and
self
.
input_in_bf16
),
"both fp16 and bf16 flags cannot be active at the same time."
self
.
input_in_float16
=
self
.
input_in_fp16
or
self
.
input_in_bf16
self
.
attn_mask_type
=
attn_mask_type
self
.
scaled_masked_softmax_fusion
=
scaled_masked_softmax_fusion
self
.
mask_func
=
mask_func
self
.
softmax_in_fp32
=
softmax_in_fp32
self
.
scale
=
scale
assert
self
.
scale
is
None
or
softmax_in_fp32
,
"softmax should be in fp32 when scaled"
def
forward
(
self
,
input
,
mask
):
# [b, np, sq, sk]
assert
input
.
dim
()
==
4
if
self
.
is_kernel_available
(
mask
,
*
input
.
size
()):
return
self
.
forward_fused_softmax
(
input
,
mask
)
else
:
return
self
.
forward_torch_softmax
(
input
,
mask
)
def
is_kernel_available
(
self
,
mask
,
b
,
np
,
sq
,
sk
):
attn_batches
=
b
*
np
if
(
self
.
scaled_masked_softmax_fusion
# user want to fuse
and
self
.
input_in_float16
# input must be fp16
and
mask
is
not
None
# mask tensor must not be None
and
16
<
sk
<=
2048
# sk must be 16 ~ 2048
and
sq
%
4
==
0
# sq must be divisor of 4
and
attn_batches
%
4
==
0
# np * b must be divisor of 4
):
if
0
<=
sk
<=
2048
:
batch_per_block
=
self
.
get_batch_per_block
(
sq
,
sk
,
b
,
np
)
if
self
.
attn_mask_type
.
value
>
1
:
if
attn_batches
%
batch_per_block
==
0
:
return
True
else
:
if
sq
%
batch_per_block
==
0
:
return
True
return
False
def
forward_fused_softmax
(
self
,
input
,
mask
):
b
,
np
,
sq
,
sk
=
input
.
size
()
scale
=
self
.
scale
if
self
.
scale
is
not
None
else
1.0
if
self
.
attn_mask_type
.
value
>
1
:
assert
sq
==
sk
,
"causal mask is only for self attention"
# input is 3D tensor (attn_batches, sq, sk)
input
=
input
.
view
(
-
1
,
sq
,
sk
)
probs
=
ScaledUpperTriangMaskedSoftmax
.
apply
(
input
,
scale
)
return
probs
.
view
(
b
,
np
,
sq
,
sk
)
else
:
# input is 4D tensor (b, np, sq, sk)
return
ScaledMaskedSoftmax
.
apply
(
input
,
mask
,
scale
)
def
forward_torch_softmax
(
self
,
input
,
mask
):
if
self
.
input_in_float16
and
self
.
softmax_in_fp32
:
input
=
input
.
float
()
if
self
.
scale
is
not
None
:
input
=
input
*
self
.
scale
mask_output
=
self
.
mask_func
(
input
,
mask
)
if
mask
is
not
None
else
input
probs
=
torch
.
nn
.
Softmax
(
dim
=-
1
)(
mask_output
)
if
self
.
input_in_float16
and
self
.
softmax_in_fp32
:
if
self
.
input_in_fp16
:
probs
=
probs
.
half
()
else
:
probs
=
probs
.
bfloat16
()
return
probs
def
get_batch_per_block
(
self
,
sq
,
sk
,
b
,
np
):
# build and load kernel if not pre-built
global
scaled_masked_softmax
if
scaled_masked_softmax
is
None
:
scaled_masked_softmax
=
ScaledMaskedSoftmaxLoader
().
load
()
return
scaled_masked_softmax
.
get_batch_per_block
(
sq
,
sk
,
b
,
np
)
colossalai/nn/optimizer/cpu_adam.py
View file @
8823cc48
import
math
import
math
import
platform
from
typing
import
Optional
from
typing
import
Optional
import
torch
import
torch
from
colossalai.kernel.
op_buil
der
import
Arm
CPUAdam
Builder
,
CPUAdamBuil
der
from
colossalai.kernel.
kernel_loa
der
import
CPUAdam
Loa
der
from
.nvme_optimizer
import
NVMeOptimizer
from
.nvme_optimizer
import
NVMeOptimizer
...
@@ -78,7 +77,7 @@ class CPUAdam(NVMeOptimizer):
...
@@ -78,7 +77,7 @@ class CPUAdam(NVMeOptimizer):
default_args
=
dict
(
lr
=
lr
,
betas
=
betas
,
eps
=
eps
,
weight_decay
=
weight_decay
,
bias_correction
=
bias_correction
)
default_args
=
dict
(
lr
=
lr
,
betas
=
betas
,
eps
=
eps
,
weight_decay
=
weight_decay
,
bias_correction
=
bias_correction
)
super
(
CPUAdam
,
self
).
__init__
(
model_params
,
default_args
,
nvme_offload_fraction
,
nvme_offload_dir
)
super
(
CPUAdam
,
self
).
__init__
(
model_params
,
default_args
,
nvme_offload_fraction
,
nvme_offload_dir
)
self
.
adamw_mode
=
adamw_mode
self
.
adamw_mode
=
adamw_mode
cpu_adam
=
Arm
CPUAdam
Builder
().
load
()
if
platform
.
machine
()
==
"aarch64"
else
CPUAdamBuil
der
().
load
()
cpu_adam
=
CPUAdam
Loa
der
().
load
()
# if you find yourself stuck here, make sure that you install colossalai with CUDA_EXT=1 specification
# if you find yourself stuck here, make sure that you install colossalai with CUDA_EXT=1 specification
self
.
cpu_adam_op
=
cpu_adam
.
CPUAdamOptimizer
(
lr
,
betas
[
0
],
betas
[
1
],
eps
,
weight_decay
,
adamw_mode
)
self
.
cpu_adam_op
=
cpu_adam
.
CPUAdamOptimizer
(
lr
,
betas
[
0
],
betas
[
1
],
eps
,
weight_decay
,
adamw_mode
)
...
...
colossalai/nn/optimizer/fused_adam.py
View file @
8823cc48
...
@@ -70,9 +70,9 @@ class FusedAdam(torch.optim.Optimizer):
...
@@ -70,9 +70,9 @@ class FusedAdam(torch.optim.Optimizer):
self
.
adamw_mode
=
1
if
adamw_mode
else
0
self
.
adamw_mode
=
1
if
adamw_mode
else
0
self
.
set_grad_none
=
set_grad_none
self
.
set_grad_none
=
set_grad_none
if
multi_tensor_applier
.
available
:
if
multi_tensor_applier
.
available
:
from
colossalai.kernel.
op_buil
der
import
FusedOptim
Buil
der
from
colossalai.kernel.
kernel_loa
der
import
FusedOptim
izerLoa
der
fused_optim
=
FusedOptim
Buil
der
().
load
()
fused_optim
=
FusedOptim
izerLoa
der
().
load
()
# Skip buffer
# Skip buffer
self
.
_dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
self
.
_dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
...
...
Prev
1
2
3
4
5
6
7
8
9
10
…
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment