Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
bd186784
"git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "ab962b9735ea323eb84c5bc4bce534bf2376960e"
Unverified
Commit
bd186784
authored
Sep 05, 2023
by
Hongxin Liu
Committed by
GitHub
Sep 05, 2023
Browse files
[test] fix gemini checkpoint and gpt test (#4620)
parent
e71d2452
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
2 additions
and
3 deletions
+2
-3
tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py
...t_checkpoint_io/test_plugins_huggingface_compatibility.py
+1
-1
tests/test_shardformer/test_model/test_shard_gpt2.py
tests/test_shardformer/test_model/test_shard_gpt2.py
+1
-2
No files found.
tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py
View file @
bd186784
...
@@ -32,7 +32,7 @@ def exam_from_pretrained(plugin_type: str, model_name: str, shard=True, size_per
...
@@ -32,7 +32,7 @@ def exam_from_pretrained(plugin_type: str, model_name: str, shard=True, size_per
elif
plugin_type
==
'zero'
:
elif
plugin_type
==
'zero'
:
plugin
=
LowLevelZeroPlugin
(
stage
=
2
,
max_norm
=
1.0
,
initial_scale
=
32
)
plugin
=
LowLevelZeroPlugin
(
stage
=
2
,
max_norm
=
1.0
,
initial_scale
=
32
)
elif
plugin_type
==
'gemini'
:
elif
plugin_type
==
'gemini'
:
plugin
=
GeminiPlugin
(
placement_policy
=
'cuda'
,
precision
=
"fp16"
,
initial_scale
=
32
)
plugin
=
GeminiPlugin
(
precision
=
"fp16"
,
initial_scale
=
32
)
else
:
else
:
raise
ValueError
(
f
"Plugin with type
{
plugin_type
}
is invalid, please check your argument."
)
raise
ValueError
(
f
"Plugin with type
{
plugin_type
}
is invalid, please check your argument."
)
...
...
tests/test_shardformer/test_model/test_shard_gpt2.py
View file @
bd186784
...
@@ -102,7 +102,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
...
@@ -102,7 +102,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
@
pytest
.
mark
.
skip
(
reason
=
"This test will hang in CI"
)
@
parameterize
(
'test_config'
,
[{
@
parameterize
(
'test_config'
,
[{
'tp_size'
:
2
,
'tp_size'
:
2
,
'pp_size'
:
2
,
'pp_size'
:
2
,
...
@@ -220,7 +219,7 @@ def check_gpt2_3d(rank, world_size, port):
...
@@ -220,7 +219,7 @@ def check_gpt2_3d(rank, world_size, port):
run_gpt2_3d_test
()
run_gpt2_3d_test
()
@
pytest
.
mark
.
skip
(
reason
=
"This test will hang in CI"
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
@
rerun_if_address_is_in_use
()
@
clear_cache_before_run
()
@
clear_cache_before_run
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment