Commit 7959d786 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[CI] Comprehensive Test cases Implementation of Matmul Dequantize (#32)

* installation script fix

* readme typo fix

* doc fix for dequantize gemm

* [Doc] remove CODE_OF_CONDUCT.md and SECURITY.md; update references in CONTRIBUTING.md

* [Doc] add unit tests for AnnotateDeviceRegions transform; remove SUPPORT.md

* update license

* [Enhancement] add tensor supply handling for unsigned integers; improve error message for execution backend assertion

* [Refactor] improve code readability by reformatting function signatures and assertions

* [Refactor] replace torch.manual_seed with tilelang.testing.set_random_seed for consistency in random seed handling
parent 34e0883d
*.h linguist-language=C++
...@@ -8,7 +8,7 @@ from tvm import DataType ...@@ -8,7 +8,7 @@ from tvm import DataType
import tilelang as TL import tilelang as TL
import tilelang.language as T import tilelang.language as T
torch.manual_seed(0) tilelang.testing.set_random_seed(0)
def matmul( def matmul(
......
...@@ -12,8 +12,6 @@ from tilelang.intrinsics.mma_macro_generator import ( ...@@ -12,8 +12,6 @@ from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitter,) TensorCoreIntrinEmitter,)
from tilelang.transform import simplify_prim_func from tilelang.transform import simplify_prim_func
torch.manual_seed(0)
def make_swizzle_layout(shared_buf): def make_swizzle_layout(shared_buf):
dtype = shared_buf.dtype dtype = shared_buf.dtype
......
...@@ -12,7 +12,7 @@ from tilelang.intrinsics.mfma_macro_generator import ( ...@@ -12,7 +12,7 @@ from tilelang.intrinsics.mfma_macro_generator import (
MatrixCoreIntrinEmitter,) MatrixCoreIntrinEmitter,)
from tilelang.transform import simplify_prim_func from tilelang.transform import simplify_prim_func
torch.manual_seed(0) tilelang.testing.set_random_seed(0)
@simplify_prim_func @simplify_prim_func
......
...@@ -11,7 +11,7 @@ import tilelang.language as T ...@@ -11,7 +11,7 @@ import tilelang.language as T
from tilelang.intrinsics.utils import get_swizzle_layout from tilelang.intrinsics.utils import get_swizzle_layout
from tilelang.intrinsics.mma_macro_generator import (TensorCoreIntrinEmitter) from tilelang.intrinsics.mma_macro_generator import (TensorCoreIntrinEmitter)
torch.manual_seed(0) tilelang.testing.set_random_seed(0)
def make_swizzle_layout(shared_buf): def make_swizzle_layout(shared_buf):
......
...@@ -4,11 +4,211 @@ import torch ...@@ -4,11 +4,211 @@ import torch
import torch.backends import torch.backends
import tilelang.testing import tilelang.testing
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tvm import DataType from tvm import DataType, tir
import tilelang as TL import tilelang as TL
import tilelang.language as T import tilelang.language as T
from tilelang import JITKernel, Profiler
tilelang.testing.set_random_seed(0)
def _tir_u8_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
assert nbit == 4
assert dtype == "float16"
assert val.dtype == "uint8"
# e_f4 == 0 -> e_f16 = 0
# e_f4 != 0 -> e_f16 = e_f4 + 8 = e_f4 | (1000)_2
# s1e2n1
mask = tir.const((1 << nbit) - 1, "uint16")
f4 = (val >> (pos.astype("uint16") * tir.const(nbit, "uint16"))) & mask
s = f4 >> tir.const(3, "uint16")
e_f4 = f4 & tir.const(7, "uint16")
e_f16 = e_f4 | tir.const(8, "uint16")
val_f16 = tir.reinterpret(
"float16",
((e_f16 | (s << tir.const(5, "uint16"))) << tir.const(10, "uint16")).astype("uint16"))
# return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, "float16"), val_f16)
return val_f16
def torch_convert(tensor):
def print_bit(name, val):
val_cpu = val.cpu().item()
binary_repr = f'{val_cpu:032b}'
print(name, binary_repr)
def _convert(val, pos):
assert val.dtype == torch.uint8
val = val.view(torch.int8)
mask = (1 << 4) - 1
f4 = ((val >> (pos * 4)) & mask).to(torch.int16)
s = f4 >> 3
e_f4 = f4 & 7
e_f16 = e_f4 | 8
val_f16 = ((e_f16 | (s << 5)) << 10) & 0xFFFF
lower_16_bits = (val_f16 & 0xFFFF).to(torch.uint16)
return lower_16_bits.view(torch.float16)
N = tensor.shape[0]
K = tensor.shape[1]
new_tensor = torch.empty(N, K * 2, dtype=torch.float16, device=tensor.device)
for i in range(new_tensor.shape[0]):
for j in range(new_tensor.shape[1]):
new_tensor[i][j] = _convert(tensor[i][j // 2], j % 2)
return new_tensor
def _convert_test(N, K, block_N, block_K, in_dtype, num_bits=4, threads=128):
num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8"
B_shape = (N, K // num_elems_per_byte)
B_shared_shape = (block_N, block_K // num_elems_per_byte)
B_dequantize_shared_shape = (block_N, block_K)
@T.prim_func
def main(
B: T.Buffer(B_shape, storage_dtype),
C: T.Buffer((N, K), in_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx):
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=1):
T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)
T.copy(B_shared, B_local)
for i, j in T.Parallel(block_N, block_K):
B_dequantize_local[i, j] = _tir_u8_to_f4_to_f16(
num_bits,
B_local[i, j // num_elems_per_byte],
j % num_elems_per_byte,
dtype=in_dtype,
)
T.copy(B_dequantize_local, C[bx * block_N, k * block_K])
return main
torch.manual_seed(0) def test_fp4_fp16_convert_close():
N, K = 256, 256
block_N, block_K = 64, 64
program = _convert_test(
N,
K,
block_N,
block_K,
"float16",
)
mod, params = tilelang.lower(program)
mod = Profiler(mod, params, [1], tilelang.TensorSupplyType.Integer)
B = torch.randint(0, 16, (N, K // 2), dtype=torch.uint8, device="cuda").to(torch.uint8)
tl_out = mod.func(B)
ref_out = torch_convert(B)
assert torch.allclose(tl_out, ref_out, rtol=0.01, atol=0.01), (tl_out, ref_out)
print("Pass")
def matmul_fp16xfp4(M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
block_M=64,
block_N=64,
block_K=64,
num_stages=1,
threads=128):
num_bits = 4
def kernel_func(block_M, block_N, block_K, num_stages, threads):
num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8"
A_shape = (M, K)
B_shape = (N, K // num_elems_per_byte)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K // num_elems_per_byte)
B_dequantize_shared_shape = (block_N, block_K)
assert K % (block_K) == 0
@T.prim_func
def main(
A: T.Buffer(A_shape, in_dtype),
B: T.Buffer(B_shape, storage_dtype),
Ct: T.Buffer((N, M), out_dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)
B_dequantize_prev_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)
Ct_local = T.alloc_fragment((block_N, block_M), accum_dtype)
Ct_shared = T.alloc_shared((block_N, block_M), out_dtype)
T.annotate_layout({
B_shared: tilelang.layout.make_swizzled_layout(B_shared),
Ct_shared: tilelang.layout.make_swizzled_layout(Ct_shared),
})
T.clear(Ct_local)
for k in T.Pipelined(K // block_K, num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)
T.copy(B_shared, B_local)
for i, j in T.Parallel(block_N, block_K):
B_dequantize_local[i, j] = _tir_u8_to_f4_to_f16(
num_bits,
B_local[i, j // num_elems_per_byte],
j % num_elems_per_byte,
dtype=in_dtype,
)
T.copy(B_dequantize_local, B_dequantize_prev_local)
T.gemm(B_dequantize_prev_local, A_shared, Ct_local, transpose_B=True)
T.copy(Ct_local, Ct_shared)
T.copy(Ct_shared, Ct[bx * block_N:(bx + 1) * block_N,
by * block_M:(by + 1) * block_M])
return main
return kernel_func(block_M=64, block_N=64, block_K=64, num_stages=1, threads=128)
def ref_program(A, qB):
dtypeC = "float16"
B = torch_convert(qB)
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC))
return C.transpose(0, 1)
def assert_simple_impl_float16xfp4_gemm(M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
block_M=64,
block_N=64,
block_K=64,
num_stages=1,
threads=128):
func = matmul_fp16xfp4(M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K,
num_stages, threads)
torch_func = JITKernel(func, [2])
profiler = torch_func.get_profiler()
profiler.assert_allclose(ref_program)
def test_simple_impl_float16xfp4_gemm():
assert_simple_impl_float16xfp4_gemm(256, 256, 256, "float16", "float16", "float32", 64, 64, 64,
1, 128)
def matmul( def matmul(
......
...@@ -13,7 +13,7 @@ from tilelang.intrinsics.mma_macro_generator import ( ...@@ -13,7 +13,7 @@ from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitter,) TensorCoreIntrinEmitter,)
from tilelang.transform import simplify_prim_func from tilelang.transform import simplify_prim_func
torch.manual_seed(0) tilelang.testing.set_random_seed(0)
def make_swizzle_layout(shared_buf): def make_swizzle_layout(shared_buf):
......
...@@ -11,7 +11,7 @@ import tilelang.language as T ...@@ -11,7 +11,7 @@ import tilelang.language as T
from tilelang.intrinsics import get_swizzle_layout from tilelang.intrinsics import get_swizzle_layout
from tilelang.transform import simplify_prim_func from tilelang.transform import simplify_prim_func
torch.manual_seed(0) tilelang.testing.set_random_seed(0)
def make_swizzle_layout(shared_buf): def make_swizzle_layout(shared_buf):
......
...@@ -17,7 +17,7 @@ from tilelang.intrinsics.mma_macro_generator import ( ...@@ -17,7 +17,7 @@ from tilelang.intrinsics.mma_macro_generator import (
) )
from tilelang.transform import simplify_prim_func from tilelang.transform import simplify_prim_func
torch.manual_seed(0) tilelang.testing.set_random_seed(0)
@simplify_prim_func @simplify_prim_func
......
# Licensed to the Apache Software Foundation (ASF) under one # Copyright (c) Microsoft Corporation.
# or more contributor license agreements. See the NOTICE file # Licensed under the MIT License.
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tilelang import tilelang
import tilelang.testing import tilelang.testing
......
# Licensed to the Apache Software Foundation (ASF) under one # Copyright (c) Microsoft Corporation.
# or more contributor license agreements. See the NOTICE file # Licensed under the MIT License.
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest import pytest
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT License. # Licensed under the MIT License.
from tilelang import tvm as tvm from tilelang import tvm as tvm
import tilelang as tl import tilelang as tl
import tilelang.language as T import tilelang.language as T
......
...@@ -69,7 +69,8 @@ class JITKernel(object): ...@@ -69,7 +69,8 @@ class JITKernel(object):
target = Target(target) target = Target(target)
# Validate the execution backend. # Validate the execution backend.
assert execution_backend in ["dl_pack", "torch_cpp", "ctypes"], "Invalid execution backend." assert execution_backend in ["dl_pack", "torch_cpp",
"ctypes"], f"Invalid execution backend. {execution_backend}"
# Compile the TileLang function and create a kernel adapter for execution. # Compile the TileLang function and create a kernel adapter for execution.
adapter = self._compile_and_create_adapter(func) adapter = self._compile_and_create_adapter(func)
......
...@@ -3,6 +3,9 @@ ...@@ -3,6 +3,9 @@
import sys import sys
import inspect import inspect
import pytest import pytest
import random
import torch
import numpy as np
from tvm.testing.utils import * from tvm.testing.utils import *
...@@ -75,3 +78,11 @@ def torch_assert_close(tensor_a, ...@@ -75,3 +78,11 @@ def torch_assert_close(tensor_a,
f"Greatest relative difference: {(diff / (torch.abs(tensor_b) + 1e-12)).max().item()}.") f"Greatest relative difference: {(diff / (torch.abs(tensor_b) + 1e-12)).max().item()}.")
else: else:
return True return True
def set_random_seed(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
...@@ -20,8 +20,7 @@ def get_tensor_supply(supply_type: TensorSupplyType): ...@@ -20,8 +20,7 @@ def get_tensor_supply(supply_type: TensorSupplyType):
def get_tensor(tensor: TensorType) -> torch.Tensor: def get_tensor(tensor: TensorType) -> torch.Tensor:
dtype = torch.__getattribute__(str(tensor.dtype)) dtype = torch.__getattribute__(str(tensor.dtype))
device = torch.cuda.current_device() device = torch.cuda.current_device()
# torch.manual_seed(0)
# torch.cuda.manual_seed(0)
shape = list(map(int, tensor.shape)) shape = list(map(int, tensor.shape))
if dtype == torch.int8 and supply_type in [ if dtype == torch.int8 and supply_type in [
TensorSupplyType.Uniform, TensorSupplyType.Uniform,
...@@ -30,6 +29,10 @@ def get_tensor_supply(supply_type: TensorSupplyType): ...@@ -30,6 +29,10 @@ def get_tensor_supply(supply_type: TensorSupplyType):
return torch.ones(*shape, device=device, dtype=dtype) return torch.ones(*shape, device=device, dtype=dtype)
if supply_type == TensorSupplyType.Integer: if supply_type == TensorSupplyType.Integer:
is_unsigned = tensor.dtype.startswith("uint")
if is_unsigned:
return torch.randint(low=0, high=3, size=shape, device=device, dtype=dtype)
else:
return torch.randint(low=-2, high=3, size=shape, device=device, dtype=dtype) return torch.randint(low=-2, high=3, size=shape, device=device, dtype=dtype)
elif supply_type == TensorSupplyType.Uniform: elif supply_type == TensorSupplyType.Uniform:
return torch.empty(*shape, device=device, dtype=dtype).uniform_(-1.0, 1.0) return torch.empty(*shape, device=device, dtype=dtype).uniform_(-1.0, 1.0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment