Unverified Commit c36a7eee authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Typing] Fallback from Python 3.10+ type syntax for compatibility (#848)

parent 6efeb743
......@@ -5,7 +5,7 @@
import tilelang.language as T
from tvm import ir
from tvm.tir import PrimExpr, Buffer, BufferLoad, BufferRegion, Var, op
from typing import List, Union
from typing import List, Union, Optional
_MEMORY_ORDER_ID_MAP = {
"relaxed": 0,
......@@ -104,7 +104,7 @@ def buffer_region_to_tile_region(buffer_region: BufferRegion, access_type: str,
return region(T.BufferLoad(buffer_region.buffer, mins), access_type, *region_extents)
def atomic_max(dst: Buffer, value: PrimExpr, memory_order: str | None = None) -> PrimExpr:
def atomic_max(dst: Buffer, value: PrimExpr, memory_order: Optional[str] = None) -> PrimExpr:
"""
Perform an atomic maximum on the value stored at dst with an optional memory-order.
......@@ -113,7 +113,7 @@ def atomic_max(dst: Buffer, value: PrimExpr, memory_order: str | None = None) ->
Parameters:
dst (Buffer): Destination buffer/address to apply the atomic max.
value (PrimExpr): Value to compare/store atomically.
memory_order (str | None): Optional memory-order name (e.g. "relaxed", "acquire", "seq_cst").
memory_order (Optional[str]): Optional memory-order name (e.g. "relaxed", "acquire", "seq_cst").
If provided, it is translated to the corresponding numeric memory-order id before the call.
Returns:
......@@ -126,7 +126,7 @@ def atomic_max(dst: Buffer, value: PrimExpr, memory_order: str | None = None) ->
_MEMORY_ORDER_ID_MAP[memory_order])
def atomic_min(dst: Buffer, value: PrimExpr, memory_order: str | None = None) -> PrimExpr:
def atomic_min(dst: Buffer, value: PrimExpr, memory_order: Optional[str] = None) -> PrimExpr:
"""
Atomically update the value at dst to the minimum of its current value and value.
......@@ -135,7 +135,7 @@ def atomic_min(dst: Buffer, value: PrimExpr, memory_order: str | None = None) ->
to integer IDs). If memory_order is None, the extern is invoked without an explicit memory-order argument.
Parameters:
memory_order (str | None): Optional memory-order name controlling the atomic operation's ordering.
memory_order (Optional[str]): Optional memory-order name controlling the atomic operation's ordering.
Returns:
PrimExpr: A handle expression representing the atomic-min operation.
......@@ -147,7 +147,7 @@ def atomic_min(dst: Buffer, value: PrimExpr, memory_order: str | None = None) ->
_MEMORY_ORDER_ID_MAP[memory_order])
def atomic_add(dst: Buffer, value: PrimExpr, memory_order: str | None = None) -> PrimExpr:
def atomic_add(dst: Buffer, value: PrimExpr, memory_order: Optional[str] = None) -> PrimExpr:
"""
Atomically add `value` into `dst`, returning a handle to the operation.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment