Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
deepspeed
Commits
5bcc463d
Commit
5bcc463d
authored
May 29, 2023
by
aiss
Browse files
update v0.9.2
parent
ac5fbab4
Changes
554
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
494 additions
and
744 deletions
+494
-744
deepspeed/autotuning/utils.py
deepspeed/autotuning/utils.py
+8
-7
deepspeed/checkpoint/__init__.py
deepspeed/checkpoint/__init__.py
+5
-4
deepspeed/checkpoint/constants.py
deepspeed/checkpoint/constants.py
+9
-4
deepspeed/checkpoint/deepspeed_checkpoint.py
deepspeed/checkpoint/deepspeed_checkpoint.py
+31
-66
deepspeed/checkpoint/reshape_3d_utils.py
deepspeed/checkpoint/reshape_3d_utils.py
+14
-23
deepspeed/checkpoint/reshape_meg_2d.py
deepspeed/checkpoint/reshape_meg_2d.py
+10
-16
deepspeed/checkpoint/reshape_utils.py
deepspeed/checkpoint/reshape_utils.py
+7
-11
deepspeed/checkpoint/universal_checkpoint.py
deepspeed/checkpoint/universal_checkpoint.py
+11
-24
deepspeed/checkpoint/utils.py
deepspeed/checkpoint/utils.py
+5
-5
deepspeed/checkpoint/zero_checkpoint.py
deepspeed/checkpoint/zero_checkpoint.py
+9
-17
deepspeed/comm/__init__.py
deepspeed/comm/__init__.py
+5
-50
deepspeed/comm/backend.py
deepspeed/comm/backend.py
+10
-6
deepspeed/comm/comm.py
deepspeed/comm/comm.py
+140
-193
deepspeed/comm/config.py
deepspeed/comm/config.py
+6
-5
deepspeed/comm/constants.py
deepspeed/comm/constants.py
+4
-1
deepspeed/comm/torch.py
deepspeed/comm/torch.py
+121
-82
deepspeed/comm/utils.py
deepspeed/comm/utils.py
+5
-35
deepspeed/compression/__init__.py
deepspeed/compression/__init__.py
+4
-1
deepspeed/compression/basic_layer.py
deepspeed/compression/basic_layer.py
+60
-145
deepspeed/compression/compress.py
deepspeed/compression/compress.py
+30
-49
No files found.
Too many changes to show.
To preserve performance only
554 of 554+
files are displayed.
Plain diff
Email patch
deepspeed/autotuning/utils.py
View file @
5bcc463d
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import
re
import
collections.abc
...
...
@@ -176,6 +179,7 @@ def fetch_hostfile(hostfile_path):
def
validate_ds_config
(
config
:
dict
):
def
is_False
(
config
:
dict
,
key
):
if
config
is
None
:
return
False
...
...
@@ -189,9 +193,7 @@ def validate_ds_config(config: dict):
if
stage
==
1
:
return
True
elif
stage
==
2
:
if
is_False
(
config_zero
,
"cpu_offload"
)
and
is_False
(
config_zero
,
"cpu_offload_params"
):
if
is_False
(
config_zero
,
"cpu_offload"
)
and
is_False
(
config_zero
,
"cpu_offload_params"
):
return
False
elif
stage
==
3
:
offload_devices
=
[
"cpu"
,
"nvme"
]
...
...
@@ -289,14 +291,13 @@ def get_all_configs(tuning_space: dict, ignore_keys=None):
Args:
tuning_space (dict): the tuning space where tunable parameters are lists of values.
"""
def
gen_combinations
(
d
:
dict
):
keys
,
values
=
d
.
keys
(),
d
.
values
()
for
v
in
values
:
if
not
isinstance
(
v
,
list
):
v
=
[
v
]
values_choices
=
(
gen_combinations
(
v
)
if
isinstance
(
v
,
dict
)
else
get_list
(
v
)
for
v
in
values
)
values_choices
=
(
gen_combinations
(
v
)
if
isinstance
(
v
,
dict
)
else
get_list
(
v
)
for
v
in
values
)
for
comb
in
itertools
.
product
(
*
values_choices
):
yield
dict
(
zip
(
keys
,
comb
))
...
...
deepspeed/checkpoint/__init__.py
View file @
5bcc463d
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from
.reshape_meg_2d
import
reshape_meg_2d_parallel
from
.deepspeed_checkpoint
import
DeepSpeedCheckpoint
from
.utils
import
(
get_layer_ckpt_name_for_rank
,
get_model_ckpt_name_for_rank
,
get_zero_ckpt_name_for_rank
)
from
.utils
import
(
get_layer_ckpt_name_for_rank
,
get_model_ckpt_name_for_rank
,
get_zero_ckpt_name_for_rank
)
from
.reshape_utils
import
(
merge_state
)
...
...
deepspeed/checkpoint/constants.py
View file @
5bcc463d
'''Copyright The Microsoft DeepSpeed Team'''
'''
Various symbolic constants used for model checkpointing
'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
"""
Various symbolic constants used for model checkpointing
"""
#########################################
# Optimizer checkpoint keys
...
...
@@ -24,6 +27,8 @@ FP32_WEIGHT_KEY = "fp32"
PARAM
=
'param'
PARAM_SHAPES
=
'param_shapes'
BUFFER_NAMES
=
'buffer_names'
FROZEN_PARAM_SHAPES
=
'frozen_param_shapes'
FROZEN_PARAM_FRAGMENTS
=
'frozen_param_fragments'
#########################################
# Checkpoint naming constants
...
...
deepspeed/checkpoint/deepspeed_checkpoint.py
View file @
5bcc463d
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import
os
from
typing
import
Dict
import
torch
from
.reshape_3d_utils
import
model_3d_desc
from
.reshape_utils
import
(
basic_folder_validation
,
merge_state
,
partition_data
,
get_files
,
get_files_with_prefix
)
from
.reshape_utils
import
(
basic_folder_validation
,
merge_state
,
partition_data
,
get_files
,
get_files_with_prefix
)
from
.constants
import
(
MODEL_FILE_PREFIX
,
LAYER_FILE_PREFIX
)
...
...
@@ -24,19 +23,15 @@ CHECKPOINT_INFO_KEY = 'checkpoint_info'
ITERATION_KEY
=
'iteration'
SEQUENTIAL_LAYERS
=
[
'input_layernorm.weight'
,
'input_layernorm.bias'
,
'self_attention.dense.bias'
,
'post_attention_layernorm.weight'
,
'post_attention_layernorm.bias'
,
'mlp.dense_4h_to_h.bias'
,
'position_embeddings.weight'
'input_layernorm.weight'
,
'input_layernorm.bias'
,
'self_attention.dense.bias'
,
'post_attention_layernorm.weight'
,
'post_attention_layernorm.bias'
,
'mlp.dense_4h_to_h.bias'
,
'position_embeddings.weight'
]
LAYER_CONCAT_DIM
=
{
'self_attention.dense.weight'
:
1
,
'mlp.dense_4h_to_h.weight'
:
1
}
class
DeepSpeedCheckpoint
(
object
):
def
__init__
(
self
,
dir
,
tp_degree
=
None
,
pp_degree
=
None
,
dp_degree
=
None
):
self
.
dir
=
dir
self
.
_validate_folder
(
dir
)
...
...
@@ -50,33 +45,24 @@ class DeepSpeedCheckpoint(object):
self
.
layer_keys
=
self
.
_get_layer_keys
()
self
.
layer_count
=
len
(
self
.
layer_keys
)
self
.
tp_degree
=
self
.
zero_checkpoint
.
get_src_tp_degree
(
)
if
tp_degree
is
None
else
tp_degree
self
.
pp_degree
=
self
.
zero_checkpoint
.
get_src_pp_degree
(
)
if
pp_degree
is
None
else
pp_degree
self
.
dp_degree
=
self
.
zero_checkpoint
.
get_src_dp_degree
(
)
if
dp_degree
is
None
else
dp_degree
self
.
tp_degree
=
self
.
zero_checkpoint
.
get_src_tp_degree
()
if
tp_degree
is
None
else
tp_degree
self
.
pp_degree
=
self
.
zero_checkpoint
.
get_src_pp_degree
()
if
pp_degree
is
None
else
pp_degree
self
.
dp_degree
=
self
.
zero_checkpoint
.
get_src_dp_degree
()
if
dp_degree
is
None
else
dp_degree
self
.
original_world_size
=
self
.
zero_checkpoint
.
get_src_tp_degree
(
)
*
self
.
zero_checkpoint
.
get_src_pp_degree
(
self
.
original_world_size
=
self
.
zero_checkpoint
.
get_src_tp_degree
()
*
self
.
zero_checkpoint
.
get_src_pp_degree
(
)
*
self
.
zero_checkpoint
.
get_src_dp_degree
()
self
.
world_size
=
self
.
tp_degree
*
self
.
pp_degree
*
self
.
dp_degree
self
.
old_2d_map
=
meg_2d_parallel_map
(
self
.
zero_checkpoint
.
get_src_pp_degree
(),
self
.
zero_checkpoint
.
get_src_tp_degree
())
self
.
old_2d_map
.
simple_init
()
self
.
new_2d_map
=
reshape_meg_2d_parallel
(
old_pp_degree
=
self
.
zero_checkpoint
.
get_src_pp_degree
(),
old_tp_degree
=
self
.
zero_checkpoint
.
get_src_tp_degree
(),
new_pp_degree
=
self
.
pp_degree
,
new_tp_degree
=
self
.
tp_degree
)
if
self
.
is_change_pp_degree
()
or
self
.
is_change_tp_degree
(
)
or
self
.
is_change_dp_degree
():
self
.
zero_checkpoint
.
reshape
(
model_3d_desc
(
self
.
pp_degree
,
self
.
tp_degree
,
self
.
dp_degree
))
self
.
new_2d_map
=
reshape_meg_2d_parallel
(
old_pp_degree
=
self
.
zero_checkpoint
.
get_src_pp_degree
(),
old_tp_degree
=
self
.
zero_checkpoint
.
get_src_tp_degree
(),
new_pp_degree
=
self
.
pp_degree
,
new_tp_degree
=
self
.
tp_degree
)
if
self
.
is_change_pp_degree
()
or
self
.
is_change_tp_degree
()
or
self
.
is_change_dp_degree
():
self
.
zero_checkpoint
.
reshape
(
model_3d_desc
(
self
.
pp_degree
,
self
.
tp_degree
,
self
.
dp_degree
))
self
.
global_state
=
{}
...
...
@@ -84,8 +70,7 @@ class DeepSpeedCheckpoint(object):
self
.
pp_to_transformer_map
=
self
.
_build_pp_transformer_map
()
self
.
transformer_file_map
=
self
.
_build_transformer_file_map
()
self
.
tp_to_embedding_map
=
self
.
_build_tp_other_layer_map
(
EMBEDDING_LAYER_INDEX
)
self
.
tp_to_final_norm_map
=
self
.
_build_tp_other_layer_map
(
FINAL_LAYER_NORM_INDEX
)
self
.
tp_to_final_norm_map
=
self
.
_build_tp_other_layer_map
(
FINAL_LAYER_NORM_INDEX
)
self
.
_build_global_state
()
def
is_change_tp_degree
(
self
):
...
...
@@ -131,9 +116,7 @@ class DeepSpeedCheckpoint(object):
keys_to_ignore
=
[
PARAM_SHAPES
])
def
get_zero_files
(
self
,
pp_index
,
tp_index
,
dp_index
)
->
list
:
return
self
.
zero_checkpoint
.
get_files_for_rank
(
pp_index
=
pp_index
,
tp_index
=
tp_index
,
dp_index
=
dp_index
)
return
self
.
zero_checkpoint
.
get_files_for_rank
(
pp_index
=
pp_index
,
tp_index
=
tp_index
,
dp_index
=
dp_index
)
def
get_embedding_layer_id
(
self
):
return
self
.
layer_keys
[
EMBEDDING_LAYER_INDEX
]
...
...
@@ -150,11 +133,7 @@ class DeepSpeedCheckpoint(object):
def
get_embedding_state
(
self
,
tp_index
:
int
)
->
Dict
:
assert
tp_index
in
self
.
tp_to_embedding_map
.
keys
()
sd_list
=
[
torch
.
load
(
fname
,
map_location
=
torch
.
device
(
'cpu'
))
for
fname
in
self
.
tp_to_embedding_map
[
tp_index
]
]
sd_list
=
[
torch
.
load
(
fname
,
map_location
=
torch
.
device
(
'cpu'
))
for
fname
in
self
.
tp_to_embedding_map
[
tp_index
]]
sd
=
self
.
_merge_state_dicts
(
sd_list
)
return
sd
...
...
@@ -179,10 +158,7 @@ class DeepSpeedCheckpoint(object):
assert
tp_index
<
self
.
tp_degree
assert
pp_index
<
self
.
pp_degree
fname_list
=
self
.
get_2d_parallel_files
(
tp_index
=
tp_index
,
pp_index
=
pp_index
)
sd_list
=
[
torch
.
load
(
fname
,
map_location
=
torch
.
device
(
'cpu'
))
for
fname
in
fname_list
]
sd_list
=
[
torch
.
load
(
fname
,
map_location
=
torch
.
device
(
'cpu'
))
for
fname
in
fname_list
]
merged_sd
=
None
for
sd
in
sd_list
:
...
...
@@ -198,10 +174,7 @@ class DeepSpeedCheckpoint(object):
assert
pp_index
<
self
.
pp_degree
t_list
=
[]
for
fname_list
in
self
.
transformer_file_map
[(
tp_index
,
pp_index
)]:
sd_list
=
[
torch
.
load
(
fname
,
map_location
=
torch
.
device
(
'cpu'
))
for
fname
in
fname_list
]
sd_list
=
[
torch
.
load
(
fname
,
map_location
=
torch
.
device
(
'cpu'
))
for
fname
in
fname_list
]
sd
=
self
.
_merge_state_dicts
(
sd_list
)
t_list
.
append
(
sd
)
return
t_list
...
...
@@ -212,8 +185,7 @@ class DeepSpeedCheckpoint(object):
def
get_final_norm_state
(
self
,
tp_index
:
int
)
->
Dict
:
assert
tp_index
in
self
.
tp_to_final_norm_map
.
keys
()
sd
=
torch
.
load
(
self
.
tp_to_final_norm_map
[
tp_index
][
0
],
map_location
=
torch
.
device
(
'cpu'
))
sd
=
torch
.
load
(
self
.
tp_to_final_norm_map
[
tp_index
][
0
],
map_location
=
torch
.
device
(
'cpu'
))
return
sd
def
get_final_norm_files
(
self
,
tp_index
:
int
)
->
list
:
...
...
@@ -222,8 +194,7 @@ class DeepSpeedCheckpoint(object):
def
_build_tp_other_layer_map
(
self
,
layer_index
:
int
):
assert
layer_index
<
len
(
self
.
layer_files
)
layer_files
=
get_files_with_prefix
(
self
.
layer_files
,
self
.
layer_keys
[
layer_index
])
layer_files
=
get_files_with_prefix
(
self
.
layer_files
,
self
.
layer_keys
[
layer_index
])
layer_file_partitions
=
partition_data
(
layer_files
,
self
.
tp_degree
)
data_map
=
{
i
:
flist
for
i
,
flist
in
enumerate
(
layer_file_partitions
)}
return
data_map
...
...
@@ -238,11 +209,7 @@ class DeepSpeedCheckpoint(object):
data_map
=
{}
transformer_layers
=
self
.
layer_keys
[
1
:
-
1
]
layers_per_pp
=
len
(
transformer_layers
)
//
self
.
pp_degree
data_map
=
{
i
:
transformer_layers
[
i
*
layers_per_pp
:(
i
+
1
)
*
layers_per_pp
]
for
i
in
range
(
0
,
self
.
pp_degree
)
}
data_map
=
{
i
:
transformer_layers
[
i
*
layers_per_pp
:(
i
+
1
)
*
layers_per_pp
]
for
i
in
range
(
0
,
self
.
pp_degree
)}
return
data_map
def
_dump_mapping
(
self
,
data_map
,
map_tag
=
None
):
...
...
@@ -308,10 +275,8 @@ class DeepSpeedCheckpoint(object):
file_list
=
get_files
(
dir
)
for
file_prefix
in
[
MODEL_FILE_PREFIX
,
LAYER_FILE_PREFIX
,
f
'
{
LAYER_FILE_PREFIX
}
01'
]:
for
file_prefix
in
[
MODEL_FILE_PREFIX
,
LAYER_FILE_PREFIX
,
f
'
{
LAYER_FILE_PREFIX
}
01'
]:
ckpt_files
=
get_files_with_prefix
(
file_list
,
file_prefix
)
assert
len
(
ckpt_files
)
>
0
,
f
'
{
dir
}
seems a bogus DeepSpeed checkpoint folder: Cannot find
{
file_prefix
}
* files in there.'
assert
len
(
ckpt_files
)
>
0
,
f
'
{
dir
}
seems a bogus DeepSpeed checkpoint folder: Cannot find
{
file_prefix
}
* files in there.'
deepspeed/checkpoint/reshape_3d_utils.py
View file @
5bcc463d
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
from
.reshape_utils
import
(
get_files
,
get_files_with_prefix
,
partition_data
,
get_zero_files
)
# DeepSpeed Team
from
.reshape_utils
import
(
get_files
,
get_files_with_prefix
,
partition_data
,
get_zero_files
)
from
.constants
import
(
MODEL_FILE_PREFIX
,
LAYER_FILE_PREFIX
)
...
...
@@ -15,6 +15,7 @@ DP_DIM = 'DP'
class
model_3d_desc
(
object
):
def
__init__
(
self
,
pp_degree
=
1
,
tp_degree
=
1
,
dp_degree
=
1
):
self
.
pp_degree
=
pp_degree
self
.
tp_degree
=
tp_degree
...
...
@@ -33,8 +34,7 @@ class model_3d_desc(object):
src_2d_size
=
self
.
pp_degree
*
self
.
tp_degree
,
dp_degree
=
self
.
dp_degree
)
return
unflatten_dp_dimension
(
meg_2d_map
=
flat_3d_map
,
dp_degree
=
target_3d_desc
.
dp_degree
)
return
unflatten_dp_dimension
(
meg_2d_map
=
flat_3d_map
,
dp_degree
=
target_3d_desc
.
dp_degree
)
def
get_desc
(
self
):
return
f
'
{
PP_DIM
}
,
{
TP_DIM
}
,
{
DP_DIM
}
= (
{
self
.
pp_degree
}
,
{
self
.
tp_degree
}
,
{
self
.
dp_degree
}
)'
...
...
@@ -45,14 +45,11 @@ class model_3d_desc(object):
def
is_valid
(
self
,
pp_index
,
tp_index
,
dp_index
):
err_msg
=
[]
valid
=
True
for
index
,
degree
,
dim_name
in
[
(
pp_index
,
self
.
pp_degree
,
PP_DIM
),
(
tp_index
,
self
.
tp_degree
,
TP_DIM
),
(
dp_index
,
self
.
dp_degree
,
DP_DIM
)]:
for
index
,
degree
,
dim_name
in
[(
pp_index
,
self
.
pp_degree
,
PP_DIM
),
(
tp_index
,
self
.
tp_degree
,
TP_DIM
),
(
dp_index
,
self
.
dp_degree
,
DP_DIM
)]:
if
index
>=
degree
:
valid
=
False
err_msg
.
append
(
f
'
{
dim_name
}
indexing error: index
{
index
}
>= degree
{
degree
}
'
)
err_msg
.
append
(
f
'
{
dim_name
}
indexing error: index
{
index
}
>= degree
{
degree
}
'
)
return
valid
,
err_msg
...
...
@@ -60,18 +57,15 @@ class model_3d_desc(object):
err_msg
=
[]
if
target_3d_desc
.
pp_degree
>
self
.
pp_degree
:
err_msg
.
append
(
f
'Expansion reshape not supported -
{
PP_DIM
}
:
{
self
.
pp_degree
}
--->
{
target_3d_desc
.
pp_degree
}
'
)
f
'Expansion reshape not supported -
{
PP_DIM
}
:
{
self
.
pp_degree
}
--->
{
target_3d_desc
.
pp_degree
}
'
)
if
target_3d_desc
.
tp_degree
>
self
.
tp_degree
:
err_msg
.
append
(
f
'Expansion reshape not supported -
{
TP_DIM
}
:
{
self
.
tp_degree
}
--->
{
target_3d_desc
.
tp_degree
}
'
)
f
'Expansion reshape not supported -
{
TP_DIM
}
:
{
self
.
tp_degree
}
--->
{
target_3d_desc
.
tp_degree
}
'
)
if
target_3d_desc
.
dp_degree
>
self
.
dp_degree
:
err_msg
.
append
(
f
'Expansion reshape not supported -
{
DP_DIM
}
:
{
self
.
dp_degree
}
--->
{
target_3d_desc
.
dp_degree
}
'
)
f
'Expansion reshape not supported -
{
DP_DIM
}
:
{
self
.
dp_degree
}
--->
{
target_3d_desc
.
dp_degree
}
'
)
return
len
(
err_msg
)
==
0
,
err_msg
...
...
@@ -106,10 +100,7 @@ def flatten_dp_dimension(meg_2d_map, src_2d_size, dp_degree):
def
unflatten_dp_dimension
(
meg_2d_map
,
dp_degree
):
pp_degree
=
meg_2d_map
.
pp_degree
tp_degree
=
meg_2d_map
.
tp_degree
meg_2d_map_list
=
[
meg_2d_parallel_map
(
pp_degree
=
pp_degree
,
tp_degree
=
tp_degree
)
for
_
in
range
(
dp_degree
)
]
meg_2d_map_list
=
[
meg_2d_parallel_map
(
pp_degree
=
pp_degree
,
tp_degree
=
tp_degree
)
for
_
in
range
(
dp_degree
)]
for
pp_index
in
range
(
pp_degree
):
for
tp_index
in
range
(
tp_degree
):
flat_dp_indices
=
meg_2d_map
.
get_data
(
pp_index
,
tp_index
)
...
...
deepspeed/checkpoint/reshape_meg_2d.py
View file @
5bcc463d
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from
.reshape_utils
import
partition_data
class
meg_2d_parallel_map
(
object
):
def
__init__
(
self
,
pp_degree
,
tp_degree
):
self
.
pp_degree
=
pp_degree
self
.
tp_degree
=
tp_degree
...
...
@@ -11,8 +15,7 @@ class meg_2d_parallel_map(object):
def
simple_init
(
self
):
self
.
map
=
{
self
.
_make_key
(
i
//
self
.
tp_degree
,
i
%
self
.
tp_degree
):
[
i
]
self
.
_make_key
(
i
//
self
.
tp_degree
,
i
%
self
.
tp_degree
):
[
i
]
for
i
in
range
(
self
.
pp_degree
*
self
.
tp_degree
)
}
...
...
@@ -74,11 +77,7 @@ def _reshape_pp_dimension(old_2d_map, new_pp_degree):
return
new_2d_map
def
reshape_meg_2d_parallel
(
old_pp_degree
,
old_tp_degree
,
new_pp_degree
,
new_tp_degree
,
verbose
=
False
):
def
reshape_meg_2d_parallel
(
old_pp_degree
,
old_tp_degree
,
new_pp_degree
,
new_tp_degree
,
verbose
=
False
):
assert
new_pp_degree
<=
old_pp_degree
assert
new_tp_degree
<=
old_tp_degree
...
...
@@ -137,8 +136,7 @@ def get_mpu_ranks(tp_size=1, pp_size=1, dp_size=1, virtual_pp_size=None):
tensor_model_parallel_size
=
min
(
tp_size
,
world_size
)
pipeline_model_parallel_size
=
min
(
pp_size
,
world_size
)
data_parallel_size
=
world_size
//
(
tensor_model_parallel_size
*
pipeline_model_parallel_size
)
data_parallel_size
=
world_size
//
(
tensor_model_parallel_size
*
pipeline_model_parallel_size
)
num_tensor_model_parallel_groups
=
world_size
//
tensor_model_parallel_size
num_pipeline_model_parallel_groups
=
world_size
//
pipeline_model_parallel_size
...
...
@@ -158,10 +156,7 @@ def get_mpu_ranks(tp_size=1, pp_size=1, dp_size=1, virtual_pp_size=None):
# Build the model-parallel groups.
all_pp_group_ranks
=
[]
for
i
in
range
(
data_parallel_size
):
ranks
=
[
data_parallel_group_ranks
[
i
]
for
data_parallel_group_ranks
in
all_dp_group_ranks
]
ranks
=
[
data_parallel_group_ranks
[
i
]
for
data_parallel_group_ranks
in
all_dp_group_ranks
]
all_pp_group_ranks
.
append
(
list
(
ranks
))
print
(
f
"PP"
,
all_pp_group_ranks
)
...
...
@@ -169,8 +164,7 @@ def get_mpu_ranks(tp_size=1, pp_size=1, dp_size=1, virtual_pp_size=None):
# Build the tensor model-parallel groups.
all_tp_group_ranks
=
[]
for
i
in
range
(
num_tensor_model_parallel_groups
):
ranks
=
range
(
i
*
tensor_model_parallel_size
,
(
i
+
1
)
*
tensor_model_parallel_size
)
ranks
=
range
(
i
*
tensor_model_parallel_size
,
(
i
+
1
)
*
tensor_model_parallel_size
)
all_tp_group_ranks
.
append
(
list
(
ranks
))
print
(
f
"TP"
,
all_tp_group_ranks
)
...
...
deepspeed/checkpoint/reshape_utils.py
View file @
5bcc463d
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import
os
import
torch
...
...
@@ -49,11 +52,7 @@ def partition_data(data_list, num_partitions):
num_elems
=
len
(
data_list
)
assert
num_elems
%
num_partitions
==
0
partition_size
=
num_elems
//
num_partitions
partitions_list
=
[
data_list
[
i
:
i
+
partition_size
]
for
i
in
range
(
0
,
num_elems
,
partition_size
)
]
partitions_list
=
[
data_list
[
i
:
i
+
partition_size
]
for
i
in
range
(
0
,
num_elems
,
partition_size
)]
return
partitions_list
...
...
@@ -76,9 +75,7 @@ def merge_state_dict(dict_a, dict_b, key_list):
def
merge_state_list
(
list_a
,
list_b
,
key_list
):
if
len
(
list_a
)
!=
len
(
list_b
):
print
(
f
'
{
_key_list_to_string
(
key_list
)
}
'
)
raise
ValueError
(
f
'Cannot merge lists of different lengths, a =
{
len
(
list_a
)
}
b =
{
len
(
list_b
)
}
'
)
raise
ValueError
(
f
'Cannot merge lists of different lengths, a =
{
len
(
list_a
)
}
b =
{
len
(
list_b
)
}
'
)
return
[
merge_state
(
a
,
b
,
key_list
)
for
a
,
b
in
zip
(
list_a
,
list_b
)]
...
...
@@ -87,8 +84,7 @@ def merge_state(state_a, state_b, key_list=[]):
if
type
(
state_a
)
!=
type
(
state_b
):
key_list_string
=
_key_list_to_string
(
key_list
)
print
(
f
'key_list =
{
key_list_string
}
'
)
raise
ValueError
(
f
'Cannot merge two states of types
{
type
(
state_a
)
}
and type
{
type
(
state_b
)
}
'
)
raise
ValueError
(
f
'Cannot merge two states of types
{
type
(
state_a
)
}
and type
{
type
(
state_b
)
}
'
)
if
type
(
state_a
)
in
(
dict
,
OrderedDict
):
return
merge_state_dict
(
state_a
,
state_b
,
key_list
)
...
...
deepspeed/checkpoint/universal_checkpoint.py
View file @
5bcc463d
"""
Copyright 2022 The Microsoft DeepSpeed Team
"""
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import
os
import
torch
import
types
from
.constants
import
(
FP32_WEIGHT_KEY
,
PARAM
,
VOCAB_DIVISIBILITY_PADDING_TENSOR
,
CAT_DIM
)
from
.constants
import
(
FP32_WEIGHT_KEY
,
PARAM
,
VOCAB_DIVISIBILITY_PADDING_TENSOR
,
CAT_DIM
)
def
load_hp_checkpoint_state
(
self
,
folder
,
tp_rank
,
tp_world_size
):
...
...
@@ -44,9 +43,7 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
# the converter to universal currently strips the original padding completely so the saved
# weight is padding-free and we just need to add new padding depending on the target TP
# degree
vocab_divisibility_padding_tensor
=
ckpt_dict
.
get
(
VOCAB_DIVISIBILITY_PADDING_TENSOR
,
None
)
vocab_divisibility_padding_tensor
=
ckpt_dict
.
get
(
VOCAB_DIVISIBILITY_PADDING_TENSOR
,
None
)
if
vocab_divisibility_padding_tensor
is
not
None
:
# In the absence of data passed from the user wrt new padded vocab specific to tp degree
# we can again derive that data by reverse engineering the target shapes like so:
...
...
@@ -56,13 +53,7 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
padding_size
=
padded_target_vocab_size
-
full_hp_param
.
shape
[
0
]
# Implement the following concat in efficient way using pad
#full_hp_param = torch.cat((full_hp_param, padding_tensor), 0)
full_hp_param
=
torch
.
nn
.
functional
.
pad
(
full_hp_param
,
(
0
,
0
,
0
,
padding_size
),
"constant"
,
0
)
full_hp_param
=
torch
.
nn
.
functional
.
pad
(
full_hp_param
,
(
0
,
0
,
0
,
padding_size
),
"constant"
,
0
)
full_hp_param
[:
-
padding_size
,
:]
=
vocab_divisibility_padding_tensor
else
:
# Need to shrink or keep the same
...
...
@@ -76,8 +67,7 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
assert
full_param_numel
==
tp_world_size
*
tp_slice_numel
,
\
f
'Loading
{
ckpt_file
}
full param numel
{
full_param_numel
}
!= tensor slice numel
{
tp_slice_numel
}
* tp_world_size
{
tp_world_size
}
'
dst_tensor
=
hp_mapping
.
hp_fragment
if
key
==
FP32_WEIGHT_KEY
else
hp_mapping
.
get_optim_state_fragment
(
key
)
dst_tensor
=
hp_mapping
.
hp_fragment
if
key
==
FP32_WEIGHT_KEY
else
hp_mapping
.
get_optim_state_fragment
(
key
)
# print(f"{full_hp_param.shape=} {full_param_numel=} {folder=}")
# print(f"{dst_tensor.shape=} {dst_tensor.numel()=}{folder=}")
...
...
@@ -90,9 +80,7 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
tp_hp_slice
=
tp_hp_slice
.
flatten
()
lp_frag_address
=
hp_mapping
.
lp_fragment_address
tp_hp_fragment
=
tp_hp_slice
.
narrow
(
0
,
lp_frag_address
.
start
,
lp_frag_address
.
numel
)
tp_hp_fragment
=
tp_hp_slice
.
narrow
(
0
,
lp_frag_address
.
start
,
lp_frag_address
.
numel
)
assert
dst_tensor
.
numel
()
==
lp_frag_address
.
numel
,
\
f
'Load checkpoint
{
key
}
dst_tensor numel
{
dst_tensor
.
numel
()
}
!= src numel
{
lp_frag_address
.
numel
}
'
...
...
@@ -104,5 +92,4 @@ def load_hp_checkpoint_state(self, folder, tp_rank, tp_world_size):
def
enable_universal_checkpoint
(
param_list
):
for
param
in
param_list
:
param
.
load_hp_checkpoint_state
=
types
.
MethodType
(
load_hp_checkpoint_state
,
param
)
param
.
load_hp_checkpoint_state
=
types
.
MethodType
(
load_hp_checkpoint_state
,
param
)
deepspeed/checkpoint/utils.py
View file @
5bcc463d
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import
os
from
.constants
import
(
MODEL_FILE_PREFIX
,
MODEL_FILE_SUFFIX
,
OPTIM_FILE_SUFFIX
,
ZERO_FILE_PREFIX
)
from
.constants
import
(
MODEL_FILE_PREFIX
,
MODEL_FILE_SUFFIX
,
OPTIM_FILE_SUFFIX
,
ZERO_FILE_PREFIX
)
def
get_model_ckpt_name_for_rank
(
base_folder
,
mp_rank_str
):
...
...
deepspeed/checkpoint/zero_checkpoint.py
View file @
5bcc463d
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import
torch
from
.constants
import
(
BASE_OPTIMIZER_STATE
,
GROUP_PADDINGS
,
OPTIMIZER_STATE_DICT
,
PARTITION_COUNT
)
from
.constants
import
(
BASE_OPTIMIZER_STATE
,
GROUP_PADDINGS
,
OPTIMIZER_STATE_DICT
,
PARTITION_COUNT
)
from
.reshape_utils
import
(
basic_folder_validation
,
get_zero_files
,
merge_state
)
...
...
@@ -15,6 +15,7 @@ GROUP_STATE_KEY = 'state'
class
ZeROCheckpoint
(
object
):
def
__init__
(
self
,
dir
):
basic_folder_validation
(
dir
)
self
.
dir
=
dir
...
...
@@ -49,12 +50,7 @@ class ZeROCheckpoint(object):
file_idx_list
=
self
.
get_file_indices_for_rank
(
pp_index
,
tp_index
,
dp_index
)
return
[
self
.
file_list
[
idx
]
for
idx
in
file_idx_list
]
def
get_state_for_rank
(
self
,
pp_index
,
tp_index
,
dp_index
,
keys_to_ignore
=
[],
strip_tensor_paddings
=
True
):
def
get_state_for_rank
(
self
,
pp_index
,
tp_index
,
dp_index
,
keys_to_ignore
=
[],
strip_tensor_paddings
=
True
):
state_file_list
=
self
.
get_files_for_rank
(
pp_index
,
tp_index
,
dp_index
)
merged_sd
=
None
for
state_file
in
state_file_list
:
...
...
@@ -111,10 +107,7 @@ class ZeROCheckpoint(object):
for
state_name
,
state_value
in
group_state
.
items
():
if
torch
.
is_tensor
(
state_value
):
raw_length
=
state_value
.
numel
()
-
group_paddings
[
key
]
group_state
[
state_name
]
=
torch
.
narrow
(
state_value
,
0
,
0
,
raw_length
).
clone
()
group_state
[
state_name
]
=
torch
.
narrow
(
state_value
,
0
,
0
,
raw_length
).
clone
()
def
_clear_group_paddings
(
self
,
sd
):
group_paddings
=
self
.
_get_optimizer_state
(
sd
,
GROUP_PADDINGS
)
...
...
@@ -144,5 +137,4 @@ class ZeROCheckpoint(object):
partition_counts
=
self
.
_get_optimizer_state
(
sd
,
PARTITION_COUNT
)
if
partition_counts
:
num_groups
=
len
(
partition_counts
)
sd
[
OPTIMIZER_STATE_DICT
][
PARTITION_COUNT
]
=
[
self
.
target_3d
.
dp_degree
]
*
num_groups
sd
[
OPTIMIZER_STATE_DICT
][
PARTITION_COUNT
]
=
[
self
.
target_3d
.
dp_degree
]
*
num_groups
deepspeed/comm/__init__.py
View file @
5bcc463d
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
import
torch
from
.utils
import
*
from
deepspeed
import
utils
supported_torch_version
=
False
# See more details at: https://github.com/pytorch/pytorch/pull/48767
# The PG API in torch versions lesser than 1.8 are different so it is
# non-trivial to support both in the same API. We will just use the
# DS comm. backend in deepspeed/comm/comm.py if torch version if 1.8+.
if
older_torch
():
# Add custom deepspeed torch comm functions here since we can't import deepspeed.comm
# NOTE: We can't call torch.distributed directly here. Current hack is to import functions before calling them.
supported_torch_version
=
False
from
torch.distributed
import
*
def
get_world_group
():
return
group
.
WORLD
# DeepSpeed Team
def
get_global_rank
(
group
,
group_rank
):
if
hasattr
(
torch
.
distributed
.
distributed_c10d
,
"get_global_rank"
):
from
torch.distributed.distributed_c10d
import
get_global_rank
as
_get_global_rank
else
:
from
torch.distributed.distributed_c10d
import
_get_global_rank
return
_get_global_rank
(
group
,
group_rank
)
def
allgather_fn
(
output_tensor
,
input_tensor
,
group
=
None
,
async_op
=
False
):
from
torch.distributed
import
all_gather
,
get_world_size
from
torch
import
chunk
output_tensors
=
list
(
chunk
(
output_tensor
,
get_world_size
(
group
)))
return
all_gather
(
output_tensors
,
input_tensor
,
group
=
group
,
async_op
=
async_op
)
def
reduce_scatter_fn
(
output_tensor
,
input_tensor
,
group
=
None
,
async_op
=
False
):
from
torch.distributed
import
reduce_scatter
,
get_world_size
from
torch
import
chunk
input_tensor_lst
=
list
(
chunk
(
input_tensor
,
get_world_size
(
group
)))
return
reduce_scatter
(
output_tensor
,
input_tensor_lst
,
group
=
group
)
def
configure
(
deepspeed_config
=
None
,
enabled
=
None
,
prof_all
=
None
,
prof_ops
=
None
,
verbose
=
None
):
utils
.
logger
.
warn
(
"Communication logging is not supported in torch versions older than 1.8"
)
else
:
supported_torch_version
=
True
from
.comm
import
*
from
.utils
import
*
from
.comm
import
*
deepspeed/comm/backend.py
View file @
5bcc463d
'''Copyright The Microsoft DeepSpeed Team'''
''' DeepSpeed Communication Backend.
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
"""
DeepSpeed Communication Backend.
In the future, directly use NCCL/MPI/Gloo/etc without requiring torch.distributed. Simply wrap torch.distributed for now.
...
...
@@ -8,18 +12,18 @@ In the future, directly use NCCL/MPI/Gloo/etc without requiring torch.distribute
- MPI -- [EXPERIMENTAL]
- RCCL -- [EXPERIMENTAL]
- GLOO -- [EXPERIMENTAL]
# DS backend wrapper for torch.distributed [DEFAULT]
- T-NCCL -- [DEFAULT]
- T-GLOO
- T-MPI
'''
'''
Backend is the base class
#
Backend is the base class
-- NcclBackend, MpiBackend, and TorchBackend are the main subclasses. TorchBackend is the only officially supported backend for now.
'''
"""
class
Backend
(
object
):
def
__init__
(
self
,
name
=
'backend'
,
rank
=
0
,
size
=
1
):
self
.
name
=
name
# The world size and rank of the world process group
...
...
deepspeed/comm/comm.py
View file @
5bcc463d
"""
Copyright 2021 The Microsoft DeepSpeed Team
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
"""
DeepSpeed Communication Package: deepspeed.comm
deepspeed.comm
-- import and use deepspeeed.ops.comm
-- use torch.distributed directly if both this package and torch.distributed use the same NCCL version
-- use custom collectives
-- can either use torch.dist or ds.ops.comm?
Note: the old 1-bit compressed allreduce variants that resided in deepspeed.runtime.comm will be moved here as well.
deepspeed.comm API
-- must be kept fully compatible (same signatures) as torch.dist API to ensure backward/cross-framework compatibility.
-- e.g. if a client code used
from deepspeed import comm as dist
instead of
import torch.distributed as dist
The code should work without breaking any of the public torch.distributed functionality
Future:
-- deepspeed groups API should be brought into ds.comm
"""
...
...
@@ -63,10 +59,6 @@ timer_summary = {}
comms_logger
=
CommsLogger
()
# Ensure we don't warn about base collectives more than once
has_warned_all_gather
=
False
has_warned_reduce_scatter
=
False
# Maintain objects of all initialized ds backends and assign them using the API functions in this file
nccl_backend
=
None
mpi_backend
=
None
...
...
@@ -110,12 +102,13 @@ def configure(
# Logging wrapper for timing ops
def
timed_op
(
func
):
def
log_wrapper
(
*
args
,
**
kwargs
):
# Add enabled flag so that overhead to each comm op is two if conditions at most
if
comms_logger
.
enabled
:
if
(
'prof'
in
kwargs
and
kwargs
[
'prof'
])
or
comms_logger
.
prof_all
or
(
'log_name'
in
kwargs
and
kwargs
[
'log_name'
]
in
comms_logger
.
prof_ops
):
if
(
'prof'
in
kwargs
and
kwargs
[
'prof'
])
or
comms_logger
.
prof_all
or
(
'log_name'
in
kwargs
and
kwargs
[
'log_name'
]
in
comms_logger
.
prof_ops
):
# Need func args for their defaults
func_args
=
get_default_args
(
func
)
func_args
.
update
(
kwargs
)
...
...
@@ -133,8 +126,7 @@ def timed_op(func):
if
cdb
.
using_mpi
:
cdb
.
barrier
()
if
(
'prof'
in
kwargs
and
kwargs
[
'prof'
])
or
comms_logger
.
prof_all
or
(
'log_name'
in
kwargs
and
kwargs
[
'log_name'
]
in
comms_logger
.
prof_ops
):
'log_name'
in
kwargs
and
kwargs
[
'log_name'
]
in
comms_logger
.
prof_ops
):
log_name
=
get_debug_log_name
(
func_args
,
comms_logger
.
debug
)
raw_name
=
func
.
__name__
timers
(
log_name
).
stop
()
...
...
@@ -182,7 +174,8 @@ def destroy_process_group(group=None):
def
new_group
(
ranks
):
global
cdb
assert
cdb
is
not
None
and
cdb
.
is_initialized
(),
'DeepSpeed backend not set, please initialize it using init_process_group()'
assert
cdb
is
not
None
and
cdb
.
is_initialized
(
),
'DeepSpeed backend not set, please initialize it using init_process_group()'
return
cdb
.
new_group
(
ranks
)
...
...
@@ -196,14 +189,12 @@ def is_available() -> bool:
return
True
def
set_backend
(
backend
):
def
set_backend
(
backend
_name
):
if
not
use_ds_backend
:
utils
.
logger
.
error
(
"DeepSpeed communication backend is required. Please use deepspeed.comm.init_distributed(backend, use_deepspeed=True) to use this functionality"
)
raise
RuntimeError
(
'Error: Custom DeepSpeed backend called without initializing DeepSpeed distributed.'
)
raise
RuntimeError
(
'Error: Custom DeepSpeed backend called without initializing DeepSpeed distributed.'
)
global
cdb
global
nccl_backend
...
...
@@ -221,13 +212,7 @@ def set_backend(backend):
@
timed_op
def
broadcast
(
tensor
,
src
,
group
=
None
,
async_op
=
False
,
prof
=
False
,
log_name
=
'broadcast'
,
debug
=
get_caller_func
()):
def
broadcast
(
tensor
,
src
,
group
=
None
,
async_op
=
False
,
prof
=
False
,
log_name
=
'broadcast'
,
debug
=
get_caller_func
()):
global
cdb
return
cdb
.
broadcast
(
tensor
=
tensor
,
src
=
src
,
group
=
group
,
async_op
=
async_op
)
...
...
@@ -241,17 +226,14 @@ def all_gather(tensor_list,
log_name
=
'all_gather'
,
debug
=
get_caller_func
()):
global
cdb
return
cdb
.
all_gather
(
tensor_list
=
tensor_list
,
tensor
=
tensor
,
group
=
group
,
async_op
=
async_op
)
return
cdb
.
all_gather
(
tensor_list
=
tensor_list
,
tensor
=
tensor
,
group
=
group
,
async_op
=
async_op
)
def
has_reduce_scatter_
base
():
def
has_reduce_scatter_
tensor
():
global
cdb
assert
cdb
is
not
None
and
cdb
.
is_initialized
(
),
'DeepSpeed backend not set, please initialize it using init_process_group()'
assert
cdb
.
has_reduce_scatter_base
is
not
None
,
'has_reduce_scatter_base is not yet defined
'
return
cdb
.
has_reduce_scatter_
base
assert
cdb
is
not
None
and
cdb
.
is_initialized
(
),
'DeepSpeed backend not set, please initialize it using init_process_group()
'
return
cdb
.
has_reduce_scatter_
tensor
()
def
reduce_scatter_fn
(
output_tensor
,
...
...
@@ -262,23 +244,21 @@ def reduce_scatter_fn(output_tensor,
prof
=
False
,
debug
=
get_caller_func
()):
global
cdb
global
has_warned_reduce_scatter
assert
cdb
is
not
None
and
cdb
.
is_initialized
(
),
'DeepSpeed backend not set, please initialize it using init_process_group()'
if
cdb
.
has_reduce_scatter_
base
:
return
reduce_scatter_
base
(
output_tensor
,
tensor
,
op
=
op
,
group
=
group
,
async_op
=
async_op
,
prof
=
prof
,
debug
=
debug
)
assert
cdb
is
not
None
and
cdb
.
is_initialized
(
),
'DeepSpeed backend not set, please initialize it using init_process_group()'
if
cdb
.
has_reduce_scatter_
tensor
()
:
return
reduce_scatter_
tensor
(
output_tensor
,
tensor
,
op
=
op
,
group
=
group
,
async_op
=
async_op
,
prof
=
prof
,
debug
=
debug
)
else
:
if
not
has_warned_reduce_scatter
:
utils
.
logger
.
warning
(
"unable to find torch.distributed._reduce_scatter_base. will fall back to "
"torch.distributed.all_gather which will result in suboptimal performance. "
"please consider upgrading your pytorch installation."
)
has_warned_reduce_scatter
=
True
if
get_rank
()
==
0
:
utils
.
logger
.
warning_once
(
"unable to find torch.distributed.reduce_scatter_tensor. will fall back to "
"torch.distributed.all_gather which will result in suboptimal performance. "
"please consider upgrading your pytorch installation."
)
input_tensor_lst
=
list
(
torch
.
chunk
(
tensor
,
cdb
.
get_world_size
(
group
)))
return
reduce_scatter
(
output_tensor
,
input_tensor_lst
,
...
...
@@ -290,71 +270,54 @@ def reduce_scatter_fn(output_tensor,
@
timed_op
def
reduce_scatter_
base
(
output_tensor
,
tensor
,
op
=
ReduceOp
.
SUM
,
group
=
None
,
async_op
=
False
,
prof
=
False
,
log_name
=
'reduce_scatter_
base
'
,
debug
=
get_caller_func
()):
global
cdb
return
cdb
.
reduce_scatter_
base
(
output_tensor
=
output_tensor
,
input_tensor
=
tensor
,
op
=
op
,
group
=
group
,
async_op
=
async_op
)
def
reduce_scatter_
tensor
(
output_tensor
,
tensor
,
op
=
ReduceOp
.
SUM
,
group
=
None
,
async_op
=
False
,
prof
=
False
,
log_name
=
'reduce_scatter_
tensor
'
,
debug
=
get_caller_func
()):
global
cdb
return
cdb
.
reduce_scatter_
tensor
(
output_tensor
=
output_tensor
,
input_tensor
=
tensor
,
op
=
op
,
group
=
group
,
async_op
=
async_op
)
@
timed_op
def
all_gather_base
(
output_tensor
,
tensor
,
group
=
None
,
async_op
=
False
,
prof
=
False
,
log_name
=
'all_gather_base'
,
debug
=
get_caller_func
()):
global
cdb
return
cdb
.
all_gather_base
(
output_tensor
=
output_tensor
,
input_tensor
=
tensor
,
group
=
group
,
async_op
=
async_op
)
def
has_allgather_base
():
global
cdb
assert
cdb
is
not
None
and
cdb
.
is_initialized
(),
'DeepSpeed backend not set, please initialize it using init_process_group()'
assert
cdb
.
has_allgather_base
is
not
None
,
'has_allgather_base is not yet defined'
return
cdb
.
has_allgather_base
def
allgather_fn
(
output_tensor
,
input_tensor
,
group
=
None
,
async_op
=
False
,
debug
=
get_caller_func
()):
global
cdb
global
has_warned_all_gather
assert
cdb
is
not
None
and
cdb
.
is_initialized
(),
'DeepSpeed backend not set, please initialize it using init_process_group()'
if
cdb
.
has_allgather_base
:
return
all_gather_base
(
output_tensor
,
input_tensor
,
group
=
group
,
async_op
=
async_op
,
debug
=
debug
)
def
all_gather_into_tensor
(
output_tensor
,
tensor
,
group
=
None
,
async_op
=
False
,
prof
=
False
,
log_name
=
'all_gather_into_tensor'
,
debug
=
get_caller_func
()):
global
cdb
return
cdb
.
all_gather_into_tensor
(
output_tensor
=
output_tensor
,
input_tensor
=
tensor
,
group
=
group
,
async_op
=
async_op
)
def
has_all_gather_into_tensor
():
global
cdb
assert
cdb
is
not
None
and
cdb
.
is_initialized
(
),
'DeepSpeed backend not set, please initialize it using init_process_group()'
return
cdb
.
has_all_gather_into_tensor
()
def
allgather_fn
(
output_tensor
,
input_tensor
,
group
=
None
,
async_op
=
False
,
debug
=
get_caller_func
()):
global
cdb
assert
cdb
is
not
None
and
cdb
.
is_initialized
(
),
'DeepSpeed backend not set, please initialize it using init_process_group()'
if
cdb
.
has_all_gather_into_tensor
():
return
all_gather_into_tensor
(
output_tensor
,
input_tensor
,
group
=
group
,
async_op
=
async_op
,
debug
=
debug
)
else
:
if
not
has_warned_all_gather
and
get_rank
()
==
0
:
utils
.
logger
.
warning
(
"unable to find torch.distributed._all_gather_base. will fall back to "
"torch.distributed.all_gather which will result in suboptimal performance. "
"please consider upgrading your pytorch installation."
)
has_warned_all_gather
=
True
if
get_rank
()
==
0
:
utils
.
logger
.
warning_once
(
"unable to find torch.distributed.all_gather_into_tensor. will fall back to "
"torch.distributed.all_gather which will result in suboptimal performance. "
"please consider upgrading your pytorch installation."
)
output_tensors
=
list
(
torch
.
chunk
(
output_tensor
,
cdb
.
get_world_size
(
group
)))
return
all_gather
(
output_tensors
,
input_tensor
,
group
=
group
,
async_op
=
async_op
,
debug
=
debug
)
return
all_gather
(
output_tensors
,
input_tensor
,
group
=
group
,
async_op
=
async_op
,
debug
=
debug
)
@
timed_op
...
...
@@ -377,49 +340,25 @@ def all_to_all_single(output,
@
timed_op
def
send
(
tensor
,
dst
,
group
=
None
,
tag
=
0
,
prof
=
False
,
log_name
=
'send'
,
debug
=
get_caller_func
()):
def
send
(
tensor
,
dst
,
group
=
None
,
tag
=
0
,
prof
=
False
,
log_name
=
'send'
,
debug
=
get_caller_func
()):
global
cdb
return
cdb
.
send
(
tensor
=
tensor
,
dst
=
dst
,
group
=
group
,
tag
=
tag
)
@
timed_op
def
recv
(
tensor
,
src
=
None
,
group
=
None
,
tag
=
0
,
prof
=
False
,
log_name
=
'recv'
,
debug
=
get_caller_func
()):
def
recv
(
tensor
,
src
=
None
,
group
=
None
,
tag
=
0
,
prof
=
False
,
log_name
=
'recv'
,
debug
=
get_caller_func
()):
global
cdb
return
cdb
.
recv
(
tensor
=
tensor
,
src
=
src
,
group
=
group
,
tag
=
tag
)
@
timed_op
def
isend
(
tensor
,
dst
,
group
=
None
,
tag
=
0
,
prof
=
False
,
log_name
=
'isend'
,
debug
=
get_caller_func
()):
def
isend
(
tensor
,
dst
,
group
=
None
,
tag
=
0
,
prof
=
False
,
log_name
=
'isend'
,
debug
=
get_caller_func
()):
global
cdb
return
cdb
.
send
(
tensor
=
tensor
,
dst
=
dst
,
group
=
group
,
tag
=
tag
)
@
timed_op
def
irecv
(
tensor
,
src
=
None
,
group
=
None
,
tag
=
0
,
prof
=
False
,
log_name
=
'irecv'
,
debug
=
get_caller_func
()):
def
irecv
(
tensor
,
src
=
None
,
group
=
None
,
tag
=
0
,
prof
=
False
,
log_name
=
'irecv'
,
debug
=
get_caller_func
()):
global
cdb
return
cdb
.
recv
(
tensor
=
tensor
,
src
=
src
,
group
=
group
,
tag
=
tag
)
...
...
@@ -434,11 +373,7 @@ def gather(tensor,
log_name
=
'gather'
,
debug
=
get_caller_func
()):
global
cdb
return
cdb
.
gather
(
tensor
=
tensor
,
gather_list
=
gather_list
,
dst
=
dst
,
group
=
group
,
async_op
=
async_op
)
return
cdb
.
gather
(
tensor
=
tensor
,
gather_list
=
gather_list
,
dst
=
dst
,
group
=
group
,
async_op
=
async_op
)
@
timed_op
...
...
@@ -451,20 +386,11 @@ def scatter(tensor,
log_name
=
'scatter'
,
debug
=
get_caller_func
()):
global
cdb
return
cdb
.
scatter
(
tensor
=
tensor
,
scatter_list
=
scatter_list
,
src
=
src
,
group
=
group
,
async_op
=
async_op
)
return
cdb
.
scatter
(
tensor
=
tensor
,
scatter_list
=
scatter_list
,
src
=
src
,
group
=
group
,
async_op
=
async_op
)
@
timed_op
def
barrier
(
group
=
None
,
async_op
=
False
,
device_ids
=
None
,
prof
=
False
,
log_name
=
'barrier'
,
debug
=
get_caller_func
()):
def
barrier
(
group
=
None
,
async_op
=
False
,
device_ids
=
None
,
prof
=
False
,
log_name
=
'barrier'
,
debug
=
get_caller_func
()):
global
cdb
return
cdb
.
barrier
(
group
=
group
,
async_op
=
async_op
,
device_ids
=
device_ids
)
...
...
@@ -511,11 +437,31 @@ def reduce_scatter(output,
log_name
=
'reduce_scatter'
,
debug
=
get_caller_func
()):
global
cdb
return
cdb
.
reduce_scatter
(
output
=
output
,
input_list
=
input_list
,
op
=
op
,
group
=
group
,
async_op
=
async_op
)
return
cdb
.
reduce_scatter
(
output
=
output
,
input_list
=
input_list
,
op
=
op
,
group
=
group
,
async_op
=
async_op
)
def
has_all_reduce_coalesced
():
""""""
global
cdb
assert
cdb
is
not
None
and
cdb
.
is_initialized
(
),
'DeepSpeed backend not set, please initialize it using init_process_group()'
assert
cdb
.
has_all_reduce_coalesced
is
not
None
,
'has_all_reduce_coalesced is not yet defined'
return
cdb
.
has_all_reduce_coalesced
def
has_coalescing_manager
():
global
cdb
assert
cdb
is
not
None
and
cdb
.
is_initialized
(
),
'DeepSpeed backend not set, please initialize it using init_process_group()'
assert
cdb
.
has_coalescing_manager
is
not
None
,
'has_coalescing_manager is not yet defined'
return
cdb
.
has_coalescing_manager
def
all_gather_coalesced
(
output_tensors
,
input_tensors
,
group
=
None
,
async_op
=
False
):
global
cdb
assert
cdb
is
not
None
and
cdb
.
is_initialized
(
),
'DeepSpeed backend not set, please initialize it using init_process_group()'
return
cdb
.
all_gather_coalesced
(
output_tensors
,
input_tensors
,
group
=
group
,
async_op
=
async_op
)
@
timed_op
...
...
@@ -535,9 +481,22 @@ def all_reduce(tensor,
return
cdb
.
all_reduce
(
tensor
,
op
,
group
,
async_op
)
@
timed_op
def
all_reduce_coalesced
(
tensors
,
op
=
ReduceOp
.
SUM
,
group
=
None
,
async_op
=
False
,
prof
=
False
,
log_name
=
'all_reduce'
,
debug
=
get_caller_func
()):
global
cbd
return
cdb
.
all_reduce_coalesced
(
tensors
,
op
,
group
,
async_op
)
def
get_world_group
():
global
cdb
assert
cdb
is
not
None
and
cdb
.
is_initialized
(),
'DeepSpeed backend not set, please initialize it using init_process_group()'
assert
cdb
is
not
None
and
cdb
.
is_initialized
(
),
'DeepSpeed backend not set, please initialize it using init_process_group()'
return
cdb
.
get_world_group
()
...
...
@@ -553,7 +512,8 @@ def get_world_size(group=None) -> int:
"""
global
cdb
assert
cdb
is
not
None
and
cdb
.
is_initialized
(),
'DeepSpeed backend not set, please initialize it using init_process_group()'
assert
cdb
is
not
None
and
cdb
.
is_initialized
(
),
'DeepSpeed backend not set, please initialize it using init_process_group()'
return
cdb
.
get_world_size
(
group
)
...
...
@@ -572,7 +532,8 @@ def get_rank(group=None):
-1, if not part of the group
"""
global
cdb
assert
cdb
is
not
None
and
cdb
.
is_initialized
(),
'DeepSpeed backend not set, please initialize it using init_process_group()'
assert
cdb
is
not
None
and
cdb
.
is_initialized
(
),
'DeepSpeed backend not set, please initialize it using init_process_group()'
return
cdb
.
get_rank
(
group
)
...
...
@@ -585,13 +546,15 @@ def get_local_rank():
local rank (= GPU device ID)
"""
global
cdb
assert
cdb
is
not
None
and
cdb
.
is_initialized
(),
'DeepSpeed backend not set, please initialize it using init_process_group()'
assert
cdb
is
not
None
and
cdb
.
is_initialized
(
),
'DeepSpeed backend not set, please initialize it using init_process_group()'
return
get_local_rank_from_launcher
()
def
get_global_rank
(
group
=
None
,
group_rank
=
0
):
global
cdb
assert
cdb
is
not
None
and
cdb
.
is_initialized
(),
'DeepSpeed backend not set, please initialize it using init_process_group()'
assert
cdb
is
not
None
and
cdb
.
is_initialized
(
),
'DeepSpeed backend not set, please initialize it using init_process_group()'
return
cdb
.
get_global_rank
(
group
,
group_rank
)
...
...
@@ -640,9 +603,7 @@ def init_distributed(dist_backend=None,
required_env
=
[
"RANK"
,
"WORLD_SIZE"
,
"MASTER_ADDR"
,
"MASTER_PORT"
,
"LOCAL_RANK"
]
if
auto_mpi_discovery
and
not
all
(
map
(
lambda
v
:
v
in
os
.
environ
,
required_env
)):
if
verbose
:
utils
.
logger
.
info
(
"Not using the DeepSpeed or dist launchers, attempting to detect MPI environment..."
)
utils
.
logger
.
info
(
"Not using the DeepSpeed or dist launchers, attempting to detect MPI environment..."
)
if
in_aml
()
and
not
in_dlts
():
patch_aml_env_for_torch_nccl_backend
(
verbose
=
verbose
)
elif
in_aws_sm
():
...
...
@@ -658,9 +619,7 @@ def init_distributed(dist_backend=None,
if
dist_backend
==
None
:
dist_backend
=
get_accelerator
().
communication_backend_name
()
if
int
(
os
.
getenv
(
'RANK'
,
'0'
))
==
0
:
utils
.
logger
.
info
(
'Initializing TorchBackend in DeepSpeed with backend {}'
.
format
(
dist_backend
))
utils
.
logger
.
info
(
'Initializing TorchBackend in DeepSpeed with backend {}'
.
format
(
dist_backend
))
# Create a torch backend object, initialize torch distributed, and assign to cdb
cdb
=
TorchBackend
(
dist_backend
,
timeout
,
init_method
,
rank
,
world_size
)
...
...
@@ -695,16 +654,12 @@ def mpi_discovery(distributed_port=TORCH_DISTRIBUTED_DEFAULT_PORT, verbose=True)
if
verbose
:
utils
.
logger
.
info
(
"Discovered MPI settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}"
.
format
(
os
.
environ
[
'RANK'
],
os
.
environ
[
'LOCAL_RANK'
],
os
.
environ
[
'WORLD_SIZE'
],
os
.
environ
[
'MASTER_ADDR'
],
os
.
environ
[
'MASTER_PORT'
]))
"Discovered MPI settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}"
.
format
(
os
.
environ
[
'RANK'
],
os
.
environ
[
'LOCAL_RANK'
],
os
.
environ
[
'WORLD_SIZE'
],
os
.
environ
[
'MASTER_ADDR'
],
os
.
environ
[
'MASTER_PORT'
]))
if
cdb
is
not
None
and
cdb
.
is_initialized
():
assert
cdb
.
get_rank
()
==
rank
,
"MPI rank {} does not match torch rank {}"
.
format
(
rank
,
cdb
.
get_rank
())
assert
cdb
.
get_rank
()
==
rank
,
"MPI rank {} does not match torch rank {}"
.
format
(
rank
,
cdb
.
get_rank
())
assert
cdb
.
get_world_size
()
==
world_size
,
"MPI world size {} does not match torch world size {}"
.
format
(
world_size
,
cdb
.
get_world_size
())
...
...
@@ -731,8 +686,7 @@ def patch_aml_env_for_torch_nccl_backend(master_port=6105, verbose=True):
"""
os
.
environ
[
"RANK"
]
=
os
.
environ
[
"OMPI_COMM_WORLD_RANK"
]
os
.
environ
[
"WORLD_SIZE"
]
=
os
.
environ
[
"OMPI_COMM_WORLD_SIZE"
]
single_node
=
int
(
os
.
environ
[
"OMPI_COMM_WORLD_LOCAL_SIZE"
])
==
int
(
os
.
environ
[
"WORLD_SIZE"
])
single_node
=
int
(
os
.
environ
[
"OMPI_COMM_WORLD_LOCAL_SIZE"
])
==
int
(
os
.
environ
[
"WORLD_SIZE"
])
if
not
single_node
:
master_node_params
=
os
.
environ
[
"AZ_BATCH_MASTER_NODE"
].
split
(
":"
)
...
...
@@ -745,8 +699,7 @@ def patch_aml_env_for_torch_nccl_backend(master_port=6105, verbose=True):
os
.
environ
[
"MASTER_PORT"
]
=
DEFAULT_AML_MASTER_PORT
if
verbose
:
utils
.
logger
.
info
(
"NCCL_SOCKET_IFNAME original value = {}"
.
format
(
os
.
environ
[
"NCCL_SOCKET_IFNAME"
]))
utils
.
logger
.
info
(
"NCCL_SOCKET_IFNAME original value = {}"
.
format
(
os
.
environ
[
"NCCL_SOCKET_IFNAME"
]))
os
.
environ
[
"NCCL_SOCKET_IFNAME"
]
=
DEFAULT_AML_NCCL_SOCKET_IFNAME
os
.
environ
[
'LOCAL_RANK'
]
=
os
.
environ
[
"OMPI_COMM_WORLD_LOCAL_RANK"
]
...
...
@@ -754,10 +707,7 @@ def patch_aml_env_for_torch_nccl_backend(master_port=6105, verbose=True):
if
verbose
:
utils
.
logger
.
info
(
"Discovered AzureML settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}"
.
format
(
os
.
environ
[
'RANK'
],
os
.
environ
[
'LOCAL_RANK'
],
os
.
environ
[
'WORLD_SIZE'
],
os
.
environ
[
'MASTER_ADDR'
],
.
format
(
os
.
environ
[
'RANK'
],
os
.
environ
[
'LOCAL_RANK'
],
os
.
environ
[
'WORLD_SIZE'
],
os
.
environ
[
'MASTER_ADDR'
],
os
.
environ
[
'MASTER_PORT'
]))
...
...
@@ -771,8 +721,5 @@ def patch_aws_sm_env_for_torch_nccl_backend(verbose=True):
if
verbose
:
utils
.
logger
.
info
(
"Discovered AWS SageMaker settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}"
.
format
(
os
.
environ
[
'RANK'
],
os
.
environ
[
'LOCAL_RANK'
],
os
.
environ
[
'WORLD_SIZE'
],
os
.
environ
[
'MASTER_ADDR'
],
.
format
(
os
.
environ
[
'RANK'
],
os
.
environ
[
'LOCAL_RANK'
],
os
.
environ
[
'WORLD_SIZE'
],
os
.
environ
[
'MASTER_ADDR'
],
os
.
environ
[
'MASTER_PORT'
]))
deepspeed/comm/config.py
View file @
5bcc463d
'''Copyright The Microsoft DeepSpeed Team'''
"""
Copyright (c) Microsoft Corporation
Licensed under the MIT license.
"""
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from
pydantic
import
BaseModel
from
.constants
import
*
class
CommsConfig
(
BaseModel
):
class
Config
:
validate_all
=
True
validate_assignment
=
True
...
...
@@ -25,6 +25,7 @@ class CommsLoggerConfig(CommsConfig):
class
DeepSpeedCommsConfig
:
def
__init__
(
self
,
ds_config
):
self
.
comms_logger_enabled
=
'comms_logger'
in
ds_config
...
...
deepspeed/comm/constants.py
View file @
5bcc463d
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
NCCL_BACKEND
=
'nccl'
MPI_BACKEND
=
'mpi'
...
...
deepspeed/comm/torch.py
View file @
5bcc463d
'''
Copyright 2021 The Microsoft DeepSpeed Team
'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from
deepspeed
import
utils
from
.utils
import
*
from
.backend
import
*
from
.comm
import
*
import
torch
def
is_torch_two
():
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
0
])
if
TORCH_MAJOR
>=
2
:
return
True
else
:
return
False
def
has_coalescing_manager
():
has_c10d
=
hasattr
(
torch
.
distributed
,
'distributed_c10d'
)
return
has_c10d
and
hasattr
(
torch
.
distributed
.
distributed_c10d
,
'_coalescing_manager'
)
def
has_all_reduce_coalesced
():
return
hasattr
(
torch
.
distributed
,
"all_reduce_coalesced"
)
def
get_coalescing_manager
(
group
,
device
,
reqs
):
if
is_torch_two
():
return
torch
.
distributed
.
distributed_c10d
.
_coalescing_manager
(
group
,
device
=
device
,
reqs
=
reqs
)
else
:
return
torch
.
distributed
.
distributed_c10d
.
_coalescing_manager
(
group
,
reqs
)
class
TorchBackend
(
Backend
):
"""
...
...
@@ -16,17 +44,13 @@ class TorchBackend(Backend):
so no need to wrap all the functions. We can keep adding wrappers as
needed.
"""
def
__init__
(
self
,
backend
,
timeout
,
init_method
,
rank
=-
1
,
world_size
=-
1
,
name
=
'torch'
):
def
__init__
(
self
,
backend
,
timeout
,
init_method
,
rank
=-
1
,
world_size
=-
1
,
name
=
'torch'
):
super
(
TorchBackend
,
self
).
__init__
()
self
.
torch_version_before_18
=
older_torch
()
self
.
has_allgather_base
=
has_allgather_base
()
self
.
has_reduce_scatter_base
=
has_reduce_scatter_base
()
self
.
has_all_reduce_coalesced
=
has_all_reduce_coalesced
()
self
.
has_coalescing_manager
=
has_coalescing_manager
()
self
.
all_gather_function
=
self
.
get_all_gather_function
()
self
.
reduce_scatter_function
=
self
.
get_reduce_scatter_function
()
self
.
initialized
=
True
self
.
name
=
name
# Future functionality to support ds.initialize() on a single GPU
...
...
@@ -35,6 +59,28 @@ class TorchBackend(Backend):
self
.
single_gpu_mode
=
True
self
.
init_process_group
(
backend
,
timeout
,
init_method
,
rank
,
world_size
)
@
classmethod
def
get_all_gather_function
(
self
):
if
hasattr
(
torch
.
distributed
,
"all_gather_into_tensor"
):
return
torch
.
distributed
.
all_gather_into_tensor
elif
hasattr
(
torch
.
distributed
,
"_all_gather_base"
):
return
torch
.
distributed
.
_all_gather_base
return
None
@
classmethod
def
get_reduce_scatter_function
(
self
):
if
hasattr
(
torch
.
distributed
,
"reduce_scatter_tensor"
):
return
torch
.
distributed
.
reduce_scatter_tensor
elif
hasattr
(
torch
.
distributed
,
"_reduce_scatter_base"
):
return
torch
.
distributed
.
_reduce_scatter_base
return
None
def
has_all_gather_into_tensor
(
self
):
return
self
.
all_gather_function
is
not
None
def
has_reduce_scatter_tensor
(
self
):
return
self
.
reduce_scatter_function
is
not
None
def
init_process_group
(
self
,
backend
,
timeout
,
init_method
,
rank
,
world_size
):
if
not
torch
.
distributed
.
is_initialized
():
torch
.
distributed
.
init_process_group
(
backend
,
...
...
@@ -44,30 +90,24 @@ class TorchBackend(Backend):
world_size
=
world_size
)
self
.
using_mpi
=
torch
.
distributed
.
get_backend
()
==
'mpi'
def
all_reduce
(
self
,
tensor
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
None
,
async_op
=
False
):
def
all_reduce
(
self
,
tensor
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
None
,
async_op
=
False
):
op
=
self
.
_reduce_op
(
op
)
return
torch
.
distributed
.
all_reduce
(
tensor
=
tensor
,
op
=
op
,
group
=
group
,
async_op
=
async_op
)
return
torch
.
distributed
.
all_reduce
(
tensor
=
tensor
,
op
=
op
,
group
=
group
,
async_op
=
async_op
)
def
all_reduce_coalesced
(
self
,
tensors
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
None
,
async_op
=
False
):
""" proxy func to torch.distributed.all_reduce_coalesced,
which is included in PyTorch 1.13 and above
"""
if
not
self
.
has_all_reduce_coalesced
:
raise
RuntimeError
(
f
"Current torch version does not have all_reduce_coalesced "
f
"api (torch.__version__:
{
torch
.
__version__
}
)"
)
op
=
self
.
_reduce_op
(
op
)
return
torch
.
distributed
.
all_reduce_coalesced
(
tensors
=
tensors
,
op
=
op
,
group
=
group
,
async_op
=
async_op
)
def
reduce
(
self
,
tensor
,
dst
,
op
=
ReduceOp
.
SUM
,
group
=
None
,
async_op
=
False
):
return
torch
.
distributed
.
reduce
(
tensor
=
tensor
,
dst
=
dst
,
op
=
self
.
_reduce_op
(
op
),
group
=
group
,
async_op
=
async_op
)
return
torch
.
distributed
.
reduce
(
tensor
=
tensor
,
dst
=
dst
,
op
=
self
.
_reduce_op
(
op
),
group
=
group
,
async_op
=
async_op
)
def
reduce_scatter
(
self
,
output
,
input_list
,
op
=
ReduceOp
.
SUM
,
group
=
None
,
async_op
=
False
):
def
reduce_scatter
(
self
,
output
,
input_list
,
op
=
ReduceOp
.
SUM
,
group
=
None
,
async_op
=
False
):
return
torch
.
distributed
.
reduce_scatter
(
output
=
output
,
input_list
=
input_list
,
op
=
self
.
_reduce_op
(
op
),
...
...
@@ -75,48 +115,57 @@ class TorchBackend(Backend):
async_op
=
async_op
)
def
broadcast
(
self
,
tensor
,
src
,
group
=
None
,
async_op
=
False
):
return
torch
.
distributed
.
broadcast
(
tensor
=
tensor
,
src
=
src
,
group
=
group
,
async_op
=
async_op
)
return
torch
.
distributed
.
broadcast
(
tensor
=
tensor
,
src
=
src
,
group
=
group
,
async_op
=
async_op
)
def
all_gather
(
self
,
tensor_list
,
tensor
,
group
=
None
,
async_op
=
False
):
return
torch
.
distributed
.
all_gather
(
tensor_list
=
tensor_list
,
tensor
=
tensor
,
return
torch
.
distributed
.
all_gather
(
tensor_list
=
tensor_list
,
tensor
=
tensor
,
group
=
group
,
async_op
=
async_op
)
def
all_gather_into_tensor
(
self
,
output_tensor
,
input_tensor
,
group
=
None
,
async_op
=
False
):
if
self
.
has_all_gather_into_tensor
():
return
self
.
all_gather_function
(
output_tensor
=
output_tensor
,
input_tensor
=
input_tensor
,
group
=
group
,
async_op
=
async_op
)
def
all_gather_base
(
self
,
output_tensor
,
input_tensor
,
group
=
None
,
async_op
=
False
):
if
self
.
has_allgather_base
:
return
torch
.
distributed
.
distributed_c10d
.
_all_gather_base
(
output_tensor
=
output_tensor
,
input_tensor
=
input_tensor
,
group
=
group
,
async_op
=
async_op
)
else
:
utils
.
logger
.
warning
(
"unable to find torch.distributed._all_gather_base. will fall back to "
"torch.distributed.reduce_scatter which will result in suboptimal performance. "
"please consider upgrading your pytorch installation."
)
utils
.
logger
.
warning
(
"unable to find torch.distributed._all_gather_base. will fall back to "
"torch.distributed.all_gather which will result in suboptimal performance. "
"please consider upgrading your pytorch installation."
)
pass
def
reduce_scatter_base
(
self
,
output_tensor
,
input_tensor
,
op
=
ReduceOp
.
SUM
,
group
=
None
,
async_op
=
False
):
if
self
.
has_reduce_scatter_base
:
return
torch
.
distributed
.
_reduce_scatter_base
(
output_tensor
,
input_tensor
,
op
=
self
.
_reduce_op
(
op
),
group
=
group
,
async_op
=
async_op
)
def
all_gather_coalesced
(
self
,
output_tensors
,
input_tensors
,
group
=
None
,
async_op
=
False
):
""""""
assert
len
(
output_tensors
)
==
len
(
input_tensors
),
""
if
hasattr
(
torch
.
distributed
.
distributed_c10d
,
'_all_gather_base_coalesced'
):
# customized PyTorch
return
torch
.
distributed
.
distributed_c10d
.
_all_gather_base_coalesced
(
output_tensors
,
input_tensors
,
group
=
group
,
async_op
=
async_op
)
elif
has_coalescing_manager
():
reqs
=
[]
with
get_coalescing_manager
(
group
,
input_tensors
[
0
].
device
,
reqs
):
for
output
,
input
in
zip
(
output_tensors
,
input_tensors
):
handle
=
torch
.
distributed
.
distributed_c10d
.
all_gather_into_tensor
(
output
,
input
,
group
=
group
,
async_op
=
True
)
reqs
.
append
(
handle
)
if
async_op
:
return
reqs
[
-
1
]
else
:
reqs
[
-
1
].
wait
()
def
reduce_scatter_tensor
(
self
,
output_tensor
,
input_tensor
,
op
=
ReduceOp
.
SUM
,
group
=
None
,
async_op
=
False
):
if
self
.
has_reduce_scatter_tensor
():
return
self
.
reduce_scatter_function
(
output_tensor
,
input_tensor
,
op
=
self
.
_reduce_op
(
op
),
group
=
group
,
async_op
=
async_op
)
else
:
utils
.
logger
.
warning
(
"unable to find torch.distributed._reduce_scatter_base. will fall back to "
"torch.distributed.reduce_scatter which will result in suboptimal performance. "
"please consider upgrading your pytorch installation."
)
utils
.
logger
.
warning
(
"unable to find torch.distributed.reduce_scatter_tensor. will fall back to "
"torch.distributed.reduce_scatter which will result in suboptimal performance. "
"please consider upgrading your pytorch installation."
)
pass
def
all_to_all_single
(
self
,
...
...
@@ -159,25 +208,15 @@ class TorchBackend(Backend):
group
=
group
,
async_op
=
async_op
)
def
barrier
(
self
,
group
=
torch
.
distributed
.
GroupMember
.
WORLD
,
async_op
=
False
,
device_ids
=
None
):
def
barrier
(
self
,
group
=
torch
.
distributed
.
GroupMember
.
WORLD
,
async_op
=
False
,
device_ids
=
None
):
if
group
is
None
:
group
=
torch
.
distributed
.
GroupMember
.
WORLD
return
torch
.
distributed
.
barrier
(
group
=
group
,
async_op
=
async_op
,
device_ids
=
device_ids
)
def
monitored_barrier
(
self
,
group
=
torch
.
distributed
.
GroupMember
.
WORLD
,
timeout
=
None
,
wait_all_ranks
=
False
):
return
torch
.
distributed
.
barrier
(
group
=
group
,
async_op
=
async_op
,
device_ids
=
device_ids
)
def
monitored_barrier
(
self
,
group
=
torch
.
distributed
.
GroupMember
.
WORLD
,
timeout
=
None
,
wait_all_ranks
=
False
):
if
group
is
None
:
group
=
torch
.
distributed
.
GroupMember
.
WORLD
return
torch
.
distributed
.
monitored_barrier
(
group
=
group
,
timeout
=
timeout
,
wait_all_ranks
=
wait_all_ranks
)
return
torch
.
distributed
.
monitored_barrier
(
group
=
group
,
timeout
=
timeout
,
wait_all_ranks
=
wait_all_ranks
)
def
get_rank
(
self
,
group
=
None
):
return
torch
.
distributed
.
get_rank
(
group
=
group
)
...
...
deepspeed/comm/utils.py
View file @
5bcc463d
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import
os
import
torch
import
inspect
from
deepspeed.utils
import
get_caller_func
def
older_torch
():
'''
Helper to lookup torch version. For versions less than 1.8, torch.dist
used torch.distributed.group.WORLD as the default group argument instead of None.
See more details at: https://github.com/pytorch/pytorch/pull/48767
'''
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
0
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
1
])
if
TORCH_MAJOR
==
1
and
TORCH_MINOR
<
8
:
return
True
else
:
return
False
def
has_allgather_base
():
'''
Helper to check if torch.distributed has _all_gather_base
'''
return
hasattr
(
torch
.
distributed
,
"_all_gather_base"
)
def
has_reduce_scatter_base
():
'''
Helper to check if torch.distributed has _reduce_scatter_base
'''
return
hasattr
(
torch
.
distributed
,
"_reduce_scatter_base"
)
def
get_local_rank_from_launcher
():
# DeepSpeed launcher will set it so get from there
...
...
@@ -84,11 +58,7 @@ def get_world_size_from_launcher():
def
get_default_args
(
func
):
signature
=
inspect
.
signature
(
func
)
return
{
k
:
v
.
default
for
k
,
v
in
signature
.
parameters
.
items
()
if
v
.
default
is
not
inspect
.
Parameter
.
empty
}
return
{
k
:
v
.
default
for
k
,
v
in
signature
.
parameters
.
items
()
if
v
.
default
is
not
inspect
.
Parameter
.
empty
}
# We need this hacky function since torch doesn't consistently name or place the input tensor args
...
...
deepspeed/compression/__init__.py
View file @
5bcc463d
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from
.compress
import
init_compression
,
redundancy_clean
from
.scheduler
import
compression_scheduler
...
...
deepspeed/compression/basic_layer.py
View file @
5bcc463d
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import
torch
import
math
...
...
@@ -21,6 +24,7 @@ class QuantAct(nn.Module):
Momentum for updating the activation quantization range.
quant_mode : str, default 'symmetric'
"""
def
__init__
(
self
,
act_range_momentum
=
0.95
,
quant_mode
=
'symmetric'
):
super
(
QuantAct
,
self
).
__init__
()
...
...
@@ -50,10 +54,8 @@ class QuantAct(nn.Module):
self
.
x_min_max
[
1
]
=
x_max
# if do not need momentum, please set self.act_range_momentum = 0
self
.
x_min_max
[
0
]
=
self
.
x_min_max
[
0
]
*
self
.
act_range_momentum
+
x_min
*
(
1
-
self
.
act_range_momentum
)
self
.
x_min_max
[
1
]
=
self
.
x_min_max
[
1
]
*
self
.
act_range_momentum
+
x_max
*
(
1
-
self
.
act_range_momentum
)
self
.
x_min_max
[
0
]
=
self
.
x_min_max
[
0
]
*
self
.
act_range_momentum
+
x_min
*
(
1
-
self
.
act_range_momentum
)
self
.
x_min_max
[
1
]
=
self
.
x_min_max
[
1
]
*
self
.
act_range_momentum
+
x_max
*
(
1
-
self
.
act_range_momentum
)
x_q
=
self
.
act_function
(
x
,
num_bits
,
self
.
x_min_max
[
0
],
self
.
x_min_max
[
1
])
...
...
@@ -61,6 +63,7 @@ class QuantAct(nn.Module):
class
Embedding_Compress
(
nn
.
Embedding
):
def
__init__
(
self
,
*
kargs
):
super
(
Embedding_Compress
,
self
).
__init__
(
*
kargs
)
self
.
weight
.
start_bits
=
None
...
...
@@ -71,17 +74,10 @@ class Embedding_Compress(nn.Embedding):
def
extra_repr
(
self
):
return
'num_embeddings={}, embedding_dim={}, weight_quantization={}'
.
format
(
self
.
num_embeddings
,
self
.
embedding_dim
,
self
.
weight
.
target_bits
)
def
enable_weight_quantization
(
self
,
start_bits
,
target_bits
,
quantization_period
,
weight_quantization_enabled_in_forward
,
quantization_type
,
num_groups
):
self
.
num_embeddings
,
self
.
embedding_dim
,
self
.
weight
.
target_bits
)
def
enable_weight_quantization
(
self
,
start_bits
,
target_bits
,
quantization_period
,
weight_quantization_enabled_in_forward
,
quantization_type
,
num_groups
):
self
.
weight
.
start_bits
=
start_bits
self
.
weight
.
target_bits
=
target_bits
self
.
weight
.
q_period
=
quantization_period
...
...
@@ -105,31 +101,20 @@ class Embedding_Compress(nn.Embedding):
self
.
weight_quantize_num_groups
=
self
.
weight
.
size
(
0
)
def
fix_weight_quantization
(
self
):
self
.
weight
.
data
=
self
.
weight_quantizer
(
self
.
weight
,
self
.
weight
.
target_bits
,
None
,
None
,
self
.
weight
.
data
=
self
.
weight_quantizer
(
self
.
weight
,
self
.
weight
.
target_bits
,
None
,
None
,
self
.
weight_quantize_num_groups
).
data
self
.
weight_quantization_enabled_in_forward
=
False
return
None
def
forward
(
self
,
input
):
if
self
.
weight_quantization_enabled_in_forward
and
self
.
weight_quantization_enabled
:
weight
=
self
.
weight_quantizer
(
self
.
weight
,
self
.
weight
.
target_bits
,
None
,
None
,
weight
=
self
.
weight_quantizer
(
self
.
weight
,
self
.
weight
.
target_bits
,
None
,
None
,
self
.
weight_quantize_num_groups
)
else
:
weight
=
self
.
weight
out
=
nn
.
functional
.
embedding
(
input
,
weight
,
self
.
padding_idx
,
self
.
max_norm
,
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
sparse
)
out
=
nn
.
functional
.
embedding
(
input
,
weight
,
self
.
padding_idx
,
self
.
max_norm
,
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
sparse
)
return
out
...
...
@@ -137,6 +122,7 @@ class LinearLayer_Compress(nn.Linear):
"""
Linear layer with compression.
"""
def
__init__
(
self
,
*
kargs
,
bias
=
True
):
super
(
LinearLayer_Compress
,
self
).
__init__
(
*
kargs
,
bias
=
bias
)
self
.
sparse_pruning_method
=
None
...
...
@@ -169,8 +155,7 @@ class LinearLayer_Compress(nn.Linear):
mask
=
mask
.
to
(
self
.
weight
.
device
)
elif
method
==
'topk'
:
self
.
sparse_mask_scores
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
weight
.
size
()))
self
.
sparse_mask_scores
.
data
=
self
.
sparse_mask_scores
.
data
.
to
(
self
.
weight
.
device
)
self
.
sparse_mask_scores
.
data
=
self
.
sparse_mask_scores
.
data
.
to
(
self
.
weight
.
device
)
init
.
kaiming_uniform_
(
self
.
sparse_mask_scores
,
a
=
math
.
sqrt
(
5
))
mask
=
None
else
:
...
...
@@ -209,11 +194,9 @@ class LinearLayer_Compress(nn.Linear):
raise
NotImplementedError
else
:
self
.
head_pruning_ratio
=
ratio
self
.
head_pruning_scores
=
nn
.
Parameter
(
torch
.
Tensor
(
1
,
self
.
num_heads
))
# we apply the pruning to O matrix
self
.
head_pruning_scores
.
data
=
self
.
head_pruning_scores
.
data
.
to
(
self
.
weight
.
device
)
self
.
head_pruning_scores
=
nn
.
Parameter
(
torch
.
Tensor
(
1
,
self
.
num_heads
))
# we apply the pruning to O matrix
self
.
head_pruning_scores
.
data
=
self
.
head_pruning_scores
.
data
.
to
(
self
.
weight
.
device
)
init
.
kaiming_uniform_
(
self
.
head_pruning_scores
,
a
=
math
.
sqrt
(
5
))
def
fix_sparse_pruning_helper
(
self
):
...
...
@@ -279,18 +262,17 @@ class LinearLayer_Compress(nn.Linear):
start_bits
=
self
.
weight
.
start_bits
target_bits
=
self
.
weight
.
target_bits
q_period
=
self
.
weight
.
q_period
self
.
weight
=
nn
.
Parameter
(
self
.
weight
.
data
.
t
().
reshape
(
num_heads
,
-
1
)[
mask
.
view
(
-
1
),
:].
reshape
(
-
1
,
shape
).
t
())
self
.
weight
=
nn
.
Parameter
(
self
.
weight
.
data
.
t
().
reshape
(
num_heads
,
-
1
)[
mask
.
view
(
-
1
),
:].
reshape
(
-
1
,
shape
).
t
())
self
.
weight
.
start_bits
=
start_bits
self
.
weight
.
target_bits
=
target_bits
self
.
weight
.
q_period
=
q_period
else
:
shape
=
self
.
weight
.
size
()
self
.
weight
.
data
=
(
self
.
weight
.
data
.
t
().
reshape
(
self
.
num_heads
,
-
1
)
*
mask
.
view
(
-
1
,
1
)).
reshape
(
shape
[
1
],
shape
[
0
]).
t
()
self
.
weight
.
data
=
(
self
.
weight
.
data
.
t
().
reshape
(
self
.
num_heads
,
-
1
)
*
mask
.
view
(
-
1
,
1
)).
reshape
(
shape
[
1
],
shape
[
0
]).
t
()
if
self
.
head_pruning_method
==
'topk'
:
del
self
.
head_pruning_scores
...
...
@@ -316,37 +298,26 @@ class LinearLayer_Compress(nn.Linear):
if
self
.
sparse_pruning_method
==
'l1'
:
return
self
.
sparse_pruning_mask
.
to
(
self
.
weight
.
device
)
elif
self
.
sparse_pruning_method
==
'topk'
:
return
TopKBinarizer
.
apply
(
self
.
sparse_mask_scores
,
self
.
sparse_pruning_ratio
,
False
)
return
TopKBinarizer
.
apply
(
self
.
sparse_mask_scores
,
self
.
sparse_pruning_ratio
,
False
)
else
:
raise
NotImplementedError
if
pruning_type
==
'row'
:
if
self
.
row_pruning_method
==
'l1'
:
return
self
.
row_pruning_mask
.
to
(
self
.
weight
.
device
)
elif
self
.
row_pruning_method
==
'topk'
:
return
TopKBinarizer
.
apply
(
self
.
row_mask_scores
,
self
.
row_pruning_ratio
,
False
)
return
TopKBinarizer
.
apply
(
self
.
row_mask_scores
,
self
.
row_pruning_ratio
,
False
)
else
:
raise
NotImplementedError
elif
pruning_type
==
'head'
:
if
self
.
head_pruning_method
==
'topk'
:
return
TopKBinarizer
.
apply
(
self
.
head_pruning_scores
,
self
.
head_pruning_ratio
,
False
)
return
TopKBinarizer
.
apply
(
self
.
head_pruning_scores
,
self
.
head_pruning_ratio
,
False
)
else
:
raise
NotImplementedError
else
:
raise
NotImplementedError
def
enable_weight_quantization
(
self
,
start_bits
,
target_bits
,
quantization_period
,
weight_quantization_enabled_in_forward
,
quantization_type
,
num_groups
):
def
enable_weight_quantization
(
self
,
start_bits
,
target_bits
,
quantization_period
,
weight_quantization_enabled_in_forward
,
quantization_type
,
num_groups
):
self
.
weight
.
start_bits
=
start_bits
self
.
weight
.
target_bits
=
target_bits
self
.
weight
.
q_period
=
quantization_period
...
...
@@ -369,10 +340,7 @@ class LinearLayer_Compress(nn.Linear):
self
.
weight_quantize_num_groups
=
num_groups
def
fix_weight_quantization
(
self
):
self
.
weight
.
data
=
self
.
weight_quantizer
(
self
.
weight
,
self
.
weight
.
target_bits
,
None
,
None
,
self
.
weight
.
data
=
self
.
weight_quantizer
(
self
.
weight
,
self
.
weight
.
target_bits
,
None
,
None
,
self
.
weight_quantize_num_groups
).
data
self
.
weight_quantization_enabled_in_forward
=
False
return
None
...
...
@@ -391,18 +359,12 @@ class LinearLayer_Compress(nn.Linear):
def
head_pruning_reshape
(
self
,
w
,
mask
):
shape
=
w
.
shape
return
(
w
.
t
().
reshape
(
self
.
num_heads
,
-
1
)
*
mask
.
view
(
-
1
,
1
)).
reshape
(
shape
[
1
],
shape
[
0
]).
t
()
return
(
w
.
t
().
reshape
(
self
.
num_heads
,
-
1
)
*
mask
.
view
(
-
1
,
1
)).
reshape
(
shape
[
1
],
shape
[
0
]).
t
()
def
forward
(
self
,
input
,
skip_bias_add
=
False
):
if
self
.
weight_quantization_enabled_in_forward
and
self
.
weight_quantization_enabled
:
weight
=
self
.
weight_quantizer
(
self
.
weight
,
self
.
weight
.
target_bits
,
None
,
None
,
weight
=
self
.
weight_quantizer
(
self
.
weight
,
self
.
weight
.
target_bits
,
None
,
None
,
self
.
weight_quantize_num_groups
)
bias
=
self
.
bias
else
:
...
...
@@ -428,11 +390,7 @@ class LinearLayer_Compress(nn.Linear):
num_groups
=
input
.
numel
()
//
input
.
size
(
-
1
)
else
:
num_groups
=
1
input
=
self
.
activation_quantizer
(
input
,
self
.
activation_quantization_bits
,
None
,
None
,
num_groups
)
input
=
self
.
activation_quantizer
(
input
,
self
.
activation_quantization_bits
,
None
,
None
,
num_groups
)
if
skip_bias_add
:
# used for mpu linear layers
...
...
@@ -447,6 +405,7 @@ class Conv2dLayer_Compress(nn.Conv2d):
"""
Conv2D layer with compression.
"""
def
__init__
(
self
,
*
kargs
):
super
(
Conv2dLayer_Compress
,
self
).
__init__
(
*
kargs
)
self
.
sparse_pruning_method
=
None
...
...
@@ -478,10 +437,8 @@ class Conv2dLayer_Compress(nn.Conv2d):
output
=
s
.
format
(
**
self
.
__dict__
)
return
output
+
' sparse pruning={}, channel pruning={}, activation quantization={}, weight_quantization={}'
.
format
(
self
.
sparse_pruning_method
is
not
None
,
self
.
channel_pruning_method
is
not
None
,
self
.
activation_quantization_method
is
not
None
,
self
.
weight
.
target_bits
)
self
.
sparse_pruning_method
is
not
None
,
self
.
channel_pruning_method
is
not
None
,
self
.
activation_quantization_method
is
not
None
,
self
.
weight
.
target_bits
)
def
enable_sparse_pruning
(
self
,
ratio
,
method
):
self
.
sparse_pruning_ratio
=
ratio
...
...
@@ -493,8 +450,7 @@ class Conv2dLayer_Compress(nn.Conv2d):
mask
=
mask
.
to
(
self
.
weight
.
device
)
elif
method
==
'topk'
:
self
.
sparse_mask_scores
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
weight
.
size
()))
self
.
sparse_mask_scores
.
data
=
self
.
sparse_mask_scores
.
data
.
to
(
self
.
weight
.
device
)
self
.
sparse_mask_scores
.
data
=
self
.
sparse_mask_scores
.
data
.
to
(
self
.
weight
.
device
)
init
.
kaiming_uniform_
(
self
.
sparse_mask_scores
,
a
=
math
.
sqrt
(
5
))
mask
=
None
else
:
...
...
@@ -514,13 +470,8 @@ class Conv2dLayer_Compress(nn.Conv2d):
mask
=
mask
.
view
(
-
1
,
1
,
1
,
1
)
mask
=
mask
.
to
(
self
.
weight
.
device
)
elif
method
==
'topk'
:
self
.
channel_mask_scores
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
weight
.
size
(
0
),
1
,
1
,
1
))
self
.
channel_mask_scores
.
data
=
self
.
channel_mask_scores
.
data
.
to
(
self
.
weight
.
device
)
self
.
channel_mask_scores
=
nn
.
Parameter
(
torch
.
Tensor
(
self
.
weight
.
size
(
0
),
1
,
1
,
1
))
self
.
channel_mask_scores
.
data
=
self
.
channel_mask_scores
.
data
.
to
(
self
.
weight
.
device
)
init
.
kaiming_uniform_
(
self
.
channel_mask_scores
,
a
=
math
.
sqrt
(
5
))
mask
=
None
else
:
...
...
@@ -579,39 +530,27 @@ class Conv2dLayer_Compress(nn.Conv2d):
if
self
.
sparse_pruning_method
==
'l1'
:
return
self
.
sparse_pruning_mask
.
to
(
self
.
weight
.
device
)
elif
self
.
sparse_pruning_method
==
'topk'
:
return
TopKBinarizer
.
apply
(
self
.
sparse_mask_scores
,
self
.
sparse_pruning_ratio
,
False
)
return
TopKBinarizer
.
apply
(
self
.
sparse_mask_scores
,
self
.
sparse_pruning_ratio
,
False
)
else
:
raise
NotImplementedError
elif
pruning_type
==
'channel'
:
if
self
.
channel_pruning_method
==
'l1'
:
return
self
.
channel_pruning_mask
.
to
(
self
.
weight
.
device
)
elif
self
.
channel_pruning_method
==
'topk'
:
return
TopKBinarizer
.
apply
(
self
.
channel_mask_scores
,
self
.
channel_pruning_ratio
,
False
)
return
TopKBinarizer
.
apply
(
self
.
channel_mask_scores
,
self
.
channel_pruning_ratio
,
False
)
else
:
raise
NotImplementedError
else
:
raise
NotImplementedError
def
fix_weight_quantization
(
self
):
self
.
weight
.
data
=
self
.
weight_quantizer
(
self
.
weight
,
self
.
weight
.
target_bits
,
None
,
None
,
self
.
weight
.
data
=
self
.
weight_quantizer
(
self
.
weight
,
self
.
weight
.
target_bits
,
None
,
None
,
self
.
weight_quantize_num_groups
).
data
self
.
weight_quantization_enabled_in_forward
=
False
return
None
def
enable_weight_quantization
(
self
,
start_bits
,
target_bits
,
quantization_period
,
weight_quantization_enabled_in_forward
,
quantization_type
,
num_groups
):
def
enable_weight_quantization
(
self
,
start_bits
,
target_bits
,
quantization_period
,
weight_quantization_enabled_in_forward
,
quantization_type
,
num_groups
):
self
.
weight
.
start_bits
=
start_bits
self
.
weight
.
target_bits
=
target_bits
self
.
weight
.
q_period
=
quantization_period
...
...
@@ -642,10 +581,7 @@ class Conv2dLayer_Compress(nn.Conv2d):
def
forward
(
self
,
input
):
if
self
.
weight_quantization_enabled_in_forward
and
self
.
weight_quantization_enabled
:
weight
=
self
.
weight_quantizer
(
self
.
weight
,
self
.
weight
.
target_bits
,
None
,
None
,
weight
=
self
.
weight_quantizer
(
self
.
weight
,
self
.
weight
.
target_bits
,
None
,
None
,
self
.
weight_quantize_num_groups
)
bias
=
self
.
bias
else
:
...
...
@@ -667,22 +603,13 @@ class Conv2dLayer_Compress(nn.Conv2d):
num_groups
=
input
.
numel
()
//
input
[
0
].
numel
()
else
:
num_groups
=
1
input
=
self
.
activation_quantizer
(
input
,
self
.
activation_quantization_bits
,
None
,
None
,
num_groups
)
input
=
self
.
activation_quantizer
(
input
,
self
.
activation_quantization_bits
,
None
,
None
,
num_groups
)
return
nn
.
functional
.
conv2d
(
input
,
weight
,
bias
,
self
.
stride
,
self
.
padding
,
self
.
dilation
,
self
.
groups
)
return
nn
.
functional
.
conv2d
(
input
,
weight
,
bias
,
self
.
stride
,
self
.
padding
,
self
.
dilation
,
self
.
groups
)
class
BNLayer_Compress
(
nn
.
BatchNorm2d
):
def
fix_channel_pruning_helper
(
self
,
mask
,
dim_reduction
=
True
):
self
.
weight
=
nn
.
Parameter
(
self
.
weight
.
data
[
mask
.
view
(
-
1
)])
self
.
bias
=
nn
.
Parameter
(
self
.
bias
.
data
[
mask
.
view
(
-
1
)])
...
...
@@ -770,6 +697,7 @@ def _gather(input_):
class
_CopyToModelParallelRegion
(
torch
.
autograd
.
Function
):
"""Pass the input to the model parallel region."""
@
staticmethod
def
forward
(
ctx
,
input_
):
return
input_
...
...
@@ -781,6 +709,7 @@ class _CopyToModelParallelRegion(torch.autograd.Function):
class
_ReduceFromModelParallelRegion
(
torch
.
autograd
.
Function
):
"""All-redcue the input from the model parallel region."""
@
staticmethod
def
forward
(
ctx
,
input_
):
return
_reduce
(
input_
)
...
...
@@ -792,6 +721,7 @@ class _ReduceFromModelParallelRegion(torch.autograd.Function):
class
_ScatterToModelParallelRegion
(
torch
.
autograd
.
Function
):
"""Split the input and keep only the corresponding chuck to the rank."""
@
staticmethod
def
forward
(
ctx
,
input_
):
return
_split
(
input_
)
...
...
@@ -803,6 +733,7 @@ class _ScatterToModelParallelRegion(torch.autograd.Function):
class
_GatherFromModelParallelRegion
(
torch
.
autograd
.
Function
):
"""Gather the input from model parallel region and concatinate."""
@
staticmethod
def
forward
(
ctx
,
input_
):
return
_gather
(
input_
)
...
...
@@ -834,13 +765,8 @@ def gather_from_model_parallel_region(input_):
class
ColumnParallelLinear_Compress
(
LinearLayer_Compress
):
def
__init__
(
self
,
mpu
,
input_size
,
output_size
,
bias
=
True
,
gather_output
=
True
,
skip_bias_add
=
False
):
def
__init__
(
self
,
mpu
,
input_size
,
output_size
,
bias
=
True
,
gather_output
=
True
,
skip_bias_add
=
False
):
# Keep input parameters
global
g_mpu
g_mpu
=
mpu
...
...
@@ -854,10 +780,7 @@ class ColumnParallelLinear_Compress(LinearLayer_Compress):
assert
output_size
%
world_size
==
0
self
.
output_size_per_partition
=
output_size
//
world_size
super
(
ColumnParallelLinear_Compress
,
self
).
__init__
(
self
.
input_size
,
self
.
output_size_per_partition
,
bias
=
bias
)
super
(
ColumnParallelLinear_Compress
,
self
).
__init__
(
self
.
input_size
,
self
.
output_size_per_partition
,
bias
=
bias
)
def
forward
(
self
,
input_
):
# Set up backprop all-reduce.
...
...
@@ -877,13 +800,8 @@ class ColumnParallelLinear_Compress(LinearLayer_Compress):
class
RowParallelLinear_Compress
(
LinearLayer_Compress
):
def
__init__
(
self
,
mpu
,
input_size
,
output_size
,
bias
=
True
,
input_is_parallel
=
False
,
skip_bias_add
=
False
):
def
__init__
(
self
,
mpu
,
input_size
,
output_size
,
bias
=
True
,
input_is_parallel
=
False
,
skip_bias_add
=
False
):
# Keep input parameters
global
g_mpu
g_mpu
=
mpu
...
...
@@ -897,10 +815,7 @@ class RowParallelLinear_Compress(LinearLayer_Compress):
assert
input_size
%
world_size
==
0
self
.
input_size_per_partition
=
input_size
//
world_size
super
(
RowParallelLinear_Compress
,
self
).
__init__
(
self
.
input_size_per_partition
,
self
.
output_size
,
bias
=
bias
)
super
(
RowParallelLinear_Compress
,
self
).
__init__
(
self
.
input_size_per_partition
,
self
.
output_size
,
bias
=
bias
)
def
forward
(
self
,
input_
):
# Set up backprop all-reduce.
...
...
deepspeed/compression/compress.py
View file @
5bcc463d
'''Copyright The Microsoft DeepSpeed Team'''
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import
re
from
.helper
import
compression_preparation
,
fix_compression
,
recursive_getattr
,
is_module_compressible
...
...
@@ -13,21 +16,13 @@ def check_deepspeed_config(config):
if
isinstance
(
config
,
dict
):
return
config
elif
os
.
path
.
exists
(
config
):
return
json
.
load
(
open
(
config
,
"r"
),
object_pairs_hook
=
dict_raise_error_on_duplicate_keys
)
return
json
.
load
(
open
(
config
,
"r"
),
object_pairs_hook
=
dict_raise_error_on_duplicate_keys
)
else
:
raise
ValueError
(
f
"Expected a string path to an existing deepspeed config, or a dictionary. Received:
{
config
}
"
)
f
"Expected a string path to an existing deepspeed config, or a dictionary. Received:
{
config
}
"
)
def
get_module_name
(
group_name
,
model
,
key_word
,
exist_module_name
,
mpu
=
None
,
verbose
=
True
):
def
get_module_name
(
group_name
,
model
,
key_word
,
exist_module_name
,
mpu
=
None
,
verbose
=
True
):
'''
get the associated module name from the model based on the key_word provided by users
'''
...
...
@@ -40,8 +35,7 @@ def get_module_name(group_name,
if
name
in
exist_module_name
and
verbose
:
# logger.warning
raise
ValueError
(
f
"
{
name
}
is already added to compression, please check your config file for
{
group_name
}
."
)
f
"
{
name
}
is already added to compression, please check your config file for
{
group_name
}
."
)
if
name
not
in
exist_module_name
:
exist_module_name
.
add
(
name
)
return_module_name
.
append
(
name
)
...
...
@@ -56,8 +50,7 @@ def get_compress_methods(model, compress_methods, mpu=None):
continue
# for loop different methods, i.e., weight quantization, activation quantization etc
exist_module_name
=
set
()
shared_parameters
=
method_content
[
SHARED_PARAMETERS
]
# get all the shared parameters
shared_parameters
=
method_content
[
SHARED_PARAMETERS
]
# get all the shared parameters
for
group_name
,
method_parameters
in
method_content
[
DIFFERENT_GROUPS
].
items
():
# for loop different groups, i.e., weight quantization group 1, weight quantization group 2 etc
module_name_list
=
[]
...
...
@@ -65,8 +58,13 @@ def get_compress_methods(model, compress_methods, mpu=None):
if
method_parameters
[
DIFFERENT_GROUPS_RELATED_MODULE_SCOPE
]:
# this is used for head/row/channel pruning, if users provide the related module scope, we can shrink the layer dim for them
# otherwise we just mask those as zeros
for
key_word
,
related_key_words
in
zip
(
method_parameters
[
DIFFERENT_GROUPS_MODULE_SCOPE
],
method_parameters
[
DIFFERENT_GROUPS_RELATED_MODULE_SCOPE
]):
module_name
,
exist_module_name
=
get_module_name
(
group_name
,
model
,
key_word
,
exist_module_name
,
mpu
=
mpu
)
for
key_word
,
related_key_words
in
zip
(
method_parameters
[
DIFFERENT_GROUPS_MODULE_SCOPE
],
method_parameters
[
DIFFERENT_GROUPS_RELATED_MODULE_SCOPE
]):
module_name
,
exist_module_name
=
get_module_name
(
group_name
,
model
,
key_word
,
exist_module_name
,
mpu
=
mpu
)
module_name_list
.
append
(
module_name
)
tmp_related_module_name_list
=
[]
for
rkw
in
related_key_words
:
...
...
@@ -76,7 +74,11 @@ def get_compress_methods(model, compress_methods, mpu=None):
related_module_name_list
.
append
(
tmp_related_module_name_list
)
else
:
for
key_word
in
method_parameters
[
DIFFERENT_GROUPS_MODULE_SCOPE
]:
module_name
,
exist_module_name
=
get_module_name
(
group_name
,
model
,
key_word
,
exist_module_name
,
mpu
=
mpu
)
module_name
,
exist_module_name
=
get_module_name
(
group_name
,
model
,
key_word
,
exist_module_name
,
mpu
=
mpu
)
module_name_list
.
append
(
module_name
)
if
module_name_list
:
...
...
@@ -85,13 +87,7 @@ def get_compress_methods(model, compress_methods, mpu=None):
**
(
method_parameters
.
copy
().
pop
(
DIFFERENT_GROUPS_PARAMETERS
)),
**
shared_parameters
}
compression_item
=
[
module_name_list
,
related_module_name_list
,
{
method
:
combined_method_parameters
}
]
compression_item
=
[
module_name_list
,
related_module_name_list
,
{
method
:
combined_method_parameters
}]
layer_added_compress_methods
.
append
(
compression_item
)
return
layer_added_compress_methods
...
...
@@ -118,9 +114,7 @@ def init_compression(model, deepspeed_config, teacher_model=None, mpu=None):
assert
teacher_model
is
not
None
,
"Teacher model is required for layer reduction"
student_initialization
(
c_model
,
teacher_model
,
deepspeed_config
)
layer_added_compress_methods
=
get_compress_methods
(
c_model
,
compress_methods
,
mpu
=
mpu
)
layer_added_compress_methods
=
get_compress_methods
(
c_model
,
compress_methods
,
mpu
=
mpu
)
compression_preparation
(
c_model
,
layer_added_compress_methods
,
mpu
)
return
model
...
...
@@ -143,31 +137,20 @@ def redundancy_clean(model, deepspeed_config, mpu=None):
else
:
c_model
=
model
layer_added_compress_methods_tmp
=
get_compress_methods
(
c_model
,
compress_methods
,
mpu
=
mpu
)
layer_added_compress_methods_tmp
=
get_compress_methods
(
c_model
,
compress_methods
,
mpu
=
mpu
)
# sort methods
order_list
=
[
WEIGHT_QUANTIZATION
,
SPARSE_PRUNING
,
ROW_PRUNING
,
HEAD_PRUNING
,
CHANNEL_PRUNING
,
ACTIVATION_QUANTIZATION
WEIGHT_QUANTIZATION
,
SPARSE_PRUNING
,
ROW_PRUNING
,
HEAD_PRUNING
,
CHANNEL_PRUNING
,
ACTIVATION_QUANTIZATION
]
layer_added_compress_methods
=
sorted
(
layer_added_compress_methods_tmp
,
key
=
lambda
x
:
order_list
.
index
(
list
(
x
[
2
].
keys
())[
0
]))
layer_added_compress_methods
=
sorted
(
layer_added_compress_methods_tmp
,
key
=
lambda
x
:
order_list
.
index
(
list
(
x
[
2
].
keys
())[
0
]))
for
module_name_lists
,
related_module_name_lists
,
compression_technique
in
layer_added_compress_methods
:
stored_mask
=
[]
need_mask
=
True
if
related_module_name_lists
else
False
for
i
,
mnl
in
enumerate
(
module_name_lists
):
for
module_name
in
mnl
:
mask
=
fix_compression
(
c_model
,
module_name
,
compression_technique
,
dim_reduction
=
need_mask
)
mask
=
fix_compression
(
c_model
,
module_name
,
compression_technique
,
dim_reduction
=
need_mask
)
if
need_mask
:
stored_mask
.
append
(
mask
)
if
need_mask
:
...
...
@@ -219,10 +202,8 @@ def student_initialization(student_model, teacher_model, deepspeed_config):
'''
assert
len
(
student_layer
)
==
len
(
teacher_layer
)
for
s_name
,
t_name
in
zip
(
student_layer
,
teacher_layer
):
s_module
=
recursive_getattr
(
student_model
,
module_name_prefix
+
'.'
+
str
(
s_name
))
t_module
=
recursive_getattr
(
teacher_model
,
module_name_prefix
+
'.'
+
str
(
t_name
))
s_module
=
recursive_getattr
(
student_model
,
module_name_prefix
+
'.'
+
str
(
s_name
))
t_module
=
recursive_getattr
(
teacher_model
,
module_name_prefix
+
'.'
+
str
(
t_name
))
for
s_param
,
t_param
in
zip
(
s_module
.
parameters
(),
t_module
.
parameters
()):
s_param
.
data
.
copy_
(
t_param
.
data
)
for
name
in
other_module_name
:
...
...
Prev
1
…
6
7
8
9
10
11
12
13
14
…
28
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment