Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
9a3321e9
Unverified
Commit
9a3321e9
authored
Mar 26, 2024
by
Edenzzzz
Committed by
GitHub
Mar 26, 2024
Browse files
Merge pull request #5515 from Edenzzzz/fix_layout_convert
Fix layout convertor caching
parents
cbe34c55
18edcd53
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
13 additions
and
3 deletions
+13
-3
colossalai/tensor/d_tensor/layout_converter.py
colossalai/tensor/d_tensor/layout_converter.py
+4
-1
tests/test_tensor/test_dtensor/test_layout_converter.py
tests/test_tensor/test_dtensor/test_layout_converter.py
+9
-2
No files found.
colossalai/tensor/d_tensor/layout_converter.py
View file @
9a3321e9
...
@@ -440,7 +440,10 @@ class LayoutConverter(metaclass=SingletonMeta):
...
@@ -440,7 +440,10 @@ class LayoutConverter(metaclass=SingletonMeta):
total_steps
=
0
total_steps
=
0
transform_path
=
[]
transform_path
=
[]
comm_action_sequence
:
List
[
CommSpec
]
=
[]
comm_action_sequence
:
List
[
CommSpec
]
=
[]
spec_pairs
=
(
str
(
source_spec
.
sharding_sequence
),
str
(
target_spec
.
sharding_sequence
))
src_shape
=
source_layout
.
get_sharded_shape_per_device
()
dst_shape
=
target_layout
.
get_sharded_shape_per_device
()
spec_pairs
=
((
str
(
source_spec
.
sharding_sequence
),
src_shape
),
(
str
(
target_spec
.
sharding_sequence
),
dst_shape
))
if
spec_pairs
in
self
.
cached_solution
:
if
spec_pairs
in
self
.
cached_solution
:
# Solution Cache hit
# Solution Cache hit
...
...
tests/test_tensor/test_dtensor/test_layout_converter.py
View file @
9a3321e9
...
@@ -123,8 +123,15 @@ def check_layout_converting(rank, world_size, port):
...
@@ -123,8 +123,15 @@ def check_layout_converting(rank, world_size, port):
assert
comm_action_sequence
[
2
].
logical_process_axis
==
1
assert
comm_action_sequence
[
2
].
logical_process_axis
==
1
# checkout chached_spec_pairs_transform_path
# checkout chached_spec_pairs_transform_path
assert
layout_converter
.
cached_solution
[(
"[R, S01, R]"
,
"[S01, R, R]"
)][
0
]
==
transform_path
src_shape
=
source_layout
.
get_sharded_shape_per_device
()
assert
layout_converter
.
cached_solution
[(
"[R, S01, R]"
,
"[S01, R, R]"
)][
1
]
==
comm_action_sequence
dst_shape
=
target_layout
.
get_sharded_shape_per_device
()
assert
(
layout_converter
.
cached_solution
[((
"[R, S01, R]"
,
src_shape
),
(
"[S01, R, R]"
,
dst_shape
))][
0
]
==
transform_path
)
assert
(
layout_converter
.
cached_solution
[((
"[R, S01, R]"
,
src_shape
),
(
"[S01, R, R]"
,
dst_shape
))][
1
]
==
comm_action_sequence
)
comm_cost
=
layout_converter
.
get_total_comm_cost
(
source_layout
,
target_layout
)
comm_cost
=
layout_converter
.
get_total_comm_cost
(
source_layout
,
target_layout
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment