"tests/L1/git@developer.sourcefind.cn:OpenDAS/apex.git" did not exist on "40555b3a5e5ef8faefed69cb599fad73afdb9574"
Commit 990e0aa5 authored by Wenhao Xie's avatar Wenhao Xie Committed by LeiWang1999
Browse files

[Refactor] Align torch_assert_close tensor comparison with torch.testing.assert_close (#239)

* [Typo] Fix formatting in installation instructions in README.md

* [Enhancement] Improve CUDA path detection and update configuration handling

* fix typo

* remove IS_WINDOWS constant

* lint fix

* Improve error messages for CUDA detection failure

* lint fix

* lint fix

* Fix .gitignore to correctly include venv directory

* [Doc] Add instructions for installing nightly version of TileLang

* update installation instructions

* update install instruction

* [Refactor] Enhance tensor comparison by integrating attribute checks and equalization in torch_assert_close

* add mismatch info

* set the default value of check_dtype to false

* fix import

* set equal_nan to true by default in torch_assert_close
parent 2a286ae6
...@@ -8,6 +8,8 @@ import torch ...@@ -8,6 +8,8 @@ import torch
import numpy as np import numpy as np
from tvm.testing.utils import * from tvm.testing.utils import *
from tilelang.utils.tensor import torch_assert_close as torch_assert_close
# pytest.main() wrapper to allow running single test file # pytest.main() wrapper to allow running single test file
def main(): def main():
...@@ -15,91 +17,6 @@ def main(): ...@@ -15,91 +17,6 @@ def main():
sys.exit(pytest.main([test_file] + sys.argv[1:])) sys.exit(pytest.main([test_file] + sys.argv[1:]))
def torch_assert_close(tensor_a,
tensor_b,
rtol=1e-2,
atol=1e-2,
max_mismatched_ratio=0.001,
verbose=False):
"""
Custom function to assert that two tensors are "close enough," allowing a specified
percentage of mismatched elements.
Parameters:
----------
tensor_a : torch.Tensor
The first tensor to compare.
tensor_b : torch.Tensor
The second tensor to compare.
rtol : float, optional
Relative tolerance for comparison. Default is 1e-2.
atol : float, optional
Absolute tolerance for comparison. Default is 1e-3.
max_mismatched_ratio : float, optional
Maximum ratio of mismatched elements allowed (relative to the total number of elements).
Default is 0.001 (0.1% of total elements).
Raises:
-------
AssertionError:
If the ratio of mismatched elements exceeds `max_mismatched_ratio`.
"""
import torch
# Assert shapes are the same
assert tensor_a.shape == tensor_b.shape, f"Tensor shapes must be the same, but got {tensor_a.shape} and {tensor_b.shape}"
# Compute the absolute difference between the two tensors
diff = torch.abs(tensor_a - tensor_b)
# Compute the maximum allowable difference for each element
max_diff = atol + rtol * torch.abs(tensor_b)
# Identify elements where the difference exceeds the maximum allowable difference
mismatched = diff > max_diff
# Count the number of mismatched elements
num_mismatched = mismatched.sum().item()
# Calculate the total number of elements in the tensor
total_elements = tensor_a.numel()
# Compute the allowed mismatched elements based on the ratio
max_allowed_mismatched = int(total_elements * max_mismatched_ratio)
# Print debug information about the mismatch
if verbose:
print(f"Number of mismatched elements: {num_mismatched} / {total_elements} "
f"(allowed: {max_allowed_mismatched})")
# If there are mismatched elements, print the first mismatch
if num_mismatched > 0:
# Find the first mismatch index
flat_idx = torch.argmax(mismatched.view(-1).int()).item()
idx = np.unravel_index(flat_idx, tensor_a.shape)
idx = [int(i) for i in idx]
a_val = tensor_a.view(-1)[flat_idx].item()
b_val = tensor_b.view(-1)[flat_idx].item()
abs_diff = abs(a_val - b_val)
rel_diff = abs_diff / (abs(b_val) + 1e-12)
mismatch_info = (f"\nFirst mismatch at index {idx}: "
f"lhs={a_val:.6f}, rhs={b_val:.6f}, "
f"abs_diff={abs_diff:.6f}, rel_diff={rel_diff:.6f}")
else:
mismatch_info = ""
# Modify the exception information
if num_mismatched > max_allowed_mismatched:
raise AssertionError(
f"Too many mismatched elements: {num_mismatched} > {max_allowed_mismatched} "
f"({max_mismatched_ratio * 100:.2f}% allowed, but get {num_mismatched / total_elements * 100:.2f}%)."
f"{mismatch_info}"
f"\nGreatest absolute difference: {diff.max().item()}, "
f"Greatest relative difference: {(diff / (torch.abs(tensor_b) + 1e-12)).max().item()}.")
else:
return True
def set_random_seed(seed: int = 42) -> None: def set_random_seed(seed: int = 42) -> None:
random.seed(seed) random.seed(seed)
np.random.seed(seed) np.random.seed(seed)
......
...@@ -3,6 +3,7 @@ from enum import Enum ...@@ -3,6 +3,7 @@ from enum import Enum
import torch import torch
from tvm.runtime import ndarray from tvm.runtime import ndarray
from torch.utils.dlpack import to_dlpack from torch.utils.dlpack import to_dlpack
import numpy as np
class TensorSupplyType(Enum): class TensorSupplyType(Enum):
...@@ -102,14 +103,105 @@ def get_tensor_supply(supply_type: TensorSupplyType): ...@@ -102,14 +103,105 @@ def get_tensor_supply(supply_type: TensorSupplyType):
return get_tensor return get_tensor
# TODO: Align with torch.testing.assert_close # Adapted from https://github.com/pytorch/pytorch/blob/main/torch/testing/_comparison.py
def _compare_attributes(
actual: torch.Tensor,
expected: torch.Tensor,
check_device: bool = True,
check_dtype: bool = True,
check_layout: bool = True,
check_stride: bool = False,
) -> None:
"""Checks if the attributes of two tensors match.
Always checks
- the :attr:`~torch.Tensor.shape`,
- whether both inputs are quantized or not,
- and if they use the same quantization scheme.
Checks for
- :attr:`~torch.Tensor.layout`,
- :meth:`~torch.Tensor.stride`,
- :attr:`~torch.Tensor.device`, and
- :attr:`~torch.Tensor.dtype`
are optional and can be disabled through the corresponding ``check_*`` flag during construction of the pair.
"""
def raise_mismatch_error(attribute_name: str, actual_value, expected_value):
raise AssertionError(
f"The values for attribute '{attribute_name}' do not match: {actual_value} != {expected_value}."
)
if actual.shape != expected.shape:
raise_mismatch_error("shape", actual.shape, expected.shape)
if actual.is_quantized != expected.is_quantized:
raise_mismatch_error("is_quantized", actual.is_quantized, expected.is_quantized)
elif actual.is_quantized and actual.qscheme() != expected.qscheme():
raise_mismatch_error("qscheme()", actual.qscheme(), expected.qscheme())
if actual.layout != expected.layout:
if check_layout:
raise_mismatch_error("layout", actual.layout, expected.layout)
elif (actual.layout == torch.strided and check_stride and actual.stride() != expected.stride()):
raise_mismatch_error("stride()", actual.stride(), expected.stride())
if check_device and actual.device != expected.device:
raise_mismatch_error("device", actual.device, expected.device)
if check_dtype and actual.dtype != expected.dtype:
raise_mismatch_error("dtype", actual.dtype, expected.dtype)
def _equalize_attributes(actual: torch.Tensor,
expected: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Equalizes some attributes of two tensors for value comparison.
If ``actual`` and ``expected`` are ...
- ... not on the same :attr:`~torch.Tensor.device`, they are moved CPU memory.
- ... not of the same ``dtype``, they are promoted to a common ``dtype`` (according to
:func:`torch.promote_types`).
- ... not of the same ``layout``, they are converted to strided tensors.
Args:
actual (Tensor): Actual tensor.
expected (Tensor): Expected tensor.
Returns:
(Tuple[Tensor, Tensor]): Equalized tensors.
"""
# The comparison logic uses operators currently not supported by the MPS backends.
# See https://github.com/pytorch/pytorch/issues/77144 for details.
# TODO: Remove this conversion as soon as all operations are supported natively by the MPS backend
if actual.is_mps or expected.is_mps: # type: ignore[attr-defined]
actual = actual.cpu()
expected = expected.cpu()
if actual.device != expected.device:
actual = actual.cpu()
expected = expected.cpu()
if actual.dtype != expected.dtype:
actual_dtype = actual.dtype
expected_dtype = expected.dtype
# For uint64, this is not sound in general, which is why promote_types doesn't
# allow it, but for easy testing, we're unlikely to get confused
# by large uint64 overflowing into negative int64
if actual_dtype in [torch.uint64, torch.uint32, torch.uint16]:
actual_dtype = torch.int64
if expected_dtype in [torch.uint64, torch.uint32, torch.uint16]:
expected_dtype = torch.int64
dtype = torch.promote_types(actual_dtype, expected_dtype)
actual = actual.to(dtype)
expected = expected.to(dtype)
if actual.layout != expected.layout:
# These checks are needed, since Tensor.to_dense() fails on tensors that are already strided
actual = actual.to_dense() if actual.layout != torch.strided else actual
expected = (expected.to_dense() if expected.layout != torch.strided else expected)
return actual, expected
def torch_assert_close( def torch_assert_close(
tensor_a, tensor_a,
tensor_b, tensor_b,
rtol=1e-2, rtol=1e-2,
atol=1e-3, atol=1e-3,
max_mismatched_ratio=0.001, max_mismatched_ratio=0.001,
verbose=False, verbose: bool = False,
equal_nan: bool = True,
check_device: bool = True,
check_dtype: bool = False,
check_layout: bool = True,
check_stride: bool = False,
): ):
""" """
Custom function to assert that two tensors are "close enough," allowing a specified Custom function to assert that two tensors are "close enough," allowing a specified
...@@ -134,17 +226,19 @@ def torch_assert_close( ...@@ -134,17 +226,19 @@ def torch_assert_close(
AssertionError: AssertionError:
If the ratio of mismatched elements exceeds `max_mismatched_ratio`. If the ratio of mismatched elements exceeds `max_mismatched_ratio`.
""" """
import torch
_compare_attributes(
tensor_a,
tensor_b,
check_device=check_device,
check_dtype=check_dtype,
check_layout=check_layout,
check_stride=check_stride)
tensor_a, tensor_b = _equalize_attributes(tensor_a, tensor_b)
mismatched = ~torch.isclose(tensor_a, tensor_b, rtol=rtol, atol=atol, equal_nan=equal_nan)
# Compute the absolute difference between the two tensors # Compute the absolute difference between the two tensors
diff = torch.abs(tensor_a - tensor_b) diff = torch.abs(tensor_a - tensor_b)
# Compute the maximum allowable difference for each element
max_diff = atol + rtol * torch.abs(tensor_b)
# Identify elements where the difference exceeds the maximum allowable difference
mismatched = diff > max_diff
# Count the number of mismatched elements # Count the number of mismatched elements
num_mismatched = mismatched.sum().item() num_mismatched = mismatched.sum().item()
...@@ -159,12 +253,29 @@ def torch_assert_close( ...@@ -159,12 +253,29 @@ def torch_assert_close(
print(f"Number of mismatched elements: {num_mismatched} / {total_elements} " print(f"Number of mismatched elements: {num_mismatched} / {total_elements} "
f"(allowed: {max_allowed_mismatched})") f"(allowed: {max_allowed_mismatched})")
# Check if the number of mismatched elements exceeds the allowed threshold # If there are mismatched elements, print the first mismatch
if num_mismatched > 0:
# Find the first mismatch index
flat_idx = torch.argmax(mismatched.view(-1).int()).item()
idx = np.unravel_index(flat_idx, tensor_a.shape)
idx = [int(i) for i in idx]
a_val = tensor_a.reshape(-1)[flat_idx].item()
b_val = tensor_b.reshape(-1)[flat_idx].item()
abs_diff = abs(a_val - b_val)
rel_diff = abs_diff / (abs(b_val) + 1e-12)
mismatch_info = (f"\nFirst mismatch at index {idx}: "
f"lhs={a_val:.6f}, rhs={b_val:.6f}, "
f"abs_diff={abs_diff:.6f}, rel_diff={rel_diff:.6f}")
else:
mismatch_info = ""
# Modify the exception information
if num_mismatched > max_allowed_mismatched: if num_mismatched > max_allowed_mismatched:
raise AssertionError( raise AssertionError(
f"Too many mismatched elements: {num_mismatched} > {max_allowed_mismatched} " f"Too many mismatched elements: {num_mismatched} > {max_allowed_mismatched} "
f"({max_mismatched_ratio * 100:.2f}% allowed). " f"({max_mismatched_ratio * 100:.2f}% allowed, but get {num_mismatched / total_elements * 100:.2f}%)."
f"Greatest absolute difference: {diff.max().item()}, " f"{mismatch_info}"
f"\nGreatest absolute difference: {diff.max().item()}, "
f"Greatest relative difference: {(diff / (torch.abs(tensor_b) + 1e-12)).max().item()}.") f"Greatest relative difference: {(diff / (torch.abs(tensor_b) + 1e-12)).max().item()}.")
else: else:
return True return True
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment