"examples/vscode:/vscode.git/clone" did not exist on "7713e25c1ab08aec268a8b2156a4a1bd2d3dfd4a"
Unverified Commit 4a09fc09 authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

[fx] fixed tracing with apex-based T5 model (#1252)

* [fx] fixed tracing with apex-based T5 model

* polish code

* polish code
parent 7531c627
...@@ -7,6 +7,7 @@ tracer.py: ...@@ -7,6 +7,7 @@ tracer.py:
import enum import enum
import inspect import inspect
import functools import functools
from colossalai.fx.tracer.meta_patch import meta_patched_module
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch import Tensor from torch import Tensor
...@@ -181,7 +182,16 @@ class ColoTracer(Tracer): ...@@ -181,7 +182,16 @@ class ColoTracer(Tracer):
def call_module(self, m, forward, args, kwargs): def call_module(self, m, forward, args, kwargs):
self.orig_forward = forward self.orig_forward = forward
return super().call_module(m, forward, args, kwargs) module_qualified_name = self.path_of_module(m)
# a leaf module is the torch.nn.Module subclasses starting with `torch.nn`
# which means customized modules are not leaf module by default
# if a customized or third-party module like apex.normalization.FusedRMSNorm is patched,
# we should treat it as leaf module as well
if meta_patched_module.has(m.__class__) or self.is_leaf_module(m, module_qualified_name):
return self.create_proxy('call_module', module_qualified_name, args, kwargs)
else:
return forward(*args, **kwargs)
def proxy(self, node) -> Proxy: def proxy(self, node) -> Proxy:
""" """
......
import pytest import pytest
import transformers import transformers
import torch import torch
from colossalai.fx.tracer.meta_patch import meta_patched_module
from utils import trace_model_and_compare_output from utils import trace_model_and_compare_output
try:
import apex
@meta_patched_module.register(apex.normalization.FusedRMSNorm)
def apex_fused_layernorm(self, input):
return torch.empty(input.shape, device='meta')
except ImportError:
pass
BATCH_SIZE = 1 BATCH_SIZE = 1
SEQ_LENGHT = 16 SEQ_LENGHT = 16
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment