"python-wheel/examples/bls/bls.py" did not exist on "c9130f8f8ce264379131e9ee2973534fe4cbf713"
Unverified Commit 5a52e21f authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

[test] fixed the activation codegen test (#1447)

* [test] fixed the activation codegen test

* polish code
parent 0f304236
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import os import os
from pprint import pp
import random import random
import socket import socket
from pathlib import Path from pathlib import Path
......
from operator import mod from operator import mod
import torch import torch
import pytest import pytest
import torch.multiprocessing as mp
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
from torch.fx import GraphModule from torch.fx import GraphModule
from colossalai.fx import ColoTracer from colossalai.fx import ColoTracer
...@@ -42,10 +43,9 @@ class MyModule(torch.nn.Module): ...@@ -42,10 +43,9 @@ class MyModule(torch.nn.Module):
return y1 + y2 + y3 + y4 return y1 + y2 + y3 + y4
@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0') def _run_act_ckpt_codegen(rank):
def test_act_ckpt_codegen():
# launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly
colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl') colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
# build model and run forward # build model and run forward
model = MyModule() model = MyModule()
...@@ -90,10 +90,14 @@ def test_act_ckpt_codegen(): ...@@ -90,10 +90,14 @@ def test_act_ckpt_codegen():
gpc.destroy() gpc.destroy()
@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0') @pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0')
def test_act_ckpt_python_code_torch11(): def test_act_ckpt_codegen():
mp.spawn(_run_act_ckpt_codegen, nprocs=1)
def _run_act_ckpt_python_code_torch11(rank):
# launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly
colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl') colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
# build model and run forward # build model and run forward
model = MyModule() model = MyModule()
...@@ -138,6 +142,11 @@ def test_act_ckpt_python_code_torch11(): ...@@ -138,6 +142,11 @@ def test_act_ckpt_python_code_torch11():
gpc.destroy() gpc.destroy()
@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0')
def test_act_ckpt_python_code_torch11():
mp.spawn(_run_act_ckpt_python_code_torch11, nprocs=1)
if __name__ == '__main__': if __name__ == '__main__':
test_act_ckpt_codegen() test_act_ckpt_codegen()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment