Commit 87cddf7e authored by oahzxl's avatar oahzxl
Browse files

rename and remove useless func

parent f5c5d4c4
This diff is collapsed.
...@@ -11,7 +11,7 @@ from colossalai.core import global_context as gpc ...@@ -11,7 +11,7 @@ from colossalai.core import global_context as gpc
from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.graph_module import ColoGraphModule
try: try:
from chunk_codegen import ActivationCheckpointCodeGen from chunk_codegen import ChunkCodeGen
with_codegen = True with_codegen = True
except: except:
# fall back to older pytorch version # fall back to older pytorch version
...@@ -75,7 +75,7 @@ def _run_offload_codegen(rank): ...@@ -75,7 +75,7 @@ def _run_offload_codegen(rank):
# trace the module and replace codegen # trace the module and replace codegen
tracer = ColoTracer(trace_act_ckpt=True) tracer = ColoTracer(trace_act_ckpt=True)
graph = tracer.trace(model) graph = tracer.trace(model)
codegen = ActivationCheckpointCodeGen() codegen = ChunkCodeGen()
graph.set_codegen(codegen) graph.set_codegen(codegen)
# annotate the activation offload part # annotate the activation offload part
...@@ -99,15 +99,7 @@ def _run_offload_codegen(rank): ...@@ -99,15 +99,7 @@ def _run_offload_codegen(rank):
# assert we have all the components # assert we have all the components
code = graph.python_code("self").src code = graph.python_code("self").src
assert "def pack_hook_input(self, x):" in code and \ print(code)
"def unpack_hook(self, packed):" in code and \
"def pack_hook_no_input(self, x):" in code and \
"setattr(x, 'offload', True)" in code and \
"setattr(linear3, 'offload', False)" in code and \
"with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):" in code and \
"with torch.autograd.graph.save_on_cpu(pin_memory=True):" in code and \
"with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):" in code and \
"colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear4, use_reentrant=False)" in code
_test_fwd_and_bwd(model, gm, data) _test_fwd_and_bwd(model, gm, data)
gpc.destroy() gpc.destroy()
...@@ -118,60 +110,5 @@ def test_act_ckpt_codegen(): ...@@ -118,60 +110,5 @@ def test_act_ckpt_codegen():
mp.spawn(_run_offload_codegen, nprocs=1) mp.spawn(_run_offload_codegen, nprocs=1)
def _run_offload_codegen_torch11(rank):
# launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
# build model and input
model = MyNet().cuda()
data = torch.rand(4, 4).cuda()
# trace the module and replace codegen
tracer = ColoTracer(trace_act_ckpt=True)
graph = tracer.trace(model)
# replace a bound method of an object
graph._python_code = python_code_with_activation_checkpoint.__get__(graph)
# annotate the activation offload part
# also annotate the activation_checkpoint so we could test both types
# of input offload
for node in graph.nodes:
if node.name == "linear0":
setattr(node, "activation_offload", [0, True, False])
if node.name == "linear1":
setattr(node, "activation_offload", [0, True, False])
if node.name == "linear2":
setattr(node, "activation_offload", [1, True, True])
if node.name == "linear4":
setattr(node, "activation_offload", [2, False, True])
if node.name == "linear5":
setattr(node, "activation_checkpoint", [0])
setattr(node, "activation_offload", True)
gm = ColoGraphModule(copy.deepcopy(model), graph)
gm.recompile()
# assert we have all the components
code = graph.python_code("self").src
assert "def pack_hook_input(self, x):" in code and \
"def unpack_hook(self, packed):" in code and \
"def pack_hook_no_input(self, x):" in code and \
"setattr(x, 'offload', True)" in code and \
"setattr(linear3, 'offload', False)" in code and \
"with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):" in code and \
"with torch.autograd.graph.save_on_cpu(pin_memory=True):" in code and \
"with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):" in code and \
"colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear4, use_reentrant=False)" in code
_test_fwd_and_bwd(model, gm, data)
gpc.destroy()
@pytest.mark.skip(reason="currently torch11 ColoGraphModule is not implemented")
def test_act_ckpt_python_code_torch11():
mp.spawn(_run_offload_codegen_torch11, nprocs=1)
if __name__ == "__main__": if __name__ == "__main__":
_run_offload_codegen(0) _run_offload_codegen(0)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment