Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
0387a47e
Unverified
Commit
0387a47e
authored
Aug 29, 2023
by
Baizhou Zhang
Committed by
GitHub
Aug 29, 2023
Browse files
[shardformer] fix emerged bugs after updating transformers (#4526)
parent
c554b7f5
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
9 additions
and
2 deletions
+9
-2
colossalai/pipeline/schedule/_utils.py
colossalai/pipeline/schedule/_utils.py
+4
-1
tests/test_shardformer/test_model/_utils.py
tests/test_shardformer/test_model/_utils.py
+5
-1
No files found.
colossalai/pipeline/schedule/_utils.py
View file @
0387a47e
...
@@ -123,7 +123,10 @@ def merge_batch(data: List[Any]) -> Any:
...
@@ -123,7 +123,10 @@ def merge_batch(data: List[Any]) -> Any:
merged_data
=
[]
merged_data
=
[]
for
elem_batch
in
zip
(
*
flattened_data
):
for
elem_batch
in
zip
(
*
flattened_data
):
if
isinstance
(
elem_batch
[
0
],
torch
.
Tensor
):
if
isinstance
(
elem_batch
[
0
],
torch
.
Tensor
):
merged_data
.
append
(
torch
.
cat
(
elem_batch
,
dim
=
0
))
if
len
(
elem_batch
[
0
].
shape
)
==
0
:
# set loss to None in pipeline outputs
merged_data
.
append
(
None
)
else
:
merged_data
.
append
(
torch
.
cat
(
elem_batch
,
dim
=
0
))
else
:
else
:
merged_data
.
append
(
list
(
elem_batch
))
merged_data
.
append
(
list
(
elem_batch
))
return
tree_unflatten
(
merged_data
,
tree_spec
)
return
tree_unflatten
(
merged_data
,
tree_spec
)
tests/test_shardformer/test_model/_utils.py
View file @
0387a47e
...
@@ -195,7 +195,11 @@ def check_output_hidden_state(org_output: Tensor,
...
@@ -195,7 +195,11 @@ def check_output_hidden_state(org_output: Tensor,
sharded_hidden_state
=
sharded_output
.
last_hidden_state
sharded_hidden_state
=
sharded_output
.
last_hidden_state
if
stage_manager
and
stage_manager
.
is_last_stage
():
if
stage_manager
and
stage_manager
.
is_last_stage
():
sharded_hidden_state
=
torch
.
cat
([
output
.
last_hidden_state
for
output
in
sharded_output
[
'outputs'
]],
dim
=
dim
)
pipeline_output
=
sharded_output
[
'outputs'
]
if
isinstance
(
pipeline_output
,
List
):
sharded_hidden_state
=
torch
.
cat
([
output
.
last_hidden_state
for
output
in
pipeline_output
],
dim
=
dim
)
else
:
sharded_hidden_state
=
pipeline_output
.
last_hidden_state
assert
torch
.
allclose
(
org_hidden_state
.
float
(),
sharded_hidden_state
.
float
(),
atol
=
atol
,
rtol
=
rtol
),
\
assert
torch
.
allclose
(
org_hidden_state
.
float
(),
sharded_hidden_state
.
float
(),
atol
=
atol
,
rtol
=
rtol
),
\
f
"shard model's output hidden state is not equal to origin model's last hidden state
\n
{
org_hidden_state
}
\n
{
sharded_hidden_state
}
"
f
"shard model's output hidden state is not equal to origin model's last hidden state
\n
{
org_hidden_state
}
\n
{
sharded_hidden_state
}
"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment