Commit 7959d786 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[CI] Comprehensive Test cases Implementation of Matmul Dequantize (#32)

* installation script fix

* readme typo fix

* doc fix for dequantize gemm

* [Doc] remove CODE_OF_CONDUCT.md and SECURITY.md; update references in CONTRIBUTING.md

* [Doc] add unit tests for AnnotateDeviceRegions transform; remove SUPPORT.md

* update license

* [Enhancement] add tensor supply handling for unsigned integers; improve error message for execution backend assertion

* [Refactor] improve code readability by reformatting function signatures and assertions

* [Refactor] replace torch.manual_seed with tilelang.testing.set_random_seed for consistency in random seed handling
parent 34e0883d
*.h linguist-language=C++
......@@ -8,7 +8,7 @@ from tvm import DataType
import tilelang as TL
import tilelang.language as T
torch.manual_seed(0)
tilelang.testing.set_random_seed(0)
def matmul(
......
......@@ -12,8 +12,6 @@ from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitter,)
from tilelang.transform import simplify_prim_func
torch.manual_seed(0)
def make_swizzle_layout(shared_buf):
dtype = shared_buf.dtype
......
......@@ -12,7 +12,7 @@ from tilelang.intrinsics.mfma_macro_generator import (
MatrixCoreIntrinEmitter,)
from tilelang.transform import simplify_prim_func
torch.manual_seed(0)
tilelang.testing.set_random_seed(0)
@simplify_prim_func
......
......@@ -11,7 +11,7 @@ import tilelang.language as T
from tilelang.intrinsics.utils import get_swizzle_layout
from tilelang.intrinsics.mma_macro_generator import (TensorCoreIntrinEmitter)
torch.manual_seed(0)
tilelang.testing.set_random_seed(0)
def make_swizzle_layout(shared_buf):
......
......@@ -4,11 +4,211 @@ import torch
import torch.backends
import tilelang.testing
from tilelang import tvm as tvm
from tvm import DataType
from tvm import DataType, tir
import tilelang as TL
import tilelang.language as T
from tilelang import JITKernel, Profiler
tilelang.testing.set_random_seed(0)
def _tir_u8_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
assert nbit == 4
assert dtype == "float16"
assert val.dtype == "uint8"
# e_f4 == 0 -> e_f16 = 0
# e_f4 != 0 -> e_f16 = e_f4 + 8 = e_f4 | (1000)_2
# s1e2n1
mask = tir.const((1 << nbit) - 1, "uint16")
f4 = (val >> (pos.astype("uint16") * tir.const(nbit, "uint16"))) & mask
s = f4 >> tir.const(3, "uint16")
e_f4 = f4 & tir.const(7, "uint16")
e_f16 = e_f4 | tir.const(8, "uint16")
val_f16 = tir.reinterpret(
"float16",
((e_f16 | (s << tir.const(5, "uint16"))) << tir.const(10, "uint16")).astype("uint16"))
# return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, "float16"), val_f16)
return val_f16
def torch_convert(tensor):
def print_bit(name, val):
val_cpu = val.cpu().item()
binary_repr = f'{val_cpu:032b}'
print(name, binary_repr)
def _convert(val, pos):
assert val.dtype == torch.uint8
val = val.view(torch.int8)
mask = (1 << 4) - 1
f4 = ((val >> (pos * 4)) & mask).to(torch.int16)
s = f4 >> 3
e_f4 = f4 & 7
e_f16 = e_f4 | 8
val_f16 = ((e_f16 | (s << 5)) << 10) & 0xFFFF
lower_16_bits = (val_f16 & 0xFFFF).to(torch.uint16)
return lower_16_bits.view(torch.float16)
N = tensor.shape[0]
K = tensor.shape[1]
new_tensor = torch.empty(N, K * 2, dtype=torch.float16, device=tensor.device)
for i in range(new_tensor.shape[0]):
for j in range(new_tensor.shape[1]):
new_tensor[i][j] = _convert(tensor[i][j // 2], j % 2)
return new_tensor
def _convert_test(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128):
num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8"
B_shape = (N, K // num_elems_per_byte)
B_shared_shape = (block_N, block_K // num_elems_per_byte)
B_dequantize_shared_shape = (block_N, block_K)
@T.prim_func
def main(
B: T.Buffer(B_shape, storage_dtype),
C: T.Buffer((N, K), in_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx):
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1):
T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)
T.copy(B_shared, B_local)
for i, j in T.Parallel(block_N, block_K):
B_dequantize_local[i, j] = _tir_u8_to_f4_to_f16(
num_bits,
B_local[i, j // num_elems_per_byte],
j % num_elems_per_byte,
dtype=in_dtype,
)
T.copy(B_dequantize_local, C[bx * block_N, k * block_K])
return main
def test_fp4_fp16_convert_close():
N, K = 256, 256
block_N, block_K = 64, 64
program = _convert_test(
N,
K,
block_N,
block_K,
"float16",
)
mod, params = tilelang.lower(program)
mod = Profiler(mod, params, [1], tilelang.TensorSupplyType.Integer)
B = torch.randint(0, 16, (N, K // 2), dtype=torch.uint8, device="cuda").to(torch.uint8)
tl_out = mod.func(B)
ref_out = torch_convert(B)
assert torch.allclose(tl_out, ref_out, rtol=0.01, atol=0.01), (tl_out, ref_out)
print("Pass")
def matmul_fp16xfp4(M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
block_M=64,
block_N=64,
block_K=64,
num_stages=1,
threads=128):
num_bits = 4
torch.manual_seed(0)
def kernel_func(block_M, block_N, block_K, num_stages, threads):
num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8"
A_shape = (M, K)
B_shape = (N, K // num_elems_per_byte)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K // num_elems_per_byte)
B_dequantize_shared_shape = (block_N, block_K)
assert K % (block_K) == 0
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, storage_dtype),
Ct: T.Buffer((N, M), out_dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)
B_dequantize_prev_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)
Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype)
Ct_shared = T.alloc_shared((block_N, block_M), out_dtype)
T.annotate_layout({
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared),
})
T.clear(Ct_local)
for k in T.Pipelined(K // block_K, num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)
T.copy(B_shared, B_local)
for i, j in T.Parallel(block_N, block_K):
B_dequantize_local[i, j] = _tir_u8_to_f4_to_f16(
num_bits,
B_local[i, j // num_elems_per_byte],
j % num_elems_per_byte,
dtype=in_dtype,
)
T.copy(B_dequantize_local, B_dequantize_prev_local)
T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True)
T.copy(Ct_local, Ct_shared)
T.copy(Ct_shared, Ct[bx * block_N:(bx + 1) * block_N,
by * block_M:(by + 1) * block_M])
return main
return kernel_func(block_M=64, block_N=64, block_K=64, num_stages=1, threads=128)
def ref_program(A, qB):
dtypeC = "float16"
B = torch_convert(qB)
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC))
return C.transpose(0, 1)
def assert_simple_impl_float16xfp4_gemm(M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
block_M=64,
block_N=64,
block_K=64,
num_stages=1,
threads=128):
func = matmul_fp16xfp4(M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K,
num_stages, threads)
torch_func = JITKernel(func, [2])
profiler = torch_func.get_profiler()
profiler.assert_allclose(ref_program)
def test_simple_impl_float16xfp4_gemm():
assert_simple_impl_float16xfp4_gemm(256, 256, 256, "float16", "float16", "float32", 64, 64, 64,
1, 128)
def matmul(
......
......@@ -13,7 +13,7 @@ from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitter,)
from tilelang.transform import simplify_prim_func
torch.manual_seed(0)
tilelang.testing.set_random_seed(0)
def make_swizzle_layout(shared_buf):
......
......@@ -11,7 +11,7 @@ import tilelang.language as T
from tilelang.intrinsics import get_swizzle_layout
from tilelang.transform import simplify_prim_func
torch.manual_seed(0)
tilelang.testing.set_random_seed(0)
def make_swizzle_layout(shared_buf):
......
......@@ -17,7 +17,7 @@ from tilelang.intrinsics.mma_macro_generator import (
)
from tilelang.transform import simplify_prim_func
torch.manual_seed(0)
tilelang.testing.set_random_seed(0)
@simplify_prim_func
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import tilelang
import tilelang.testing
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import pytest
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tilelang import tvm as tvm
import tilelang as tl
import tilelang.language as T
......
......@@ -69,7 +69,8 @@ class JITKernel(object):
target = Target(target)
# Validate the execution backend.
assert execution_backend in ["dl_pack", "torch_cpp", "ctypes"], "Invalid execution backend."
assert execution_backend in ["dl_pack", "torch_cpp",
"ctypes"], f"Invalid execution backend. {execution_backend}"
# Compile the TileLang function and create a kernel adapter for execution.
adapter = self._compile_and_create_adapter(func)
......
......@@ -3,6 +3,9 @@
import sys
import inspect
import pytest
import random
import torch
import numpy as np
from tvm.testing.utils import *
......@@ -75,3 +78,11 @@ def torch_assert_close(tensor_a,
f"Greatest relative difference: {(diff / (torch.abs(tensor_b) + 1e-12)).max().item()}.")
else:
return True
def set_random_seed(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
......@@ -20,8 +20,7 @@ def get_tensor_supply(supply_type: TensorSupplyType):
def get_tensor(tensor: TensorType) -> torch.Tensor:
dtype = torch.__getattribute__(str(tensor.dtype))
device = torch.cuda.current_device()
# torch.manual_seed(0)
# torch.cuda.manual_seed(0)
shape = list(map(int, tensor.shape))
if dtype == torch.int8 and supply_type in [
TensorSupplyType.Uniform,
......@@ -30,7 +29,11 @@ def get_tensor_supply(supply_type: TensorSupplyType):
return torch.ones(*shape, device=device, dtype=dtype)
if supply_type == TensorSupplyType.Integer:
return torch.randint(low=-2, high=3, size=shape, device=device, dtype=dtype)
is_unsigned = tensor.dtype.startswith("uint")
if is_unsigned:
return torch.randint(low=0, high=3, size=shape, device=device, dtype=dtype)
else:
return torch.randint(low=-2, high=3, size=shape, device=device, dtype=dtype)
elif supply_type == TensorSupplyType.Uniform:
return torch.empty(*shape, device=device, dtype=dtype).uniform_(-1.0, 1.0)
elif supply_type == TensorSupplyType.Normal:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment