Unverified Commit 8119550b authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Bugfix][Language V2] Capture closure variables from program (#1206)



* Enhance CUDA code generation by improving register type handling for float data types and introducing a workaround for TF32 compatibility. Updated MMA register type registration for A and B operands to boost performance and ensure correctness.

* lint fix

---------
Co-authored-by: default avatarZhiwen Mo <zm125@ic.ac.uk>
parent c8ec3469
......@@ -575,10 +575,25 @@ def get_type_hints(func):
if annot is None:
raise TypeError(f'Failed to get function type hints, {func} is not a function')
hints = {}
# type params are not used currently, it is support since python 3.12.4
# type_params = getattr(func, "__type_params__", ())
globalns = getattr(func, '__globals__', {})
localns = globalns
# Build eval namespaces from function globals plus captured closure variables
# This lets annotations reference symbols like `n`, `h`, or dtype vars
# defined in the outer scope of a nested function.
globalns = dict(getattr(func, '__globals__', {}))
localns = dict(globalns)
try:
freevars = getattr(func.__code__, 'co_freevars', ())
cells = getattr(func, '__closure__', ()) or ()
closure_bindings = {
name: cell.cell_contents for name, cell in zip(freevars, cells) if name not in localns
}
if closure_bindings:
localns.update(closure_bindings)
# Also update globals so ForwardRef eval sees them uniformly
globalns.update(closure_bindings)
except Exception:
# Be permissive: absence or access issues with closure shouldn't crash
pass
for name, value in annot.items():
if name == 'return':
continue
......@@ -588,10 +603,12 @@ def get_type_hints(func):
if value is None:
value = type(None)
if isinstance(value, str):
# this branch handles T.float32 style annotation
# since they are string, directly evaluating them usually causes NameError
# so we need to split and evaluate them separately
# Handle simple dtype aliases like T.float32 appearing as strings
# Evaluate directly only when it matches known dtypes
try:
_, v = value.split('.', maxsplit=1)
except ValueError:
v = value
if v in dt._all_dtypes:
try:
hints[name] = eval(value, globalns, localns)
......@@ -599,8 +616,7 @@ def get_type_hints(func):
except Exception:
pass
value = ForwardRef(value, is_argument=True, is_class=False)
hints[name] = _eval_type(
value, globalns=globalns, localns=localns) #, type_params=type_params)
hints[name] = _eval_type(value, globalns=globalns, localns=localns)
return hints
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment