"tests/vscode:/vscode.git/clone" did not exist on "d86ddd9b2910ef0e9a093039d70c3789d3af3517"
Unverified Commit 4a09fc09 authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

[fx] fixed tracing with apex-based T5 model (#1252)

* [fx] fixed tracing with apex-based T5 model

* polish code

* polish code
parent 7531c627
......@@ -7,6 +7,7 @@ tracer.py:
import enum
import inspect
import functools
from colossalai.fx.tracer.meta_patch import meta_patched_module
import torch
import torch.nn as nn
from torch import Tensor
......@@ -181,7 +182,16 @@ class ColoTracer(Tracer):
def call_module(self, m, forward, args, kwargs):
self.orig_forward = forward
return super().call_module(m, forward, args, kwargs)
module_qualified_name = self.path_of_module(m)
# a leaf module is the torch.nn.Module subclasses starting with `torch.nn`
# which means customized modules are not leaf module by default
# if a customized or third-party module like apex.normalization.FusedRMSNorm is patched,
# we should treat it as leaf module as well
if meta_patched_module.has(m.__class__) or self.is_leaf_module(m, module_qualified_name):
return self.create_proxy('call_module', module_qualified_name, args, kwargs)
else:
return forward(*args, **kwargs)
def proxy(self, node) -> Proxy:
"""
......
import pytest
import transformers
import torch
from colossalai.fx.tracer.meta_patch import meta_patched_module
from utils import trace_model_and_compare_output
try:
import apex
@meta_patched_module.register(apex.normalization.FusedRMSNorm)
def apex_fused_layernorm(self, input):
return torch.empty(input.shape, device='meta')
except ImportError:
pass
BATCH_SIZE = 1
SEQ_LENGHT = 16
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment