Unverified Commit 39359e5b authored by Michael Benayoun's avatar Michael Benayoun Committed by GitHub
Browse files

Fix FX tracing issues for Llama (#30619)

parent 9719202d
...@@ -714,9 +714,14 @@ class HFCacheProxy(HFProxy): ...@@ -714,9 +714,14 @@ class HFCacheProxy(HFProxy):
Proxy that represents an instance of `transformers.cache_utils.Cache`. Proxy that represents an instance of `transformers.cache_utils.Cache`.
""" """
def install_orig_cache_cls(self, orig_cache_cls: Type[Cache]):
self._orig_cache_cls = orig_cache_cls
@property @property
def __class__(self): def __class__(self):
return ProxyableCache if not hasattr(self, "_orig_cache_cls"):
raise RuntimeError("The original Cache class must be installed to the HFCacheProxy.")
return self.tracer._CLASSES_TO_PATCH[self._orig_cache_cls]
def create_wrapper( def create_wrapper(
...@@ -806,23 +811,39 @@ def _proxies_to_metas(v): ...@@ -806,23 +811,39 @@ def _proxies_to_metas(v):
return v return v
def cache_proxy_factory_fn(n: Node) -> HFCacheProxy: def create_cache_proxy_factory_fn(orig_cache_cls: Type[Cache]) -> Callable[[Node], HFCacheProxy]:
global _CURRENT_TRACER def cache_proxy_factory_fn(n: Node) -> HFCacheProxy:
if not isinstance(_CURRENT_TRACER, HFTracer): global _CURRENT_TRACER
raise RuntimeError("Cannot create HFCacheProxy because there is no HFTracer currently tracing.") if not isinstance(_CURRENT_TRACER, HFTracer):
return HFCacheProxy(n, _CURRENT_TRACER) raise RuntimeError("Cannot create HFCacheProxy because there is no HFTracer currently tracing.")
cache_proxy = HFCacheProxy(n, _CURRENT_TRACER)
cache_proxy.install_orig_cache_cls(orig_cache_cls)
return cache_proxy
return cache_proxy_factory_fn
# Proxyable equivalent of the cache classes defined in `transformers.cache_utils`. # Proxyable equivalent of the cache classes defined in `transformers.cache_utils`.
ProxyableCache = HFProxyableClassMeta("ProxyableCache", (Cache,), {}, proxy_factory_fn=cache_proxy_factory_fn) ProxyableCache = HFProxyableClassMeta(
"ProxyableCache", (Cache,), {}, proxy_factory_fn=create_cache_proxy_factory_fn(Cache)
)
ProxyableDynamicCache = HFProxyableClassMeta( ProxyableDynamicCache = HFProxyableClassMeta(
"ProxyableDynamicCache", (DynamicCache,), {}, proxy_factory_fn=cache_proxy_factory_fn "ProxyableDynamicCache",
(DynamicCache,),
{},
proxy_factory_fn=create_cache_proxy_factory_fn(DynamicCache),
) )
ProxyableSinkCache = HFProxyableClassMeta( ProxyableSinkCache = HFProxyableClassMeta(
"ProxyableSinkCache", (SinkCache,), {}, proxy_factory_fn=cache_proxy_factory_fn "ProxyableSinkCache",
(SinkCache,),
{},
proxy_factory_fn=create_cache_proxy_factory_fn(SinkCache),
) )
ProxyableStaticCache = HFProxyableClassMeta( ProxyableStaticCache = HFProxyableClassMeta(
"ProxyableStaticCache", (StaticCache,), {}, proxy_factory_fn=cache_proxy_factory_fn "ProxyableStaticCache",
(StaticCache,),
{},
proxy_factory_fn=create_cache_proxy_factory_fn(StaticCache),
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment