Unverified Commit 8119550b authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Bugfix][Language V2] Capture closure variables from program (#1206)



* Enhance CUDA code generation by improving register type handling for float data types and introducing a workaround for TF32 compatibility. Updated MMA register type registration for A and B operands to boost performance and ensure correctness.

* lint fix

---------
Co-authored-by: default avatarZhiwen Mo <zm125@ic.ac.uk>
parent c8ec3469
...@@ -575,10 +575,25 @@ def get_type_hints(func): ...@@ -575,10 +575,25 @@ def get_type_hints(func):
if annot is None: if annot is None:
raise TypeError(f'Failed to get function type hints, {func} is not a function') raise TypeError(f'Failed to get function type hints, {func} is not a function')
hints = {} hints = {}
# type params are not used currently, it is support since python 3.12.4 # Build eval namespaces from function globals plus captured closure variables
# type_params = getattr(func, "__type_params__", ()) # This lets annotations reference symbols like `n`, `h`, or dtype vars
globalns = getattr(func, '__globals__', {}) # defined in the outer scope of a nested function.
localns = globalns globalns = dict(getattr(func, '__globals__', {}))
localns = dict(globalns)
try:
freevars = getattr(func.__code__, 'co_freevars', ())
cells = getattr(func, '__closure__', ()) or ()
closure_bindings = {
name: cell.cell_contents for name, cell in zip(freevars, cells) if name not in localns
}
if closure_bindings:
localns.update(closure_bindings)
# Also update globals so ForwardRef eval sees them uniformly
globalns.update(closure_bindings)
except Exception:
# Be permissive: absence or access issues with closure shouldn't crash
pass
for name, value in annot.items(): for name, value in annot.items():
if name == 'return': if name == 'return':
continue continue
...@@ -588,10 +603,12 @@ def get_type_hints(func): ...@@ -588,10 +603,12 @@ def get_type_hints(func):
if value is None: if value is None:
value = type(None) value = type(None)
if isinstance(value, str): if isinstance(value, str):
# this branch handles T.float32 style annotation # Handle simple dtype aliases like T.float32 appearing as strings
# since they are string, directly evaluating them usually causes NameError # Evaluate directly only when it matches known dtypes
# so we need to split and evaluate them separately try:
_, v = value.split('.', maxsplit=1) _, v = value.split('.', maxsplit=1)
except ValueError:
v = value
if v in dt._all_dtypes: if v in dt._all_dtypes:
try: try:
hints[name] = eval(value, globalns, localns) hints[name] = eval(value, globalns, localns)
...@@ -599,8 +616,7 @@ def get_type_hints(func): ...@@ -599,8 +616,7 @@ def get_type_hints(func):
except Exception: except Exception:
pass pass
value = ForwardRef(value, is_argument=True, is_class=False) value = ForwardRef(value, is_argument=True, is_class=False)
hints[name] = _eval_type( hints[name] = _eval_type(value, globalns=globalns, localns=localns)
value, globalns=globalns, localns=localns) #, type_params=type_params)
return hints return hints
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment