Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
32f81f14
Unverified
Commit
32f81f14
authored
May 19, 2023
by
digger yu
Committed by
GitHub
May 19, 2023
Browse files
[NFC] fix typo colossalai/amp auto_parallel autochunk (#3756)
parent
21e29e22
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
12 additions
and
12 deletions
+12
-12
colossalai/amp/torch_amp/_grad_scaler.py
colossalai/amp/torch_amp/_grad_scaler.py
+1
-1
colossalai/auto_parallel/meta_profiler/meta_registry/linear.py
...salai/auto_parallel/meta_profiler/meta_registry/linear.py
+1
-1
colossalai/auto_parallel/passes/runtime_apply_pass.py
colossalai/auto_parallel/passes/runtime_apply_pass.py
+1
-1
colossalai/auto_parallel/passes/runtime_preparation_pass.py
colossalai/auto_parallel/passes/runtime_preparation_pass.py
+2
-2
colossalai/autochunk/trace_flow.py
colossalai/autochunk/trace_flow.py
+3
-3
colossalai/autochunk/trace_indice.py
colossalai/autochunk/trace_indice.py
+4
-4
No files found.
colossalai/amp/torch_amp/_grad_scaler.py
View file @
32f81f14
...
@@ -240,7 +240,7 @@ class GradScaler(object):
...
@@ -240,7 +240,7 @@ class GradScaler(object):
for
grads
in
per_dtype_grads
.
values
():
for
grads
in
per_dtype_grads
.
values
():
torch
.
_amp_foreach_non_finite_check_and_unscale_
(
grads
,
per_device_found_inf
.
get
(
device
),
torch
.
_amp_foreach_non_finite_check_and_unscale_
(
grads
,
per_device_found_inf
.
get
(
device
),
per_device_inv_scale
.
get
(
device
))
per_device_inv_scale
.
get
(
device
))
# For tensor parallel paramters it should be all-reduced over tensor parallel process group
# For tensor parallel param
e
ters it should be all-reduced over tensor parallel process group
if
gpc
.
is_initialized
(
ParallelMode
.
MODEL
)
and
gpc
.
get_world_size
(
ParallelMode
.
MODEL
)
>
1
:
if
gpc
.
is_initialized
(
ParallelMode
.
MODEL
)
and
gpc
.
get_world_size
(
ParallelMode
.
MODEL
)
>
1
:
vals
=
[
val
for
val
in
per_device_found_inf
.
_per_device_tensors
.
values
()]
vals
=
[
val
for
val
in
per_device_found_inf
.
_per_device_tensors
.
values
()]
coalesced
=
_flatten_dense_tensors
(
vals
)
coalesced
=
_flatten_dense_tensors
(
vals
)
...
...
colossalai/auto_parallel/meta_profiler/meta_registry/linear.py
View file @
32f81f14
...
@@ -325,7 +325,7 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
...
@@ -325,7 +325,7 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
else
:
else
:
_is_batch_dims_same
=
False
_is_batch_dims_same
=
False
# ret
i
reve dimensions
# retr
i
eve dimensions
input_dim_00
=
input_tensors
[
0
].
shape
[
-
2
]
input_dim_00
=
input_tensors
[
0
].
shape
[
-
2
]
input_dim_01
=
input_tensors
[
0
].
shape
[
-
1
]
input_dim_01
=
input_tensors
[
0
].
shape
[
-
1
]
input_dim_10
=
input_tensors
[
1
].
shape
[
-
2
]
input_dim_10
=
input_tensors
[
1
].
shape
[
-
2
]
...
...
colossalai/auto_parallel/passes/runtime_apply_pass.py
View file @
32f81f14
...
@@ -219,7 +219,7 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
...
@@ -219,7 +219,7 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
return
gm
return
gm
def
_act_annotat
a
ion_pass
(
gm
:
torch
.
fx
.
GraphModule
):
def
_act_annotation_pass
(
gm
:
torch
.
fx
.
GraphModule
):
"""
"""
This pass is used to add the act annotation to the new inserted nodes.
This pass is used to add the act annotation to the new inserted nodes.
"""
"""
...
...
colossalai/auto_parallel/passes/runtime_preparation_pass.py
View file @
32f81f14
...
@@ -54,7 +54,7 @@ def size_processing(size: Union[int, torch.Size],
...
@@ -54,7 +54,7 @@ def size_processing(size: Union[int, torch.Size],
return
size
return
size
def
solution_annotat
at
ion_pass
(
gm
:
torch
.
fx
.
GraphModule
,
solution
:
List
[
int
],
def
solution_annotation_pass
(
gm
:
torch
.
fx
.
GraphModule
,
solution
:
List
[
int
],
strategies_constructor
:
StrategiesConstructor
):
strategies_constructor
:
StrategiesConstructor
):
"""
"""
This method is used to stick the solution strategy to the nodes and add the information
This method is used to stick the solution strategy to the nodes and add the information
...
@@ -496,7 +496,7 @@ def runtime_preparation_pass(gm: torch.fx.GraphModule,
...
@@ -496,7 +496,7 @@ def runtime_preparation_pass(gm: torch.fx.GraphModule,
device_mesh
:
DeviceMesh
,
device_mesh
:
DeviceMesh
,
strategies_constructor
:
StrategiesConstructor
,
strategies_constructor
:
StrategiesConstructor
,
overlap
=
False
):
overlap
=
False
):
gm
,
sharding_spec_convert_dict
,
origin_node_sharding_spec_dict
,
comm_actions_dict
=
solution_annotat
at
ion_pass
(
gm
,
sharding_spec_convert_dict
,
origin_node_sharding_spec_dict
,
comm_actions_dict
=
solution_annotation_pass
(
gm
,
solution
,
strategies_constructor
)
gm
,
solution
,
strategies_constructor
)
gm
=
size_value_converting_pass
(
gm
,
device_mesh
)
gm
=
size_value_converting_pass
(
gm
,
device_mesh
)
gm
=
node_args_converting_pass
(
gm
,
device_mesh
)
gm
=
node_args_converting_pass
(
gm
,
device_mesh
)
...
...
colossalai/autochunk/trace_flow.py
View file @
32f81f14
...
@@ -64,7 +64,7 @@ class TraceFlow(object):
...
@@ -64,7 +64,7 @@ class TraceFlow(object):
return
False
return
False
return
True
return
True
def
_ass
g
in_single_node_flow
(
def
_assi
g
n_single_node_flow
(
self
,
self
,
arg_node
:
Node
,
arg_node
:
Node
,
start_idx
:
int
,
start_idx
:
int
,
...
@@ -177,7 +177,7 @@ class TraceFlow(object):
...
@@ -177,7 +177,7 @@ class TraceFlow(object):
if
get_node_shape
(
arg
)
is
None
:
if
get_node_shape
(
arg
)
is
None
:
continue
continue
arg_list
.
append
(
arg
)
arg_list
.
append
(
arg
)
flow_flag
=
self
.
_ass
g
in_single_node_flow
(
flow_flag
=
self
.
_assi
g
n_single_node_flow
(
arg
,
arg
,
start_idx
,
start_idx
,
end_idx
,
end_idx
,
...
@@ -315,7 +315,7 @@ class TraceFlow(object):
...
@@ -315,7 +315,7 @@ class TraceFlow(object):
chunk_info
[
"args"
][
"prepose_nodes"
]
=
prepose_nodes
chunk_info
[
"args"
][
"prepose_nodes"
]
=
prepose_nodes
def
_get_non_chunk_inputs
(
self
,
chunk_info
,
start_idx
,
end_idx
):
def
_get_non_chunk_inputs
(
self
,
chunk_info
,
start_idx
,
end_idx
):
# we need to log input nodes to avoid delet
e
ing them in the loop
# we need to log input nodes to avoid deleting them in the loop
chunk_node_list
=
self
.
node_mgr
.
get_node_slice_by_idx
(
start_idx
,
end_idx
+
1
)
chunk_node_list
=
self
.
node_mgr
.
get_node_slice_by_idx
(
start_idx
,
end_idx
+
1
)
# also need to get some prepose node's arg out of non_chunk_inputs
# also need to get some prepose node's arg out of non_chunk_inputs
for
n
in
chunk_info
[
"args"
][
"prepose_nodes"
]:
for
n
in
chunk_info
[
"args"
][
"prepose_nodes"
]:
...
...
colossalai/autochunk/trace_indice.py
View file @
32f81f14
...
@@ -461,7 +461,7 @@ class TraceIndice(object):
...
@@ -461,7 +461,7 @@ class TraceIndice(object):
nodes_in
.
append
(
node_in
)
nodes_in
.
append
(
node_in
)
self
.
_inherit_more_indice_from_node_with_exclude
(
node_in
,
node
)
self
.
_inherit_more_indice_from_node_with_exclude
(
node_in
,
node
)
def
_ass
g
in_no_change_indice
(
self
,
node
,
idx
):
def
_assi
g
n_no_change_indice
(
self
,
node
,
idx
):
self
.
_assign_indice_as_input
(
node
,
idx
)
self
.
_assign_indice_as_input
(
node
,
idx
)
for
node_in
in
node
.
args
:
for
node_in
in
node
.
args
:
if
type
(
node_in
)
==
type
(
node
):
if
type
(
node_in
)
==
type
(
node
):
...
@@ -792,7 +792,7 @@ class TraceIndice(object):
...
@@ -792,7 +792,7 @@ class TraceIndice(object):
self
.
_add_dim
(
node_idx
,
i
)
self
.
_add_dim
(
node_idx
,
i
)
dim_from
.
reverse
()
dim_from
.
reverse
()
# inhe
i
rt indice from current node
# inher
i
t indice from current node
if
len
(
dim_from
)
!=
0
and
len
(
dim_to
)
!=
0
:
if
len
(
dim_from
)
!=
0
and
len
(
dim_to
)
!=
0
:
if
dim_diff
==
1
:
if
dim_diff
==
1
:
if
origin_shape
[
dim_from
[
0
]]
==
1
:
if
origin_shape
[
dim_from
[
0
]]
==
1
:
...
@@ -852,7 +852,7 @@ class TraceIndice(object):
...
@@ -852,7 +852,7 @@ class TraceIndice(object):
elif
"split"
==
node_name
:
elif
"split"
==
node_name
:
self
.
_assign_split_indice
(
node
,
idx
)
self
.
_assign_split_indice
(
node
,
idx
)
elif
any
(
i
==
node_name
for
i
in
[
"to"
,
"contiguous"
,
"clone"
,
"type"
,
"float"
]):
elif
any
(
i
==
node_name
for
i
in
[
"to"
,
"contiguous"
,
"clone"
,
"type"
,
"float"
]):
self
.
_ass
g
in_no_change_indice
(
node
,
idx
)
self
.
_assi
g
n_no_change_indice
(
node
,
idx
)
elif
"new_ones"
==
node_name
:
elif
"new_ones"
==
node_name
:
self
.
_assign_all_indice
(
node
,
idx
)
self
.
_assign_all_indice
(
node
,
idx
)
elif
"flatten"
==
node_name
:
elif
"flatten"
==
node_name
:
...
@@ -914,7 +914,7 @@ class TraceIndice(object):
...
@@ -914,7 +914,7 @@ class TraceIndice(object):
elif
"conv2d"
==
node_name
:
elif
"conv2d"
==
node_name
:
self
.
_assign_conv2d_indice
(
node
,
idx
)
self
.
_assign_conv2d_indice
(
node
,
idx
)
elif
"identity"
==
node_name
:
elif
"identity"
==
node_name
:
self
.
_ass
g
in_no_change_indice
(
node
,
idx
)
self
.
_assi
g
n_no_change_indice
(
node
,
idx
)
elif
any
(
n
==
node_name
for
n
in
[
"sigmoid"
,
"dropout"
,
"relu"
,
"silu"
,
"gelu"
]):
elif
any
(
n
==
node_name
for
n
in
[
"sigmoid"
,
"dropout"
,
"relu"
,
"silu"
,
"gelu"
]):
self
.
_assign_elementwise_indice
(
node
,
idx
)
self
.
_assign_elementwise_indice
(
node
,
idx
)
else
:
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment