Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
3acbf6d4
Unverified
Commit
3acbf6d4
authored
Nov 22, 2023
by
Xuanlei Zhao
Committed by
GitHub
Nov 22, 2023
Browse files
[npu] add npu support for hybrid plugin and llama (#5090)
* llama 3d * update * fix autocast
parent
aae49663
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
61 additions
and
40 deletions
+61
-40
colossalai/booster/mixed_precision/fp16_torch.py
colossalai/booster/mixed_precision/fp16_torch.py
+2
-1
colossalai/booster/plugin/hybrid_parallel_plugin.py
colossalai/booster/plugin/hybrid_parallel_plugin.py
+9
-8
colossalai/device/device_mesh.py
colossalai/device/device_mesh.py
+1
-1
colossalai/legacy/amp/torch_amp/torch_amp.py
colossalai/legacy/amp/torch_amp/torch_amp.py
+4
-3
colossalai/legacy/utils/activation_checkpoint.py
colossalai/legacy/utils/activation_checkpoint.py
+3
-3
colossalai/shardformer/layer/utils.py
colossalai/shardformer/layer/utils.py
+16
-15
colossalai/testing/utils.py
colossalai/testing/utils.py
+7
-6
colossalai/utils/device.py
colossalai/utils/device.py
+17
-1
examples/language/llama2/benchmark.py
examples/language/llama2/benchmark.py
+2
-2
No files found.
colossalai/booster/mixed_precision/fp16_torch.py
View file @
3acbf6d4
...
@@ -6,6 +6,7 @@ from torch import Tensor
...
@@ -6,6 +6,7 @@ from torch import Tensor
from
torch.optim
import
Optimizer
from
torch.optim
import
Optimizer
from
colossalai.interface
import
ModelWrapper
,
OptimizerWrapper
from
colossalai.interface
import
ModelWrapper
,
OptimizerWrapper
from
colossalai.utils.device
import
autocast
from
.mixed_precision_base
import
MixedPrecision
from
.mixed_precision_base
import
MixedPrecision
...
@@ -88,7 +89,7 @@ class TorchAMPModule(ModelWrapper):
...
@@ -88,7 +89,7 @@ class TorchAMPModule(ModelWrapper):
super
().
__init__
(
module
)
super
().
__init__
(
module
)
def
forward
(
self
,
*
args
,
**
kwargs
):
def
forward
(
self
,
*
args
,
**
kwargs
):
with
torch
.
cuda
.
amp
.
autocast
():
with
autocast
():
return
self
.
module
(
*
args
,
**
kwargs
)
return
self
.
module
(
*
args
,
**
kwargs
)
...
...
colossalai/booster/plugin/hybrid_parallel_plugin.py
View file @
3acbf6d4
...
@@ -29,6 +29,7 @@ from colossalai.shardformer.layer.utils import SeqParallelUtils
...
@@ -29,6 +29,7 @@ from colossalai.shardformer.layer.utils import SeqParallelUtils
from
colossalai.shardformer.policies.base_policy
import
Policy
from
colossalai.shardformer.policies.base_policy
import
Policy
from
colossalai.tensor.d_tensor.api
import
is_distributed_tensor
from
colossalai.tensor.d_tensor.api
import
is_distributed_tensor
from
colossalai.zero.low_level
import
LowLevelZeroOptimizer
from
colossalai.zero.low_level
import
LowLevelZeroOptimizer
from
colossalai.utils.device
import
get_current_device
from
.pp_plugin_base
import
PipelinePluginBase
from
.pp_plugin_base
import
PipelinePluginBase
...
@@ -81,7 +82,7 @@ class HybridParallelModule(ModelWrapper):
...
@@ -81,7 +82,7 @@ class HybridParallelModule(ModelWrapper):
self
.
mixed_precision
=
torch
.
bfloat16
self
.
mixed_precision
=
torch
.
bfloat16
if
self
.
mixed_precision
is
not
None
:
if
self
.
mixed_precision
is
not
None
:
module
=
module
.
to
(
self
.
mixed_precision
)
module
=
module
.
to
(
self
.
mixed_precision
)
module
=
module
.
cuda
()
module
=
module
.
to
(
get_current_device
()
)
# setting input type cast when using mixed precision
# setting input type cast when using mixed precision
self
.
convert_fn
=
None
self
.
convert_fn
=
None
...
@@ -345,7 +346,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
...
@@ -345,7 +346,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
if
norm_type
==
inf
:
if
norm_type
==
inf
:
total_norm
=
max
(
grad
.
data
.
abs
().
max
()
for
grad
in
gradients
)
total_norm
=
max
(
grad
.
data
.
abs
().
max
()
for
grad
in
gradients
)
total_norm_cuda
=
torch
.
cuda
.
FloatT
ensor
([
float
(
total_norm
)])
total_norm_cuda
=
torch
.
t
ensor
([
float
(
total_norm
)]
,
device
=
get_current_device
(),
dtype
=
torch
.
float32
)
if
self
.
tp_size
>
1
:
if
self
.
tp_size
>
1
:
dist
.
all_reduce
(
tensor
=
total_norm_cuda
,
op
=
dist
.
ReduceOp
.
MAX
,
group
=
self
.
tp_pg
)
dist
.
all_reduce
(
tensor
=
total_norm_cuda
,
op
=
dist
.
ReduceOp
.
MAX
,
group
=
self
.
tp_pg
)
if
self
.
pp_size
>
1
:
if
self
.
pp_size
>
1
:
...
@@ -384,7 +385,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
...
@@ -384,7 +385,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
total_norm_exponentiated
+=
grad_norm_exponentiated
total_norm_exponentiated
+=
grad_norm_exponentiated
total_norm_exponentiated_cuda
=
torch
.
cuda
.
FloatT
ensor
([
float
(
total_norm_exponentiated
)])
total_norm_exponentiated_cuda
=
torch
.
t
ensor
([
float
(
total_norm_exponentiated
)]
,
device
=
get_current_device
(),
dtype
=
torch
.
float32
)
if
self
.
tp_size
>
1
:
if
self
.
tp_size
>
1
:
# compute norm in tp process group
# compute norm in tp process group
dist
.
all_reduce
(
tensor
=
total_norm_exponentiated_cuda
,
op
=
dist
.
ReduceOp
.
SUM
,
group
=
self
.
tp_pg
)
dist
.
all_reduce
(
tensor
=
total_norm_exponentiated_cuda
,
op
=
dist
.
ReduceOp
.
SUM
,
group
=
self
.
tp_pg
)
...
@@ -542,7 +543,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
...
@@ -542,7 +543,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
# so we need to calculate the norm of 'tp' and 'pp' gradients.
# so we need to calculate the norm of 'tp' and 'pp' gradients.
total_norm
=
super
().
_compute_grad_norm
(
param_gradient_pairs
,
norm_type
)
total_norm
=
super
().
_compute_grad_norm
(
param_gradient_pairs
,
norm_type
)
total_norm_cuda
=
torch
.
cuda
.
FloatT
ensor
([
float
(
total_norm
)])
total_norm_cuda
=
torch
.
t
ensor
([
float
(
total_norm
)]
,
device
=
get_current_device
(),
dtype
=
torch
.
float32
)
if
self
.
tp_size
>
1
:
if
self
.
tp_size
>
1
:
dist
.
all_reduce
(
tensor
=
total_norm_cuda
,
op
=
dist
.
ReduceOp
.
MAX
,
group
=
self
.
tp_pg
)
dist
.
all_reduce
(
tensor
=
total_norm_cuda
,
op
=
dist
.
ReduceOp
.
MAX
,
group
=
self
.
tp_pg
)
...
@@ -585,7 +586,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
...
@@ -585,7 +586,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
total_norm_exponentiated
+=
grad_norm_exponentiated
total_norm_exponentiated
+=
grad_norm_exponentiated
total_norm_exponentiated_cuda
=
torch
.
cuda
.
FloatT
ensor
([
float
(
total_norm_exponentiated
)])
total_norm_exponentiated_cuda
=
torch
.
t
ensor
([
float
(
total_norm_exponentiated
)]
,
device
=
get_current_device
(),
dtype
=
torch
.
float32
)
if
self
.
tp_size
>
1
:
if
self
.
tp_size
>
1
:
# compute norm in tp process group
# compute norm in tp process group
dist
.
all_reduce
(
tensor
=
total_norm_exponentiated_cuda
,
op
=
dist
.
ReduceOp
.
SUM
,
group
=
self
.
tp_pg
)
dist
.
all_reduce
(
tensor
=
total_norm_exponentiated_cuda
,
op
=
dist
.
ReduceOp
.
SUM
,
group
=
self
.
tp_pg
)
...
@@ -797,7 +798,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
...
@@ -797,7 +798,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
# so we only need to calculate the norm 'tp' of 'pp' gradients.
# so we only need to calculate the norm 'tp' of 'pp' gradients.
total_norm
=
super
().
_compute_grad_norm
(
gradients
,
norm_type
)
total_norm
=
super
().
_compute_grad_norm
(
gradients
,
norm_type
)
total_norm_cuda
=
torch
.
cuda
.
FloatT
ensor
([
float
(
total_norm
)])
total_norm_cuda
=
torch
.
t
ensor
([
float
(
total_norm
)]
,
device
=
get_current_device
(),
dtype
=
torch
.
float32
)
if
tp_size
>
1
:
if
tp_size
>
1
:
dist
.
all_reduce
(
tensor
=
total_norm_cuda
,
op
=
dist
.
ReduceOp
.
MAX
,
group
=
self
.
tp_pg
)
dist
.
all_reduce
(
tensor
=
total_norm_cuda
,
op
=
dist
.
ReduceOp
.
MAX
,
group
=
self
.
tp_pg
)
...
@@ -836,7 +837,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
...
@@ -836,7 +837,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
total_norm_exponentiated
+=
grad_norm_exponentiated
total_norm_exponentiated
+=
grad_norm_exponentiated
total_norm_exponentiated_cuda
=
torch
.
cuda
.
FloatT
ensor
([
float
(
total_norm_exponentiated
)])
total_norm_exponentiated_cuda
=
torch
.
t
ensor
([
float
(
total_norm_exponentiated
)]
,
device
=
get_current_device
(),
dtype
=
torch
.
float32
)
if
dp_size
>
1
:
if
dp_size
>
1
:
# compute norm in dp process group
# compute norm in dp process group
dist
.
all_reduce
(
tensor
=
total_norm_exponentiated_cuda
,
op
=
dist
.
ReduceOp
.
SUM
,
group
=
self
.
dp_pg
)
dist
.
all_reduce
(
tensor
=
total_norm_exponentiated_cuda
,
op
=
dist
.
ReduceOp
.
SUM
,
group
=
self
.
dp_pg
)
...
@@ -1027,7 +1028,7 @@ class HybridParallelPlugin(PipelinePluginBase):
...
@@ -1027,7 +1028,7 @@ class HybridParallelPlugin(PipelinePluginBase):
return
self
.
pp_size
>
1
return
self
.
pp_size
>
1
def
supported_devices
(
self
)
->
List
[
str
]:
def
supported_devices
(
self
)
->
List
[
str
]:
return
[
"cuda"
]
return
[
"cuda"
,
"npu"
]
def
supported_precisions
(
self
)
->
List
[
str
]:
def
supported_precisions
(
self
)
->
List
[
str
]:
return
[
"fp16"
,
"bf16"
,
"fp32"
]
return
[
"fp16"
,
"bf16"
,
"fp32"
]
...
...
colossalai/device/device_mesh.py
View file @
3acbf6d4
...
@@ -38,7 +38,7 @@ class DeviceMesh:
...
@@ -38,7 +38,7 @@ class DeviceMesh:
device (str): the device for the process groups used by the DeviceMesh instance. (default: 'cuda')
device (str): the device for the process groups used by the DeviceMesh instance. (default: 'cuda')
"""
"""
_DIST_BACKEND
=
{
"cuda"
:
"nccl"
,
"cpu"
:
"gloo"
}
_DIST_BACKEND
=
{
"cuda"
:
"nccl"
,
"cpu"
:
"gloo"
,
"npu"
:
"hccl"
}
def
__init__
(
def
__init__
(
self
,
self
,
...
...
colossalai/legacy/amp/torch_amp/torch_amp.py
View file @
3acbf6d4
#!/usr/bin/env python
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# -*- encoding: utf-8 -*-
import
torch.cuda.amp
as
torch_amp
from
colossalai.utils.device
import
autocast
import
torch.nn
as
nn
import
torch.nn
as
nn
from
torch
import
Tensor
from
torch
import
Tensor
from
torch.nn.modules.loss
import
_Loss
from
torch.nn.modules.loss
import
_Loss
...
@@ -70,7 +71,7 @@ class TorchAMPModel(nn.Module):
...
@@ -70,7 +71,7 @@ class TorchAMPModel(nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
model
=
model
self
.
model
=
model
@
torch_amp
.
autocast
()
@
autocast
()
def
forward
(
self
,
*
args
,
**
kwargs
):
def
forward
(
self
,
*
args
,
**
kwargs
):
"""
"""
Execute forward under the torch amp context
Execute forward under the torch amp context
...
@@ -89,7 +90,7 @@ class TorchAMPLoss(nn.Module):
...
@@ -89,7 +90,7 @@ class TorchAMPLoss(nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
loss
=
loss
self
.
loss
=
loss
@
torch_amp
.
autocast
()
@
autocast
()
def
forward
(
self
,
*
args
,
**
kwargs
):
def
forward
(
self
,
*
args
,
**
kwargs
):
"""
"""
Execute forward under the torch amp context
Execute forward under the torch amp context
...
...
colossalai/legacy/utils/activation_checkpoint.py
View file @
3acbf6d4
...
@@ -7,7 +7,7 @@ import torch
...
@@ -7,7 +7,7 @@ import torch
from
torch.utils.checkpoint
import
check_backward_validity
,
detach_variable
from
torch.utils.checkpoint
import
check_backward_validity
,
detach_variable
from
colossalai.legacy.context.random
import
get_current_mode
,
get_states
,
set_mode
,
set_seed_states
,
sync_states
from
colossalai.legacy.context.random
import
get_current_mode
,
get_states
,
set_mode
,
set_seed_states
,
sync_states
from
colossalai.utils
import
get_current_device
from
colossalai.utils
.device
import
autocast
,
get_current_device
def
copy_to_device
(
obj
,
device
):
def
copy_to_device
(
obj
,
device
):
...
@@ -110,7 +110,7 @@ class CheckpointFunction(torch.autograd.Function):
...
@@ -110,7 +110,7 @@ class CheckpointFunction(torch.autograd.Function):
inputs
[
idx
]
=
tensors
[
i
]
inputs
[
idx
]
=
tensors
[
i
]
detached_inputs
=
detach_variable
(
tuple
(
inputs
))
detached_inputs
=
detach_variable
(
tuple
(
inputs
))
if
ctx
.
had_autocast_in_fwd
:
if
ctx
.
had_autocast_in_fwd
:
with
torch
.
enable_grad
(),
torch
.
cuda
.
amp
.
autocast
():
with
torch
.
enable_grad
(),
autocast
():
outputs
=
ctx
.
run_function
(
*
detached_inputs
)
outputs
=
ctx
.
run_function
(
*
detached_inputs
)
else
:
else
:
with
torch
.
enable_grad
():
with
torch
.
enable_grad
():
...
@@ -226,7 +226,7 @@ def _checkpoint_without_reentrant(function, activation_offload=False, *args):
...
@@ -226,7 +226,7 @@ def _checkpoint_without_reentrant(function, activation_offload=False, *args):
# rerun forward, the inner_pack will store all the activations in storage
# rerun forward, the inner_pack will store all the activations in storage
if
has_autocast_in_fwd
:
if
has_autocast_in_fwd
:
with
torch
.
enable_grad
(),
torch
.
cuda
.
amp
.
autocast
(),
torch
.
autograd
.
graph
.
saved_tensors_hooks
(
with
torch
.
enable_grad
(),
autocast
(),
torch
.
autograd
.
graph
.
saved_tensors_hooks
(
inner_pack
,
inner_unpack
inner_pack
,
inner_unpack
):
):
_unused
=
function
(
*
args
)
_unused
=
function
(
*
args
)
...
...
colossalai/shardformer/layer/utils.py
View file @
3acbf6d4
...
@@ -6,6 +6,7 @@ import torch.distributed as dist
...
@@ -6,6 +6,7 @@ import torch.distributed as dist
from
torch
import
nn
from
torch
import
nn
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
from
torch.distributed
import
ProcessGroup
,
get_world_size
from
torch.distributed
import
ProcessGroup
,
get_world_size
from
colossalai.utils.device
import
get_current_device
,
get_rng_state
,
set_rng_state
,
manual_seed
class
SeqParallelUtils
:
class
SeqParallelUtils
:
...
@@ -104,14 +105,14 @@ class Randomizer:
...
@@ -104,14 +105,14 @@ class Randomizer:
def
__init__
(
self
,
seed
:
int
):
def
__init__
(
self
,
seed
:
int
):
self
.
seed
=
seed
self
.
seed
=
seed
# Handle
CUDA
rng state
# Handle
device
rng state
# 1. get the current rng state
# 1. get the current rng state
# 2. set the seed and store the rng state
# 2. set the seed and store the rng state
# 3. recover the original rng state
# 3. recover the original rng state
cuda
_original_rng_state
=
torch
.
cuda
.
get_rng_state
()
device
_original_rng_state
=
get_rng_state
()
torch
.
cuda
.
manual_seed
(
seed
)
manual_seed
(
seed
)
self
.
cuda
_rng_state
=
torch
.
cuda
.
get_rng_state
()
self
.
device
_rng_state
=
get_rng_state
()
torch
.
cuda
.
set_rng_state
(
cuda
_original_rng_state
)
set_rng_state
(
device
_original_rng_state
)
# to the same for cpu rng state
# to the same for cpu rng state
cpu_original_rng_state
=
torch
.
get_rng_state
()
cpu_original_rng_state
=
torch
.
get_rng_state
()
...
@@ -119,11 +120,11 @@ class Randomizer:
...
@@ -119,11 +120,11 @@ class Randomizer:
self
.
cpu_rng_state
=
torch
.
get_rng_state
()
self
.
cpu_rng_state
=
torch
.
get_rng_state
()
torch
.
set_rng_state
(
cpu_original_rng_state
)
torch
.
set_rng_state
(
cpu_original_rng_state
)
def
_set_
cuda
_rng_state
(
self
,
rng_state
):
def
_set_
device
_rng_state
(
self
,
rng_state
):
torch
.
cuda
.
set_rng_state
(
rng_state
)
set_rng_state
(
rng_state
)
def
_get_
cuda
_rng_state
(
self
):
def
_get_
device
_rng_state
(
self
):
current_state
=
torch
.
cuda
.
get_rng_state
()
current_state
=
get_rng_state
()
return
current_state
return
current_state
def
_set_cpu_rng_state
(
self
,
rng_state
):
def
_set_cpu_rng_state
(
self
,
rng_state
):
...
@@ -144,16 +145,16 @@ class Randomizer:
...
@@ -144,16 +145,16 @@ class Randomizer:
>>> input = super().forward(input)
>>> input = super().forward(input)
"""
"""
try
:
try
:
current_
cuda
_rng_state
=
self
.
_get_
cuda
_rng_state
()
current_
device
_rng_state
=
self
.
_get_
device
_rng_state
()
self
.
_set_
cuda
_rng_state
(
self
.
cuda
_rng_state
)
self
.
_set_
device
_rng_state
(
self
.
device
_rng_state
)
if
enable_cpu
:
if
enable_cpu
:
current_cpu_rng_state
=
self
.
_get_cpu_rng_state
()
current_cpu_rng_state
=
self
.
_get_cpu_rng_state
()
self
.
_set_cpu_rng_state
(
self
.
cpu_rng_state
)
self
.
_set_cpu_rng_state
(
self
.
cpu_rng_state
)
yield
yield
finally
:
finally
:
self
.
cuda
_rng_state
=
self
.
_get_
cuda
_rng_state
()
self
.
device
_rng_state
=
self
.
_get_
device
_rng_state
()
self
.
_set_
cuda
_rng_state
(
current_
cuda
_rng_state
)
self
.
_set_
device
_rng_state
(
current_
device
_rng_state
)
if
enable_cpu
:
if
enable_cpu
:
self
.
cpu_rng_state
=
self
.
_get_cpu_rng_state
()
self
.
cpu_rng_state
=
self
.
_get_cpu_rng_state
()
...
@@ -208,7 +209,7 @@ class Randomizer:
...
@@ -208,7 +209,7 @@ class Randomizer:
index
=
Randomizer
.
index
()
index
=
Randomizer
.
index
()
if
dist
.
is_initialized
():
if
dist
.
is_initialized
():
# convert the index to tensor
# convert the index to tensor
index_tensor
=
torch
.
tensor
(
index
,
dtype
=
torch
.
int32
).
cuda
()
index_tensor
=
torch
.
tensor
(
index
,
dtype
=
torch
.
int32
,
device
=
get_current_device
()
)
# all gather the index
# all gather the index
gathered_index
=
[
torch
.
zeros_like
(
index_tensor
)
for
_
in
range
(
dist
.
get_world_size
(
process_group
))]
gathered_index
=
[
torch
.
zeros_like
(
index_tensor
)
for
_
in
range
(
dist
.
get_world_size
(
process_group
))]
...
@@ -230,7 +231,7 @@ class Randomizer:
...
@@ -230,7 +231,7 @@ class Randomizer:
if
dist
.
is_initialized
():
if
dist
.
is_initialized
():
# convert the index to tensor
# convert the index to tensor
index_tensor
=
torch
.
tensor
(
index
,
dtype
=
torch
.
int32
).
cuda
()
index_tensor
=
torch
.
tensor
(
index
,
dtype
=
torch
.
int32
,
device
=
get_current_device
()
)
# all gather the index
# all gather the index
gathered_index
=
[
torch
.
zeros_like
(
index_tensor
)
for
_
in
range
(
dist
.
get_world_size
(
process_group
))]
gathered_index
=
[
torch
.
zeros_like
(
index_tensor
)
for
_
in
range
(
dist
.
get_world_size
(
process_group
))]
...
...
colossalai/testing/utils.py
View file @
3acbf6d4
...
@@ -9,6 +9,7 @@ from typing import Any, Callable, List
...
@@ -9,6 +9,7 @@ from typing import Any, Callable, List
import
torch
import
torch
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
from
packaging
import
version
from
packaging
import
version
from
colossalai.utils.device
import
empty_cache
,
reset_max_memory_allocated
,
reset_peak_memory_stats
,
synchronize
,
reset_max_memory_cached
,
device_count
def
parameterize
(
argument
:
str
,
values
:
List
[
Any
])
->
Callable
:
def
parameterize
(
argument
:
str
,
values
:
List
[
Any
])
->
Callable
:
...
@@ -198,7 +199,7 @@ def skip_if_not_enough_gpus(min_gpus: int):
...
@@ -198,7 +199,7 @@ def skip_if_not_enough_gpus(min_gpus: int):
def
_wrap_func
(
f
):
def
_wrap_func
(
f
):
def
_execute_by_gpu_num
(
*
args
,
**
kwargs
):
def
_execute_by_gpu_num
(
*
args
,
**
kwargs
):
num_avail_gpu
=
torch
.
cuda
.
device_count
()
num_avail_gpu
=
device_count
()
if
num_avail_gpu
>=
min_gpus
:
if
num_avail_gpu
>=
min_gpus
:
f
(
*
args
,
**
kwargs
)
f
(
*
args
,
**
kwargs
)
...
@@ -262,11 +263,11 @@ def clear_cache_before_run():
...
@@ -262,11 +263,11 @@ def clear_cache_before_run():
def
_wrap_func
(
f
):
def
_wrap_func
(
f
):
def
_clear_cache
(
*
args
,
**
kwargs
):
def
_clear_cache
(
*
args
,
**
kwargs
):
torch
.
cuda
.
empty_cache
()
empty_cache
()
torch
.
cuda
.
reset_peak_memory_stats
()
reset_peak_memory_stats
()
torch
.
cuda
.
reset_max_memory_allocated
()
reset_max_memory_allocated
()
torch
.
cuda
.
reset_max_memory_cached
()
reset_max_memory_cached
()
torch
.
cuda
.
synchronize
()
synchronize
()
gc
.
collect
()
gc
.
collect
()
f
(
*
args
,
**
kwargs
)
f
(
*
args
,
**
kwargs
)
...
...
colossalai/utils/device.py
View file @
3acbf6d4
#!/usr/bin/env python
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# -*- encoding: utf-8 -*-
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Callable
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
...
@@ -191,6 +191,10 @@ def reset_max_memory_allocated(device=None) -> None:
...
@@ -191,6 +191,10 @@ def reset_max_memory_allocated(device=None) -> None:
return
_dispatch_device_func
(
"reset_max_memory_allocated"
,
device
)
return
_dispatch_device_func
(
"reset_max_memory_allocated"
,
device
)
def
reset_max_memory_cached
(
device
=
None
)
->
None
:
return
_dispatch_device_func
(
"reset_max_memory_cached"
,
device
)
def
memory_reserved
(
device
=
None
)
->
int
:
def
memory_reserved
(
device
=
None
)
->
int
:
return
_dispatch_device_func
(
"memory_reserved"
,
device
)
return
_dispatch_device_func
(
"memory_reserved"
,
device
)
...
@@ -205,3 +209,15 @@ def set_per_process_memory_fraction(fraction: float, device=None) -> None:
...
@@ -205,3 +209,15 @@ def set_per_process_memory_fraction(fraction: float, device=None) -> None:
def
reset_peak_memory_stats
(
device
=
None
)
->
None
:
def
reset_peak_memory_stats
(
device
=
None
)
->
None
:
return
_dispatch_device_func
(
"reset_peak_memory_stats"
,
device
)
return
_dispatch_device_func
(
"reset_peak_memory_stats"
,
device
)
# amp
def
autocast
()
->
Callable
:
if
torch
.
cuda
.
is_available
():
return
torch
.
cuda
.
amp
.
autocast
()
elif
IS_NPU_AVAILABLE
:
return
torch
.
npu
.
amp
.
autocast
()
else
:
raise
RuntimeError
(
"No device available"
)
examples/language/llama2/benchmark.py
View file @
3acbf6d4
...
@@ -131,7 +131,7 @@ def main():
...
@@ -131,7 +131,7 @@ def main():
tp_size
=
args
.
tp
,
tp_size
=
args
.
tp
,
pp_size
=
args
.
pp
,
pp_size
=
args
.
pp
,
zero_stage
=
args
.
zero
,
zero_stage
=
args
.
zero
,
enable_fused_normalization
=
True
,
enable_fused_normalization
=
torch
.
cuda
.
is_available
()
,
num_microbatches
=
args
.
mbs
,
num_microbatches
=
args
.
mbs
,
precision
=
"bf16"
,
precision
=
"bf16"
,
)
)
...
@@ -141,7 +141,7 @@ def main():
...
@@ -141,7 +141,7 @@ def main():
pp_size
=
args
.
pp
,
pp_size
=
args
.
pp
,
zero_stage
=
args
.
zero
,
zero_stage
=
args
.
zero
,
cpu_offload
=
True
,
cpu_offload
=
True
,
enable_fused_normalization
=
True
,
enable_fused_normalization
=
torch
.
cuda
.
is_available
()
,
num_microbatches
=
args
.
mbs
,
num_microbatches
=
args
.
mbs
,
initial_scale
=
2
**
8
,
initial_scale
=
2
**
8
,
precision
=
"bf16"
,
precision
=
"bf16"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment