Unverified Commit 37d8611a authored by statelesshz's avatar statelesshz Committed by GitHub
Browse files

replace no_cuda with use_cpu in test_pytorch_examples (#24944)

* replace no_cuda with use_cpu in test_pytorch_examples

* remove codes that never be used

* fix style
parent 79444f37
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
import argparse
import json import json
import logging import logging
import os import os
...@@ -76,13 +75,6 @@ logging.basicConfig(level=logging.DEBUG) ...@@ -76,13 +75,6 @@ logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger() logger = logging.getLogger()
def get_setup_file():
parser = argparse.ArgumentParser()
parser.add_argument("-f")
args = parser.parse_args()
return args.f
def get_results(output_dir): def get_results(output_dir):
results = {} results = {}
path = os.path.join(output_dir, "all_results.json") path = os.path.join(output_dir, "all_results.json")
...@@ -153,8 +145,8 @@ class ExamplesTests(TestCasePlus): ...@@ -153,8 +145,8 @@ class ExamplesTests(TestCasePlus):
# Skipping because there are not enough batches to train the model + would need a drop_last to work. # Skipping because there are not enough batches to train the model + would need a drop_last to work.
return return
if torch_device != "cuda": if torch_device == "cpu":
testargs.append("--no_cuda") testargs.append("--use_cpu")
with patch.object(sys, "argv", testargs): with patch.object(sys, "argv", testargs):
run_clm.main() run_clm.main()
...@@ -175,8 +167,8 @@ class ExamplesTests(TestCasePlus): ...@@ -175,8 +167,8 @@ class ExamplesTests(TestCasePlus):
--config_overrides n_embd=10,n_head=2 --config_overrides n_embd=10,n_head=2
""".split() """.split()
if torch_device != "cuda": if torch_device == "cpu":
testargs.append("--no_cuda") testargs.append("--use_cpu")
logger = run_clm.logger logger = run_clm.logger
with patch.object(sys, "argv", testargs): with patch.object(sys, "argv", testargs):
...@@ -201,8 +193,8 @@ class ExamplesTests(TestCasePlus): ...@@ -201,8 +193,8 @@ class ExamplesTests(TestCasePlus):
--num_train_epochs=1 --num_train_epochs=1
""".split() """.split()
if torch_device != "cuda": if torch_device == "cpu":
testargs.append("--no_cuda") testargs.append("--use_cpu")
with patch.object(sys, "argv", testargs): with patch.object(sys, "argv", testargs):
run_mlm.main() run_mlm.main()
...@@ -231,8 +223,8 @@ class ExamplesTests(TestCasePlus): ...@@ -231,8 +223,8 @@ class ExamplesTests(TestCasePlus):
--seed 7 --seed 7
""".split() """.split()
if torch_device != "cuda": if torch_device == "cpu":
testargs.append("--no_cuda") testargs.append("--use_cpu")
with patch.object(sys, "argv", testargs): with patch.object(sys, "argv", testargs):
run_ner.main() run_ner.main()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment