"...git@developer.sourcefind.cn:OpenDAS/diffusers.git" did not exist on "01782c220eea24f74e149c0ab1153e74ba65711b"
Unverified Commit 3516f1ee authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Language] Enhance T.dtype.as_torch conversion for compatibility (#1473)

* [Language] Enhance dtype conversion for PyTorch compatibility

- Added support for new float8 and float4 data types in the __dtype_as_torch__ method.
- Implemented backend-specific handling for float8_e4m3 based on HIP or CUDA.
- Included assertions to ensure compatibility with the required PyTorch versions for each dtype.
- Improved error handling for unsupported dtypes.

* Fix test script execution and improve error messages for dtype assertions

- Commented out the main execution call in the test script and replaced it with a direct call to the test function `test_divmod()`.
- Enhanced error messages in the dtype conversion assertions to improve clarity and readability, ensuring proper guidance for required PyTorch versions.
parent 95e3b5a7
...@@ -157,8 +157,42 @@ def __dtype_call__(self: dtype, expr=None, is_size_var: bool = False) -> tir.Var ...@@ -157,8 +157,42 @@ def __dtype_call__(self: dtype, expr=None, is_size_var: bool = False) -> tir.Var
def __dtype_as_torch__(self: dtype) -> torch.dtype: def __dtype_as_torch__(self: dtype) -> torch.dtype:
"""Convert TileLang dtype to PyTorch dtype.""" """Convert TileLang dtype to PyTorch dtype."""
dtype_str = str(self) dtype_str = str(self)
if dtype_str in _STR_TO_TORCH_DTYPE:
if dtype_str == "float8_e4m3":
# Check if we're on HIP (AMD ROCm) or CUDA
if torch.version.hip is not None:
# HIP backend - use float8_e4m3fnuz
assert hasattr(torch, "float8_e4m3fnuz"), (
"torch.float8_e4m3fnuz is not supported in this version of torch. Please upgrade torch >= 2.2.0"
)
return torch.float8_e4m3fnuz
else:
# CUDA backend - use float8_e4m3fn
assert hasattr(torch, "float8_e4m3fn"), (
"torch.float8_e4m3fn is not supported in this version of torch. Please upgrade torch >= 2.1.0"
)
return torch.float8_e4m3fn
elif dtype_str == "float8_e5m2":
assert hasattr(torch, "float8_e5m2"), "torch.float8_e5m2 is not supported in this version of torch. Please upgrade torch >= 2.1.0"
return torch.float8_e5m2
elif dtype_str == "e4m3fnuz_float8":
assert hasattr(torch, "float8_e4m3fnuz"), (
"torch.float8_e4m3fnuz is not supported in this version of torch. Please upgrade torch >= 2.2.0"
)
return torch.float8_e4m3fnuz
elif dtype_str == "float8_e8m0fnu":
assert hasattr(torch, "float8_e8m0fnu"), (
"torch.float8_e8m0fnu is not supported in this version of torch. Please upgrade torch >= 2.8.0"
)
return torch.float8_e8m0fnu
elif dtype_str == "float4_e2m1fnx2":
assert hasattr(torch, "float4_e2m1fnx2"), (
"torch.float4_e2m1fnx2 is not supported in this version of torch. Please upgrade torch >= 2.8.0"
)
return torch.float4_e2m1fnx2
elif dtype_str in _STR_TO_TORCH_DTYPE:
return _STR_TO_TORCH_DTYPE[dtype_str] return _STR_TO_TORCH_DTYPE[dtype_str]
raise ValueError(f"Cannot convert dtype '{dtype_str}' to torch.dtype. Supported dtypes: {list(_STR_TO_TORCH_DTYPE.keys())}") raise ValueError(f"Cannot convert dtype '{dtype_str}' to torch.dtype. Supported dtypes: {list(_STR_TO_TORCH_DTYPE.keys())}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment