Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
7c079d9c
Unverified
Commit
7c079d9c
authored
Mar 11, 2022
by
HELSON
Committed by
GitHub
Mar 11, 2022
Browse files
[hotfix] fixed bugs in ShardStrategy and PcieProfiler (#394)
parent
1e4bf85c
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
4 additions
and
3 deletions
+4
-3
colossalai/utils/profiler/pcie_profiler.py
colossalai/utils/profiler/pcie_profiler.py
+1
-1
colossalai/zero/shard_utils/tensor_shard_strategy.py
colossalai/zero/shard_utils/tensor_shard_strategy.py
+3
-2
No files found.
colossalai/utils/profiler/pcie_profiler.py
View file @
7c079d9c
...
...
@@ -79,7 +79,7 @@ class PcieProfiler(BaseProfiler):
if
self
.
profiler
.
enabled
:
events
=
self
.
profiler
.
function_events
for
event
in
events
:
if
event
.
name
==
"aten::
_to_
copy"
:
if
event
.
name
==
"aten::copy
_
"
:
t_shape
=
event
.
input_shapes
[
0
]
if
len
(
t_shape
)
==
0
or
event
.
cuda_time_total
==
0
:
continue
...
...
colossalai/zero/shard_utils/tensor_shard_strategy.py
View file @
7c079d9c
...
...
@@ -5,6 +5,7 @@ import torch.distributed as dist
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.sharded_model._zero3_utils
import
get_shard
from
colossalai.zero.sharded_param.sharded_tensor
import
ShardedTensor
from
colossalai.utils
import
get_current_device
class
TensorShardStrategy
(
BaseShardStrategy
):
...
...
@@ -35,9 +36,9 @@ class TensorShardStrategy(BaseShardStrategy):
payload_numel
=
t
.
payload
.
numel
()
for
i
in
range
(
self
.
world_size
):
if
i
==
self
.
local_rank
:
buffer_list
.
append
(
t
.
payload
.
cuda
())
buffer_list
.
append
(
t
.
payload
.
cuda
(
get_current_device
()
))
else
:
buffer_list
.
append
(
torch
.
zeros
(
payload_numel
,
dtype
=
t
.
dtype
).
cuda
(
))
buffer_list
.
append
(
torch
.
zeros
(
payload_numel
,
dtype
=
t
.
dtype
,
device
=
get_current_device
()
))
torch
.
distributed
.
all_gather
(
buffer_list
,
buffer_list
[
self
.
local_rank
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment