Commit 71537ba5 authored by Wenhao Xie's avatar Wenhao Xie Committed by LeiWang1999
Browse files

[BugFix] Fix bug of mismatching dtype in testing (#245)

* [Typo] Fix formatting in installation instructions in README.md

* [Enhancement] Improve CUDA path detection and update configuration handling

* fix typo

* remove IS_WINDOWS constant

* lint fix

* Improve error messages for CUDA detection failure

* lint fix

* lint fix

* Fix .gitignore to correctly include venv directory

* [Doc] Add instructions for installing nightly version of TileLang

* update installation instructions

* update install instruction

* fix bug of mismatching dtype in testing and set the default value of check_dtype in torch_assert_close to true

* lint fix

* fix bug

* use map_torch_type
parent c264f37f
......@@ -6,6 +6,7 @@ import tilelang.language as T
import tilelang.testing
import tilelang
import torch
from tilelang.utils.tensor import map_torch_type
def matmul(
......@@ -197,8 +198,11 @@ def run_gemm_jit_kernel(
matmul_kernel = tilelang.compile(program, out_idx=-1, execution_backend="ctypes")
A = torch.randn(M, K, dtype=torch.__getattribute__(in_dtype)).cuda()
B = torch.randn(K, N, dtype=torch.__getattribute__(in_dtype)).cuda()
in_dtype = map_torch_type(in_dtype)
out_dtype = map_torch_type(out_dtype)
A = torch.randn(M, K, dtype=in_dtype).cuda()
B = torch.randn(K, N, dtype=in_dtype).cuda()
if trans_A:
A = A.T
......@@ -208,7 +212,7 @@ def run_gemm_jit_kernel(
def ref_program(A, B):
import torch
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
C = C.to(out_dtype)
return C
ref_C = ref_program(A, B)
......@@ -313,15 +317,16 @@ def run_ctypes_kernel_multi_stream(M,
)
matmul_kernel = tilelang.compile(program, execution_backend="ctypes")
tensor_a = torch.randn(M, K, dtype=torch.__getattribute__(in_dtype)).cuda()
tensor_b = torch.randn(K, N, dtype=torch.__getattribute__(in_dtype)).cuda()
in_dtype = map_torch_type(in_dtype)
out_dtype = map_torch_type(out_dtype)
tensor_a = torch.randn(M, K, dtype=in_dtype).cuda()
tensor_b = torch.randn(K, N, dtype=in_dtype).cuda()
if trans_A:
tensor_a = tensor_a.T
if trans_B:
tensor_b = tensor_b.T
tensor_c = torch.randn(M, N, dtype=torch.__getattribute__(out_dtype)).cuda()
tensor_c = torch.randn(M, N, dtype=out_dtype).cuda()
num_streams = 4
for _ in range(num_streams):
......@@ -371,18 +376,22 @@ def run_ctypes_dynamic_shape(M,
N = 1024
if isinstance(K, T.Var):
K = 768
tensor_a = torch.randn(M, K, dtype=torch.__getattribute__(in_dtype)).cuda()
tensor_b = torch.randn(K, N, dtype=torch.__getattribute__(in_dtype)).cuda()
in_dtype = map_torch_type(in_dtype)
out_dtype = map_torch_type(out_dtype)
tensor_a = torch.randn(M, K, dtype=in_dtype).cuda()
tensor_b = torch.randn(K, N, dtype=in_dtype).cuda()
if trans_A:
tensor_a = tensor_a.T
if trans_B:
tensor_b = tensor_b.T
tensor_c = torch.randn(M, N, dtype=torch.__getattribute__(out_dtype)).cuda()
tensor_c = torch.randn(M, N, dtype=out_dtype).cuda()
matmul_kernel(tensor_a, tensor_b, tensor_c)
tensor_ref_c = torch.matmul(tensor_a.to(torch.float), tensor_b.to(torch.float))
tensor_ref_c = torch.matmul(tensor_a.to(torch.float), tensor_b.to(torch.float)).to(out_dtype)
tilelang.testing.torch_assert_close(
tensor_c, tensor_ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
......
......@@ -6,6 +6,7 @@ import tilelang.language as T
import tilelang.testing
import tilelang
import torch
from tilelang.utils.tensor import map_torch_type
def matmul(
......@@ -197,8 +198,11 @@ def run_gemm_jit_kernel(
matmul_kernel = tilelang.compile(program, out_idx=-1, execution_backend="cython")
A = torch.randn(M, K, dtype=torch.__getattribute__(in_dtype)).cuda()
B = torch.randn(K, N, dtype=torch.__getattribute__(in_dtype)).cuda()
in_dtype = map_torch_type(in_dtype)
out_dtype = map_torch_type(out_dtype)
A = torch.randn(M, K, dtype=in_dtype).cuda()
B = torch.randn(K, N, dtype=in_dtype).cuda()
if trans_A:
A = A.T
......@@ -208,7 +212,7 @@ def run_gemm_jit_kernel(
def ref_program(A, B):
import torch
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
C = C.to(out_dtype)
return C
ref_C = ref_program(A, B)
......@@ -321,14 +325,17 @@ def run_cython_kernel_multi_stream(M,
matmul_kernel = tilelang.compile(program, execution_backend="cython")
tensor_a = torch.randn(M, K, dtype=torch.__getattribute__(in_dtype)).cuda()
tensor_b = torch.randn(K, N, dtype=torch.__getattribute__(in_dtype)).cuda()
in_dtype = map_torch_type(in_dtype)
out_dtype = map_torch_type(out_dtype)
tensor_a = torch.randn(M, K, dtype=in_dtype).cuda()
tensor_b = torch.randn(K, N, dtype=in_dtype).cuda()
if trans_A:
tensor_a = tensor_a.T
if trans_B:
tensor_b = tensor_b.T
tensor_c = torch.randn(M, N, dtype=torch.__getattribute__(out_dtype)).cuda()
tensor_c = torch.randn(M, N, dtype=out_dtype).cuda()
num_streams = 4
for _ in range(num_streams):
......@@ -378,18 +385,22 @@ def run_cython_dynamic_shape(M,
N = 1024
if isinstance(K, T.Var):
K = 768
tensor_a = torch.randn(M, K, dtype=torch.__getattribute__(in_dtype)).cuda()
tensor_b = torch.randn(K, N, dtype=torch.__getattribute__(in_dtype)).cuda()
in_dtype = map_torch_type(in_dtype)
out_dtype = map_torch_type(out_dtype)
tensor_a = torch.randn(M, K, dtype=in_dtype).cuda()
tensor_b = torch.randn(K, N, dtype=in_dtype).cuda()
if trans_A:
tensor_a = tensor_a.T
if trans_B:
tensor_b = tensor_b.T
tensor_c = torch.randn(M, N, dtype=torch.__getattribute__(out_dtype)).cuda()
tensor_c = torch.randn(M, N, dtype=out_dtype).cuda()
matmul_kernel(tensor_a, tensor_b, tensor_c)
tensor_ref_c = torch.matmul(tensor_a.to(torch.float), tensor_b.to(torch.float))
tensor_ref_c = torch.matmul(tensor_a.to(torch.float), tensor_b.to(torch.float)).to(out_dtype)
tilelang.testing.torch_assert_close(
tensor_c, tensor_ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
......@@ -443,8 +454,12 @@ def run_cython_dynamic_shape_with_out_idx(M,
N = 1024
if isinstance(K, T.Var):
K = 768
tensor_a = torch.randn(M, K, dtype=torch.__getattribute__(in_dtype)).cuda()
tensor_b = torch.randn(K, N, dtype=torch.__getattribute__(in_dtype)).cuda()
in_dtype = map_torch_type(in_dtype)
out_dtype = map_torch_type(out_dtype)
tensor_a = torch.randn(M, K, dtype=in_dtype).cuda()
tensor_b = torch.randn(K, N, dtype=in_dtype).cuda()
if trans_A:
tensor_a = tensor_a.T
......@@ -453,7 +468,7 @@ def run_cython_dynamic_shape_with_out_idx(M,
tensor_c = matmul_kernel(tensor_a, tensor_b)
tensor_ref_c = torch.matmul(tensor_a.to(torch.float), tensor_b.to(torch.float))
tensor_ref_c = torch.matmul(tensor_a.to(torch.float), tensor_b.to(torch.float)).to(out_dtype)
tilelang.testing.torch_assert_close(
tensor_c, tensor_ref_c, atol=1e-2, rtol=1e-2, max_mismatched_ratio=0.05)
......
......@@ -159,7 +159,7 @@ def evaluate_gemv_simt(
ref_c = torch.mm(A.to(torch.float32), B.T.to(torch.float32))
if with_bias:
ref_c += Bias.to(torch.float32)
ref_c = ref_c.to(out_dtype)
print(C)
print(ref_c)
tilelang.testing.torch_assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
......
......@@ -199,7 +199,7 @@ def torch_assert_close(
verbose: bool = False,
equal_nan: bool = True,
check_device: bool = True,
check_dtype: bool = False,
check_dtype: bool = True,
check_layout: bool = True,
check_stride: bool = False,
):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment