Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
98cdbf49
Unverified
Commit
98cdbf49
authored
Jun 07, 2022
by
ver217
Committed by
GitHub
Jun 07, 2022
Browse files
[hotfix] fix chunk comm src rank (#1072)
parent
bfdc5ccb
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
2 deletions
+3
-2
colossalai/tensor/chunk.py
colossalai/tensor/chunk.py
+3
-2
No files found.
colossalai/tensor/chunk.py
View file @
98cdbf49
...
...
@@ -47,6 +47,7 @@ class Chunk:
self
.
utilized_size
=
0
self
.
src_rank
=
src_rank
self
.
is_src_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
DATA
)
==
src_rank
self
.
global_src_rank
=
gpc
.
get_ranks_in_group
(
ParallelMode
.
DATA
)[
src_rank
]
self
.
dtype
=
dtype
self
.
device
=
init_device
or
get_current_device
()
self
.
data
=
torch
.
empty
(
chunk_size
,
dtype
=
dtype
,
device
=
self
.
device
)
...
...
@@ -87,7 +88,7 @@ class Chunk:
if
not
self
.
is_src_rank
:
self
.
data
.
storage
().
resize_
(
self
.
size
)
self
.
data
.
data
=
self
.
data
.
to
(
get_current_device
())
dist
.
broadcast
(
self
.
data
,
self
.
src_rank
,
group
=
gpc
.
get_group
(
ParallelMode
.
DATA
))
dist
.
broadcast
(
self
.
data
,
self
.
global_
src_rank
,
group
=
gpc
.
get_group
(
ParallelMode
.
DATA
))
self
.
_update_tensors_ptr
()
if
not
self
.
is_src_rank
:
self
.
_update_tensors_state
(
TensorState
.
HOLD
,
prev_state
=
TensorState
.
FREE
)
...
...
@@ -101,7 +102,7 @@ class Chunk:
if
is_all_reduce
:
dist
.
all_reduce
(
self
.
data
,
group
=
gpc
.
get_group
(
ParallelMode
.
DATA
))
else
:
dist
.
reduce
(
self
.
data
,
self
.
src_rank
,
group
=
gpc
.
get_group
(
ParallelMode
.
DATA
))
dist
.
reduce
(
self
.
data
,
self
.
global_
src_rank
,
group
=
gpc
.
get_group
(
ParallelMode
.
DATA
))
self
.
_update_tensors_ptr
()
self
.
_update_tensors_state
(
TensorState
.
HOLD
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment