Unverified Commit 39359e5b authored by Michael Benayoun's avatar Michael Benayoun Committed by GitHub
Browse files

Fix FX tracing issues for Llama (#30619)

parent 9719202d
......@@ -714,9 +714,14 @@ class HFCacheProxy(HFProxy):
Proxy that represents an instance of `transformers.cache_utils.Cache`.
"""
def install_orig_cache_cls(self, orig_cache_cls: Type[Cache]):
self._orig_cache_cls = orig_cache_cls
@property
def __class__(self):
return ProxyableCache
if not hasattr(self, "_orig_cache_cls"):
raise RuntimeError("The original Cache class must be installed to the HFCacheProxy.")
return self.tracer._CLASSES_TO_PATCH[self._orig_cache_cls]
def create_wrapper(
......@@ -806,23 +811,39 @@ def _proxies_to_metas(v):
return v
def cache_proxy_factory_fn(n: Node) -> HFCacheProxy:
global _CURRENT_TRACER
if not isinstance(_CURRENT_TRACER, HFTracer):
raise RuntimeError("Cannot create HFCacheProxy because there is no HFTracer currently tracing.")
return HFCacheProxy(n, _CURRENT_TRACER)
def create_cache_proxy_factory_fn(orig_cache_cls: Type[Cache]) -> Callable[[Node], HFCacheProxy]:
def cache_proxy_factory_fn(n: Node) -> HFCacheProxy:
global _CURRENT_TRACER
if not isinstance(_CURRENT_TRACER, HFTracer):
raise RuntimeError("Cannot create HFCacheProxy because there is no HFTracer currently tracing.")
cache_proxy = HFCacheProxy(n, _CURRENT_TRACER)
cache_proxy.install_orig_cache_cls(orig_cache_cls)
return cache_proxy
return cache_proxy_factory_fn
# Proxyable equivalent of the cache classes defined in `transformers.cache_utils`.
ProxyableCache = HFProxyableClassMeta("ProxyableCache", (Cache,), {}, proxy_factory_fn=cache_proxy_factory_fn)
ProxyableCache = HFProxyableClassMeta(
"ProxyableCache", (Cache,), {}, proxy_factory_fn=create_cache_proxy_factory_fn(Cache)
)
ProxyableDynamicCache = HFProxyableClassMeta(
"ProxyableDynamicCache", (DynamicCache,), {}, proxy_factory_fn=cache_proxy_factory_fn
"ProxyableDynamicCache",
(DynamicCache,),
{},
proxy_factory_fn=create_cache_proxy_factory_fn(DynamicCache),
)
ProxyableSinkCache = HFProxyableClassMeta(
"ProxyableSinkCache", (SinkCache,), {}, proxy_factory_fn=cache_proxy_factory_fn
"ProxyableSinkCache",
(SinkCache,),
{},
proxy_factory_fn=create_cache_proxy_factory_fn(SinkCache),
)
ProxyableStaticCache = HFProxyableClassMeta(
"ProxyableStaticCache", (StaticCache,), {}, proxy_factory_fn=cache_proxy_factory_fn
"ProxyableStaticCache",
(StaticCache,),
{},
proxy_factory_fn=create_cache_proxy_factory_fn(StaticCache),
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment