Unverified Commit 5a52e21f authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

[test] fixed the activation codegen test (#1447)

* [test] fixed the activation codegen test

* polish code
parent 0f304236
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import os
from pprint import pp
import random
import socket
from pathlib import Path
......
from operator import mod
import torch
import pytest
import torch.multiprocessing as mp
from torch.utils.checkpoint import checkpoint
from torch.fx import GraphModule
from colossalai.fx import ColoTracer
......@@ -42,10 +43,9 @@ class MyModule(torch.nn.Module):
return y1 + y2 + y3 + y4
@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0')
def test_act_ckpt_codegen():
def _run_act_ckpt_codegen(rank):
# launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly
colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl')
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
# build model and run forward
model = MyModule()
......@@ -90,10 +90,14 @@ def test_act_ckpt_codegen():
gpc.destroy()
@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0')
def test_act_ckpt_python_code_torch11():
@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0')
def test_act_ckpt_codegen():
mp.spawn(_run_act_ckpt_codegen, nprocs=1)
def _run_act_ckpt_python_code_torch11(rank):
# launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly
colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl')
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
# build model and run forward
model = MyModule()
......@@ -138,6 +142,11 @@ def test_act_ckpt_python_code_torch11():
gpc.destroy()
@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0')
def test_act_ckpt_python_code_torch11():
mp.spawn(_run_act_ckpt_python_code_torch11, nprocs=1)
if __name__ == '__main__':
test_act_ckpt_codegen()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment