Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
8823cc48
Unverified
Commit
8823cc48
authored
Jan 29, 2024
by
Frank Lee
Committed by
GitHub
Jan 29, 2024
Browse files
Merge pull request #5310 from hpcaitech/feature/npu
Feature/npu
parents
bce9499e
73f4dc57
Changes
266
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
71 additions
and
276 deletions
+71
-276
colossalai/nn/optimizer/fused_lamb.py
colossalai/nn/optimizer/fused_lamb.py
+2
-2
colossalai/nn/optimizer/fused_sgd.py
colossalai/nn/optimizer/fused_sgd.py
+2
-2
colossalai/nn/optimizer/hybrid_adam.py
colossalai/nn/optimizer/hybrid_adam.py
+2
-2
colossalai/pipeline/schedule/generate.py
colossalai/pipeline/schedule/generate.py
+2
-2
colossalai/pipeline/schedule/interleaved_pp.py
colossalai/pipeline/schedule/interleaved_pp.py
+3
-2
colossalai/pipeline/schedule/one_f_one_b.py
colossalai/pipeline/schedule/one_f_one_b.py
+4
-3
colossalai/shardformer/layer/utils.py
colossalai/shardformer/layer/utils.py
+10
-9
colossalai/shardformer/modeling/blip2.py
colossalai/shardformer/modeling/blip2.py
+1
-1
colossalai/shardformer/modeling/chatglm2.py
colossalai/shardformer/modeling/chatglm2.py
+1
-1
colossalai/shardformer/modeling/gpt2.py
colossalai/shardformer/modeling/gpt2.py
+1
-1
colossalai/shardformer/modeling/gptj.py
colossalai/shardformer/modeling/gptj.py
+1
-1
colossalai/shardformer/modeling/llama.py
colossalai/shardformer/modeling/llama.py
+19
-9
colossalai/shardformer/modeling/mistral.py
colossalai/shardformer/modeling/mistral.py
+1
-1
colossalai/shardformer/modeling/opt.py
colossalai/shardformer/modeling/opt.py
+1
-1
colossalai/shardformer/modeling/vit.py
colossalai/shardformer/modeling/vit.py
+1
-1
colossalai/shardformer/modeling/whisper.py
colossalai/shardformer/modeling/whisper.py
+1
-1
colossalai/testing/utils.py
colossalai/testing/utils.py
+8
-7
colossalai/utils/__init__.py
colossalai/utils/__init__.py
+2
-7
colossalai/utils/common.py
colossalai/utils/common.py
+9
-0
colossalai/utils/device.py
colossalai/utils/device.py
+0
-223
No files found.
colossalai/nn/optimizer/fused_lamb.py
View file @
8823cc48
...
...
@@ -77,9 +77,9 @@ class FusedLAMB(torch.optim.Optimizer):
)
super
(
FusedLAMB
,
self
).
__init__
(
params
,
defaults
)
if
multi_tensor_applier
.
available
:
from
colossalai.kernel.
op_buil
der
import
FusedOptim
Buil
der
from
colossalai.kernel.
kernel_loa
der
import
FusedOptim
izerLoa
der
fused_optim
=
FusedOptim
Buil
der
().
load
()
fused_optim
=
FusedOptim
izerLoa
der
().
load
()
self
.
multi_tensor_l2norm
=
fused_optim
.
multi_tensor_l2norm
# Skip buffer
...
...
colossalai/nn/optimizer/fused_sgd.py
View file @
8823cc48
...
...
@@ -72,9 +72,9 @@ class FusedSGD(Optimizer):
self
.
wd_after_momentum
=
wd_after_momentum
if
multi_tensor_applier
.
available
:
from
colossalai.kernel.
op_buil
der
import
FusedOptim
Buil
der
from
colossalai.kernel.
kernel_loa
der
import
FusedOptim
izerLoa
der
fused_optim
=
FusedOptim
Buil
der
().
load
()
fused_optim
=
FusedOptim
izerLoa
der
().
load
()
# Skip buffer
self
.
_dummy_overflow_buf
=
torch
.
tensor
(
...
...
colossalai/nn/optimizer/hybrid_adam.py
View file @
8823cc48
...
...
@@ -2,7 +2,7 @@ from typing import Any, Optional
import
torch
from
colossalai.kernel.
op_buil
der
import
FusedOptim
Buil
der
from
colossalai.kernel.
kernel_loa
der
import
FusedOptim
izerLoa
der
from
colossalai.utils
import
multi_tensor_applier
from
.cpu_adam
import
CPUAdam
...
...
@@ -85,7 +85,7 @@ class HybridAdam(CPUAdam):
nvme_offload_dir
,
)
if
torch
.
cuda
.
is_available
():
fused_optim
=
FusedOptim
Buil
der
().
load
()
fused_optim
=
FusedOptim
izerLoa
der
().
load
()
self
.
gpu_adam_op
=
fused_optim
.
multi_tensor_adam
self
.
_dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
...
...
colossalai/pipeline/schedule/generate.py
View file @
8823cc48
...
...
@@ -7,10 +7,10 @@ import torch.cuda
from
torch.nn
import
Module
from
torch.utils._pytree
import
tree_map
from
colossalai.accelerator
import
get_accelerator
from
colossalai.inference.engine.microbatch_manager
import
MicroBatchManager
,
Status
from
colossalai.pipeline.p2p
import
PipelineP2PCommunication
from
colossalai.pipeline.stage_manager
import
PipelineStageManager
from
colossalai.utils.device
import
get_current_device
from
._utils
import
get_batch_size
,
get_micro_batch
,
model_forward
,
to_device
from
.base
import
PipelineSchedule
...
...
@@ -86,7 +86,7 @@ class GenerateSchedule(PipelineSchedule):
"""
micro_batch
=
get_micro_batch
(
self
.
batch
,
self
.
microbatch_offset
,
self
.
microbatch_size
)
self
.
microbatch_offset
+=
self
.
microbatch_size
return
tree_map
(
partial
(
to_device
,
device
=
get_current_device
()),
micro_batch
)
return
tree_map
(
partial
(
to_device
,
device
=
get_
accelerator
().
get_
current_device
()),
micro_batch
)
def
_prepare_inputs_for_interval_stage
(
self
):
"""
...
...
colossalai/pipeline/schedule/interleaved_pp.py
View file @
8823cc48
...
...
@@ -6,10 +6,11 @@ import torch.cuda
from
torch.nn
import
Module
,
ModuleList
from
torch.utils._pytree
import
tree_map
from
colossalai.accelerator
import
get_accelerator
from
colossalai.interface
import
OptimizerWrapper
from
colossalai.pipeline.p2p
import
PipelineP2PCommunication
,
create_send_metadata
from
colossalai.pipeline.stage_manager
import
PipelineStageManager
from
colossalai.utils
.device
import
get_current_device
from
colossalai.utils
import
get_current_device
from
._utils
import
detach
,
get_batch_size
,
get_micro_batch
,
merge_batch
,
model_forward
,
retain_grad
,
to_device
from
.base
import
PipelineSchedule
...
...
@@ -100,7 +101,7 @@ class InterleavedSchedule(PipelineSchedule):
assert
self
.
microbatch_offset
[
model_chunk_id
]
<=
self
.
batch_size
,
"Microbatches exhausted"
micro_batch
=
get_micro_batch
(
self
.
batch
,
self
.
microbatch_offset
[
model_chunk_id
],
self
.
microbatch_size
)
self
.
microbatch_offset
[
model_chunk_id
]
+=
self
.
microbatch_size
return
tree_map
(
partial
(
to_device
,
device
=
get_current_device
()),
micro_batch
)
return
tree_map
(
partial
(
to_device
,
device
=
get_
accelerator
().
get_
current_device
()),
micro_batch
)
def
get_model_chunk_id
(
self
,
microbatch_id
:
int
,
is_forward
:
bool
)
->
int
:
"""Helper method to get the model chunk ID given the iteration number.
...
...
colossalai/pipeline/schedule/one_f_one_b.py
View file @
8823cc48
...
...
@@ -6,10 +6,11 @@ import torch.cuda
from
torch.nn
import
Module
from
torch.utils._pytree
import
tree_map
from
colossalai.accelerator
import
get_accelerator
from
colossalai.interface
import
ModelWrapper
,
OptimizerWrapper
from
colossalai.pipeline.p2p
import
PipelineP2PCommunication
,
create_send_metadata
from
colossalai.pipeline.stage_manager
import
PipelineStageManager
from
colossalai.utils
.device
import
get_current_device
from
colossalai.utils
import
get_current_device
from
._utils
import
(
detach
,
...
...
@@ -110,7 +111,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
assert
self
.
microbatch_offset
<=
self
.
batch_size
,
"Microbatches exhausted"
micro_batch
=
get_micro_batch
(
self
.
batch
,
self
.
microbatch_offset
,
self
.
microbatch_size
)
self
.
microbatch_offset
+=
self
.
microbatch_size
return
tree_map
(
partial
(
to_device
,
device
=
get_current_device
()),
micro_batch
)
return
tree_map
(
partial
(
to_device
,
device
=
get_
accelerator
().
get_
current_device
()),
micro_batch
)
def
recv_forward
(
self
,
prev_rank
:
int
=
None
)
->
Any
:
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
...
...
@@ -317,7 +318,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
accum_loss
=
None
if
return_loss
and
self
.
stage_manager
.
is_last_stage
():
accum_loss
=
torch
.
scalar_tensor
(
0
,
device
=
get_current_device
())
accum_loss
=
torch
.
scalar_tensor
(
0
,
device
=
get_
accelerator
().
get_
current_device
())
outputs
=
[]
if
return_outputs
and
self
.
stage_manager
.
is_last_stage
()
else
None
for
_
in
range
(
self
.
num_microbatches
):
...
...
colossalai/shardformer/layer/utils.py
View file @
8823cc48
...
...
@@ -6,7 +6,8 @@ import torch.distributed as dist
from
torch
import
nn
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
from
torch.distributed
import
ProcessGroup
,
get_world_size
from
colossalai.utils.device
import
get_current_device
,
get_rng_state
,
set_rng_state
,
manual_seed
from
colossalai.accelerator
import
get_accelerator
class
SeqParallelUtils
:
...
...
@@ -109,10 +110,10 @@ class Randomizer:
# 1. get the current rng state
# 2. set the seed and store the rng state
# 3. recover the original rng state
device_original_rng_state
=
get_rng_state
()
manual_seed
(
seed
)
self
.
device_rng_state
=
get_rng_state
()
set_rng_state
(
device_original_rng_state
)
device_original_rng_state
=
get_accelerator
().
get_rng_state
()
get_accelerator
().
manual_seed
(
seed
)
self
.
device_rng_state
=
get_accelerator
().
get_rng_state
()
get_accelerator
().
set_rng_state
(
device_original_rng_state
)
# to the same for cpu rng state
cpu_original_rng_state
=
torch
.
get_rng_state
()
...
...
@@ -121,10 +122,10 @@ class Randomizer:
torch
.
set_rng_state
(
cpu_original_rng_state
)
def
_set_device_rng_state
(
self
,
rng_state
):
set_rng_state
(
rng_state
)
get_accelerator
().
set_rng_state
(
rng_state
)
def
_get_device_rng_state
(
self
):
current_state
=
get_rng_state
()
current_state
=
get_accelerator
().
get_rng_state
()
return
current_state
def
_set_cpu_rng_state
(
self
,
rng_state
):
...
...
@@ -209,7 +210,7 @@ class Randomizer:
index
=
Randomizer
.
index
()
if
dist
.
is_initialized
():
# convert the index to tensor
index_tensor
=
torch
.
tensor
(
index
,
dtype
=
torch
.
int32
,
device
=
get_current_device
())
index_tensor
=
torch
.
tensor
(
index
,
dtype
=
torch
.
int32
,
device
=
get_
accelerator
().
get_
current_device
())
# all gather the index
gathered_index
=
[
torch
.
zeros_like
(
index_tensor
)
for
_
in
range
(
dist
.
get_world_size
(
process_group
))]
...
...
@@ -231,7 +232,7 @@ class Randomizer:
if
dist
.
is_initialized
():
# convert the index to tensor
index_tensor
=
torch
.
tensor
(
index
,
dtype
=
torch
.
int32
,
device
=
get_current_device
())
index_tensor
=
torch
.
tensor
(
index
,
dtype
=
torch
.
int32
,
device
=
get_
accelerator
().
get_
current_device
())
# all gather the index
gathered_index
=
[
torch
.
zeros_like
(
index_tensor
)
for
_
in
range
(
dist
.
get_world_size
(
process_group
))]
...
...
colossalai/shardformer/modeling/blip2.py
View file @
8823cc48
...
...
@@ -62,7 +62,7 @@ def forward_fn():
def
get_blip2_flash_attention_forward
():
from
transformers.models.blip_2.modeling_blip_2
import
Blip2Attention
from
colossalai.
kernel.cuda_native
import
ColoAttention
from
colossalai.
nn.layer.colo_attention
import
ColoAttention
def
forward
(
self
:
Blip2Attention
,
...
...
colossalai/shardformer/modeling/chatglm2.py
View file @
8823cc48
...
...
@@ -14,7 +14,7 @@ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLM
def
get_flash_core_attention_forward
():
from
colossalai.
kernel.cuda_native
import
AttnMaskType
,
ColoAttention
from
colossalai.
nn.layer.colo_attention
import
AttnMaskType
,
ColoAttention
from
.chatglm2_6b.modeling_chatglm
import
CoreAttention
...
...
colossalai/shardformer/modeling/gpt2.py
View file @
8823cc48
...
...
@@ -719,7 +719,7 @@ class GPT2PipelineForwards:
def
get_gpt2_flash_attention_forward
():
from
transformers.models.gpt2.modeling_gpt2
import
GPT2Attention
from
colossalai.
kernel.cuda_native
import
AttnMaskType
,
ColoAttention
from
colossalai.
nn.layer.colo_attention
import
AttnMaskType
,
ColoAttention
def
split_heads
(
tensor
,
num_heads
,
attn_head_size
):
"""
...
...
colossalai/shardformer/modeling/gptj.py
View file @
8823cc48
...
...
@@ -530,7 +530,7 @@ class GPTJPipelineForwards:
def
get_gptj_flash_attention_forward
():
from
transformers.models.gptj.modeling_gptj
import
GPTJAttention
from
colossalai.
kernel.cuda_native
import
AttnMaskType
,
ColoAttention
from
colossalai.
nn.layer.colo_attention
import
AttnMaskType
,
ColoAttention
def
split_heads
(
tensor
,
num_attention_heads
,
attn_head_size
,
rotary
):
"""
...
...
colossalai/shardformer/modeling/llama.py
View file @
8823cc48
...
...
@@ -3,7 +3,6 @@ from typing import List, Optional, Tuple, Union
import
torch
import
torch.nn.functional
as
F
import
torch.distributed
as
dist
from
torch.nn
import
BCEWithLogitsLoss
,
CrossEntropyLoss
,
MSELoss
from
transformers.modeling_outputs
import
(
BaseModelOutputWithPast
,
...
...
@@ -15,14 +14,17 @@ from transformers.utils import logging
from
colossalai.pipeline.stage_manager
import
PipelineStageManager
from
colossalai.shardformer.shard
import
ShardConfig
from
..layer
import
cross_entropy_1d
try
:
from
transformers.models.llama.modeling_llama
import
_prepare_4d_causal_attention_mask
LATEST_VERSION
=
True
except
ImportError
:
LATEST_VERSION
=
False
class
LlamaPipelineForwards
:
"""
This class serves as a micro library for forward function substitution of Llama models
...
...
@@ -203,7 +205,7 @@ class LlamaPipelineForwards:
stage_manager
:
Optional
[
PipelineStageManager
]
=
None
,
hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
stage_index
:
Optional
[
List
[
int
]]
=
None
,
shard_config
:
ShardConfig
=
None
shard_config
:
ShardConfig
=
None
,
):
r
"""
Args:
...
...
@@ -279,12 +281,13 @@ class LlamaPipelineForwards:
if
shard_config
.
enable_tensor_parallelism
:
new_vocab_size
=
logits
.
shape
[
-
1
]
shift_logits
=
shift_logits
.
view
(
-
1
,
new_vocab_size
)
loss
=
cross_entropy_1d
(
shift_logits
,
shift_labels
,
process_group
=
shard_config
.
tensor_parallel_process_group
)
loss
=
cross_entropy_1d
(
shift_logits
,
shift_labels
,
process_group
=
shard_config
.
tensor_parallel_process_group
)
else
:
shift_logits
=
shift_logits
.
view
(
-
1
,
self
.
config
.
vocab_size
)
loss
=
loss_fct
(
shift_logits
,
shift_labels
)
if
not
return_dict
:
output
=
(
logits
,)
+
outputs
[
1
:]
return
(
loss
,)
+
output
if
loss
is
not
None
else
output
...
...
@@ -417,7 +420,7 @@ class LlamaPipelineForwards:
def
get_llama_flash_attention_forward
(
shard_config
:
ShardConfig
):
from
transformers.models.llama.modeling_llama
import
LlamaAttention
,
apply_rotary_pos_emb
from
colossalai.
kernel.cuda_native
import
AttnMaskType
,
ColoAttention
from
colossalai.
nn.layer.colo_attention
import
AttnMaskType
,
ColoAttention
llama_version
=
2
try
:
...
...
@@ -480,7 +483,12 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig):
attention
=
ColoAttention
(
embed_dim
=
self
.
hidden_size
,
num_heads
=
self
.
num_heads
)
attn_output
=
attention
(
query_states
,
key_states
,
value_states
,
attn_mask
=
flash_attention_mask
,
attn_mask_type
=
attn_mask_type
query_states
,
key_states
,
value_states
,
attn_mask
=
flash_attention_mask
,
attn_mask_type
=
attn_mask_type
,
origin_attn_mask
=
attention_mask
,
)
attn_output
=
self
.
o_proj
(
attn_output
)
...
...
@@ -492,7 +500,7 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig):
def
get_lm_forward_with_dist_cross_entropy
(
shard_config
:
ShardConfig
):
from
transformers
import
LlamaForCausalLM
def
forward
(
self
:
LlamaForCausalLM
,
input_ids
:
torch
.
LongTensor
=
None
,
...
...
@@ -573,12 +581,13 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
if
shard_config
.
enable_tensor_parallelism
:
new_vocab_size
=
logits
.
shape
[
-
1
]
shift_logits
=
shift_logits
.
view
(
-
1
,
new_vocab_size
)
loss
=
cross_entropy_1d
(
shift_logits
,
shift_labels
,
process_group
=
shard_config
.
tensor_parallel_process_group
)
loss
=
cross_entropy_1d
(
shift_logits
,
shift_labels
,
process_group
=
shard_config
.
tensor_parallel_process_group
)
else
:
shift_logits
=
shift_logits
.
view
(
-
1
,
self
.
config
.
vocab_size
)
loss
=
loss_fct
(
shift_logits
,
shift_labels
)
if
not
return_dict
:
output
=
(
logits
,)
+
outputs
[
1
:]
return
(
loss
,)
+
output
if
loss
is
not
None
else
output
...
...
@@ -590,4 +599,5 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
hidden_states
=
outputs
.
hidden_states
,
attentions
=
outputs
.
attentions
,
)
return
forward
colossalai/shardformer/modeling/mistral.py
View file @
8823cc48
...
...
@@ -6,7 +6,7 @@ import torch
def
get_mistral_flash_attention_forward
():
from
transformers.models.mistral.modeling_mistral
import
MistralAttention
,
apply_rotary_pos_emb
,
repeat_kv
from
colossalai.
kernel.cuda_native
import
AttnMaskType
,
ColoAttention
from
colossalai.
nn.layer.colo_attention
import
AttnMaskType
,
ColoAttention
def
forward
(
self
:
MistralAttention
,
...
...
colossalai/shardformer/modeling/opt.py
View file @
8823cc48
...
...
@@ -514,7 +514,7 @@ class OPTPipelineForwards:
def
get_opt_flash_attention_forward
():
from
transformers.models.opt.modeling_opt
import
OPTAttention
from
colossalai.
kernel.cuda_native
import
AttnMaskType
,
ColoAttention
from
colossalai.
nn.layer.colo_attention
import
AttnMaskType
,
ColoAttention
def
forward
(
self
:
OPTAttention
,
...
...
colossalai/shardformer/modeling/vit.py
View file @
8823cc48
...
...
@@ -336,7 +336,7 @@ def ViTForMaskedImageModeling_pipeline_forward(stage_manager: PipelineStageManag
def
get_vit_flash_self_attention_forward
():
from
transformers.models.vit.modeling_vit
import
ViTSelfAttention
from
colossalai.
kernel.cuda_native
import
ColoAttention
from
colossalai.
nn.layer.colo_attention
import
ColoAttention
def
transpose_for_scores
(
x
:
torch
.
Tensor
,
num_attention_heads
,
attention_head_size
)
->
torch
.
Tensor
:
new_x_shape
=
x
.
size
()[:
-
1
]
+
(
num_attention_heads
,
attention_head_size
)
...
...
colossalai/shardformer/modeling/whisper.py
View file @
8823cc48
...
...
@@ -26,7 +26,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
def
get_whisper_flash_attention_forward
():
from
transformers.models.whisper.modeling_whisper
import
WhisperAttention
from
colossalai.
kernel.cuda_native
import
AttnMaskType
,
ColoAttention
from
colossalai.
nn.layer.colo_attention
import
AttnMaskType
,
ColoAttention
def
shape
(
tensor
:
torch
.
Tensor
,
seq_len
:
int
,
bsz
:
int
,
num_heads
:
int
,
head_dim
:
int
):
return
tensor
.
view
(
bsz
,
seq_len
,
num_heads
,
head_dim
).
contiguous
()
...
...
colossalai/testing/utils.py
View file @
8823cc48
...
...
@@ -9,7 +9,8 @@ from typing import Any, Callable, List
import
torch
import
torch.multiprocessing
as
mp
from
packaging
import
version
from
colossalai.utils.device
import
empty_cache
,
reset_max_memory_allocated
,
reset_peak_memory_stats
,
synchronize
,
reset_max_memory_cached
,
device_count
from
colossalai.accelerator
import
get_accelerator
def
parameterize
(
argument
:
str
,
values
:
List
[
Any
])
->
Callable
:
...
...
@@ -199,7 +200,7 @@ def skip_if_not_enough_gpus(min_gpus: int):
def
_wrap_func
(
f
):
def
_execute_by_gpu_num
(
*
args
,
**
kwargs
):
num_avail_gpu
=
device_count
()
num_avail_gpu
=
get_accelerator
().
device_count
()
if
num_avail_gpu
>=
min_gpus
:
f
(
*
args
,
**
kwargs
)
...
...
@@ -263,11 +264,11 @@ def clear_cache_before_run():
def
_wrap_func
(
f
):
def
_clear_cache
(
*
args
,
**
kwargs
):
empty_cache
()
reset_peak_memory_stats
()
reset_max_memory_allocated
()
reset_max_memory_cached
()
synchronize
()
get_accelerator
().
empty_cache
()
get_accelerator
().
reset_peak_memory_stats
()
get_accelerator
().
reset_max_memory_allocated
()
get_accelerator
().
reset_max_memory_cached
()
get_accelerator
().
synchronize
()
gc
.
collect
()
f
(
*
args
,
**
kwargs
)
...
...
colossalai/utils/__init__.py
View file @
8823cc48
...
...
@@ -4,20 +4,16 @@ from .common import (
disposable
,
ensure_path_exists
,
free_storage
,
get_current_device
,
is_ddp_ignored
,
set_seed
,
)
from
.device
import
IS_NPU_AVAILABLE
,
empty_cache
,
get_current_device
,
set_device
,
set_to_cuda
,
synchronize
from
.multi_tensor_apply
import
multi_tensor_applier
from
.tensor_detector
import
TensorDetector
from
.timer
import
MultiTimer
,
Timer
__all__
=
[
"conditional_context"
,
"get_current_device"
,
"synchronize"
,
"empty_cache"
,
"set_to_cuda"
,
"Timer"
,
"MultiTimer"
,
"multi_tensor_applier"
,
...
...
@@ -27,7 +23,6 @@ __all__ = [
"_cast_float"
,
"free_storage"
,
"set_seed"
,
"get_current_device"
,
"is_ddp_ignored"
,
"set_device"
,
"IS_NPU_AVAILABLE"
,
]
colossalai/utils/common.py
View file @
8823cc48
...
...
@@ -10,6 +10,15 @@ from typing import Callable
import
numpy
as
np
import
torch
from
colossalai.accelerator
import
get_accelerator
def
get_current_device
():
"""
A wrapper function for accelerator's API for backward compatibility.
"""
return
get_accelerator
().
get_current_device
()
def
ensure_path_exists
(
filename
:
str
):
# ensure the path exists
...
...
colossalai/utils/device.py
deleted
100644 → 0
View file @
bce9499e
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Callable
import
torch
import
torch.distributed
as
dist
IS_NPU_AVAILABLE
:
bool
=
False
try
:
import
torch_npu
# noqa
IS_NPU_AVAILABLE
=
torch
.
npu
.
is_available
()
except
ImportError
:
pass
def
set_to_cuda
(
models
):
"""Send model to gpu.
:param models: nn.module or a list of module
"""
if
isinstance
(
models
,
list
)
and
len
(
models
)
>
1
:
ret
=
[]
for
model
in
models
:
ret
.
append
(
model
.
to
(
get_current_device
()))
return
ret
elif
isinstance
(
models
,
list
):
return
models
[
0
].
to
(
get_current_device
())
else
:
return
models
.
to
(
get_current_device
())
def
get_current_device
()
->
torch
.
device
:
"""
Returns currently selected device (gpu/cpu).
If cuda available, return gpu, otherwise return cpu.
"""
if
torch
.
cuda
.
is_available
():
return
torch
.
device
(
f
"cuda:
{
torch
.
cuda
.
current_device
()
}
"
)
elif
IS_NPU_AVAILABLE
:
return
torch
.
device
(
f
"npu:
{
torch
.
npu
.
current_device
()
}
"
)
else
:
return
torch
.
device
(
"cpu"
)
def
_dispatch_device_func
(
fn_name
:
str
,
*
args
,
**
kwargs
):
if
torch
.
cuda
.
is_available
():
return
getattr
(
torch
.
cuda
,
fn_name
)(
*
args
,
**
kwargs
)
elif
IS_NPU_AVAILABLE
:
return
getattr
(
torch
.
npu
,
fn_name
)(
*
args
,
**
kwargs
)
else
:
raise
RuntimeError
(
"No device available"
)
# device semantics
def
can_device_access_peer
(
device
,
peer_device
)
->
bool
:
return
_dispatch_device_func
(
"can_device_access_peer"
,
device
,
peer_device
)
def
current_device
()
->
int
:
return
_dispatch_device_func
(
"current_device"
)
def
current_stream
(
device
=
None
):
return
_dispatch_device_func
(
"current_stream"
,
device
)
def
default_stream
(
device
=
None
):
return
_dispatch_device_func
(
"default_stream"
,
device
)
def
device_count
()
->
int
:
return
_dispatch_device_func
(
"device_count"
)
def
get_device_capability
(
device
=
None
)
->
Tuple
[
int
,
int
]:
return
_dispatch_device_func
(
"get_device_capability"
,
device
)
def
get_device_name
(
device
=
None
)
->
str
:
return
_dispatch_device_func
(
"get_device_name"
,
device
)
def
get_device_properties
(
device
):
return
_dispatch_device_func
(
"get_device_properties"
,
device
)
def
set_device
(
index
:
Optional
[
int
]
=
None
)
->
None
:
if
index
is
None
:
index
=
dist
.
get_rank
()
%
device_count
()
_dispatch_device_func
(
"set_device"
,
index
)
def
set_stream
(
stream_
):
return
_dispatch_device_func
(
"set_stream"
,
stream_
)
def
stream
(
stream_
):
return
_dispatch_device_func
(
"stream"
,
stream_
)
def
synchronize
():
return
_dispatch_device_func
(
"synchronize"
)
def
utilization
(
device
=
None
)
->
int
:
return
_dispatch_device_func
(
"utilization"
,
device
)
# random number generator
def
get_rng_state
(
device
=
"cuda"
)
->
torch
.
Tensor
:
return
_dispatch_device_func
(
"get_rng_state"
,
device
)
def
get_rng_state_all
()
->
List
[
torch
.
Tensor
]:
return
_dispatch_device_func
(
"get_rng_state_all"
)
def
set_rng_state
(
new_state
:
torch
.
ByteTensor
,
device
=
"cuda"
)
->
None
:
return
_dispatch_device_func
(
"set_rng_state"
,
new_state
,
device
)
def
set_rng_state_all
(
new_states
:
List
[
torch
.
ByteTensor
])
->
None
:
return
_dispatch_device_func
(
"set_rng_state_all"
,
new_states
)
def
manual_seed
(
seed
:
int
)
->
None
:
return
_dispatch_device_func
(
"manual_seed"
,
seed
)
def
manual_seed_all
(
seed
:
int
)
->
None
:
return
_dispatch_device_func
(
"manual_seed_all"
,
seed
)
def
seed
()
->
None
:
return
_dispatch_device_func
(
"seed"
)
def
seed_all
()
->
None
:
return
_dispatch_device_func
(
"seed_all"
)
def
initial_seed
()
->
int
:
return
_dispatch_device_func
(
"initial_seed"
)
# streams and events
def
Stream
(
device
=
None
,
priority
=
0
,
**
kwargs
):
return
_dispatch_device_func
(
"Stream"
,
device
,
priority
,
**
kwargs
)
def
Event
(
enable_timing
:
bool
=
False
,
blocking
:
bool
=
False
,
interprocess
:
bool
=
False
):
return
_dispatch_device_func
(
"Event"
,
enable_timing
,
blocking
,
interprocess
)
# memory management
def
empty_cache
()
->
None
:
return
_dispatch_device_func
(
"empty_cache"
)
def
memory_stats
(
device
=
None
)
->
Dict
[
str
,
Any
]:
return
_dispatch_device_func
(
"memory_stats"
,
device
)
def
memory_summary
(
device
=
None
,
abbreviated
=
False
)
->
str
:
return
_dispatch_device_func
(
"memory_summary"
,
device
,
abbreviated
)
def
memory_snapshot
():
return
_dispatch_device_func
(
"memory_snapshot"
)
def
memory_allocated
(
device
=
None
)
->
int
:
return
_dispatch_device_func
(
"memory_allocated"
,
device
)
def
max_memory_allocated
(
device
=
None
)
->
int
:
return
_dispatch_device_func
(
"max_memory_allocated"
,
device
)
def
reset_max_memory_allocated
(
device
=
None
)
->
None
:
return
_dispatch_device_func
(
"reset_max_memory_allocated"
,
device
)
def
reset_max_memory_cached
(
device
=
None
)
->
None
:
return
_dispatch_device_func
(
"reset_max_memory_cached"
,
device
)
def
memory_reserved
(
device
=
None
)
->
int
:
return
_dispatch_device_func
(
"memory_reserved"
,
device
)
def
max_memory_reserved
(
device
=
None
)
->
int
:
return
_dispatch_device_func
(
"max_memory_reserved"
,
device
)
def
set_per_process_memory_fraction
(
fraction
:
float
,
device
=
None
)
->
None
:
return
_dispatch_device_func
(
"set_per_process_memory_fraction"
,
fraction
,
device
)
def
reset_peak_memory_stats
(
device
=
None
)
->
None
:
return
_dispatch_device_func
(
"reset_peak_memory_stats"
,
device
)
# amp
def
autocast
()
->
Callable
:
if
torch
.
cuda
.
is_available
():
return
torch
.
cuda
.
amp
.
autocast
()
elif
IS_NPU_AVAILABLE
:
return
torch
.
npu
.
amp
.
autocast
()
else
:
raise
RuntimeError
(
"No device available"
)
Prev
1
…
3
4
5
6
7
8
9
10
11
…
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment