Commit a1da26f2 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Bugfix] Disable force inline for ldmatrix (#227)

* Refactor GEMM and Bulk Copy operations to enhance layout handling and support for Hopper architecture

- Update `ComputeWarpPartition` to include a new parameter for Hopper WGMMA support.
- Modify layout checks in `LowerBulkCopy` to accommodate new GEMM layout types.
- Enhance layout inference logic in `InferLayout` for better compatibility with Hopper architecture.
- Include necessary header files for built-in operations and layout inference improvements.

* Refactor parameter formatting in CUDA matrix load functions for consistency

- Adjusted parameter alignment in `ptx_ldmatrix_x1`, `ptx_ldmatrix_x2`, `ptx_ldmatrix_x4`, and their transposed counterparts for improved readability.
- Added a blank line in `get_tensor_supply` function in `tensor.py` to enhance code clarity.

* Enhance tensor supply generation in `get_tensor_supply` function

- Introduced handling for unsigned integer and float8 tensor types, allowing for specific random tensor generation based on data type.
- Updated logic to return appropriate random tensors for different data types, improving flexibility and functionality of tensor supply generation.
- Refactored existing conditions for clarity and maintainability.

* Fix tensor supply generation logic in `get_tensor_supply` function

- Updated the variable reference from `tensor` to `param` to ensure correct handling of tensor data types.
- Improved the accuracy of unsigned integer and float8 checks for tensor supply generation, enhancing functionality and reliability.

* Enhance tensor supply checks in `get_tensor_supply` function

- Updated the logic for identifying unsigned integers and float8 types by using `removeprefix` on the dtype string, improving accuracy in tensor supply generation.
- Ensured better handling of tensor data types for more reliable random tensor generation based on the updated checks.

* Enhance KernelParam functionality and improve tensor supply checks

- Added methods `is_unsigned` and `is_float8` to the `KernelParam` class for better type identification of parameters.
- Updated the `get_tensor_supply` function to utilize the new methods, improving clarity and accuracy in tensor supply generation based on parameter types.
parent 3824adab
......@@ -24,6 +24,7 @@ using int4_t = int4;
#define ushort unsigned short
#define TL_DEVICE __forceinline__ __device__
#define TL_DEVICE_NOINLINE __noinline__ __device__
// Pack two half values.
TL_DEVICE unsigned __pack_half2(const half x, const half y) {
......
......@@ -6,7 +6,7 @@
namespace tl {
TL_DEVICE void ptx_ldmatrix_x1(void const *const smem_ptr,
TL_DEVICE_NOINLINE void ptx_ldmatrix_x1(void const *const smem_ptr,
void *const local_ptr) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
int32_t *value = reinterpret_cast<int32_t *>(local_ptr);
......@@ -15,7 +15,7 @@ TL_DEVICE void ptx_ldmatrix_x1(void const *const smem_ptr,
: "r"(smem_int_ptr));
}
TL_DEVICE void ptx_ldmatrix_x2(void const *const smem_ptr,
TL_DEVICE_NOINLINE void ptx_ldmatrix_x2(void const *const smem_ptr,
void *const local_ptr) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
int32_t *value = reinterpret_cast<int32_t *>(local_ptr);
......@@ -24,7 +24,7 @@ TL_DEVICE void ptx_ldmatrix_x2(void const *const smem_ptr,
: "r"(smem_int_ptr));
}
TL_DEVICE void ptx_ldmatrix_x4(void const *const smem_ptr,
TL_DEVICE_NOINLINE void ptx_ldmatrix_x4(void const *const smem_ptr,
void *const local_ptr) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
int32_t *value = reinterpret_cast<int32_t *>(local_ptr);
......@@ -34,7 +34,7 @@ TL_DEVICE void ptx_ldmatrix_x4(void const *const smem_ptr,
: "r"(smem_int_ptr));
}
TL_DEVICE void ptx_ldmatrix_x1_trans(void const *const smem_ptr,
TL_DEVICE_NOINLINE void ptx_ldmatrix_x1_trans(void const *const smem_ptr,
void *const local_ptr) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
int32_t *value = reinterpret_cast<int32_t *>(local_ptr);
......@@ -43,7 +43,7 @@ TL_DEVICE void ptx_ldmatrix_x1_trans(void const *const smem_ptr,
: "r"(smem_int_ptr));
}
TL_DEVICE void ptx_ldmatrix_x2_trans(void const *const smem_ptr,
TL_DEVICE_NOINLINE void ptx_ldmatrix_x2_trans(void const *const smem_ptr,
void *const local_ptr) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
int32_t *value = reinterpret_cast<int32_t *>(local_ptr);
......@@ -53,7 +53,7 @@ TL_DEVICE void ptx_ldmatrix_x2_trans(void const *const smem_ptr,
: "r"(smem_int_ptr));
}
TL_DEVICE void ptx_ldmatrix_x4_trans(void const *const smem_ptr,
TL_DEVICE_NOINLINE void ptx_ldmatrix_x4_trans(void const *const smem_ptr,
void *const local_ptr) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
int32_t *value = reinterpret_cast<int32_t *>(local_ptr);
......
......@@ -147,9 +147,12 @@ public:
continue;
// Check if buffer exists in use_list_
ICHECK(use_list_.count(buffer))
<< "Buffer " << buffer << " not found in use_list_. "
<< "Potential mismatch between inference updates and use_list_.";
if (!use_list_.count(buffer)) {
LOG(WARNING) << "Buffer " << buffer << " not found in use_list_. "
<< "Potential mismatch between inference updates and "
<< "use_list_.";
continue;
}
// Push back into BFS queue
for (int idx : use_list_[buffer]) {
......
......@@ -64,3 +64,21 @@ class KernelParam:
bool: True if parameter has no dimensions (empty shape), False otherwise
"""
return len(self.shape) == 0
def is_unsigned(self) -> bool:
"""
Checks if the parameter represents an unsigned integer type.
Returns:
bool: True if parameter is an unsigned integer type, False otherwise
"""
return str(self.dtype).removeprefix("torch.").startswith("uint")
def is_float8(self) -> bool:
"""
Checks if the parameter represents a float8 type.
Returns:
bool: True if parameter is a float8 type, False otherwise
"""
return str(self.dtype).removeprefix("torch.").startswith("float8")
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tvm import tir, IRModule
from tvm.target import Target
import tilelang
......@@ -39,7 +36,6 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tir.transform.LowerOpaqueBlock()(mod)
mod = tilelang.transform.MergeIfStmt()(mod)
mod = tilelang.transform.RewriteWgmmaSync()(mod)
# mod = tilelang.transform.WarpSpecializedPipeline()(mod)
mod = tilelang.transform.InjectFenceProxy()(mod)
else:
mod = tir.transform.PlanAndUpdateBufferAllocationLocation()(mod)
......
......@@ -192,14 +192,14 @@ class JITKernel(object):
return cls(func=tilelang_func, **kwargs)
def get_profiler(self,
tensor_supply_type: TensorSupplyType = TensorSupplyType.Integer) -> Profiler:
tensor_supply_type: TensorSupplyType = TensorSupplyType.Auto) -> Profiler:
"""
Creates a profiler to benchmark the compiled runtime module.
Parameters
----------
tensor_supply_type : TensorSupplyType, optional
The type of input tensors to supply for profiling (default: TensorSupplyType.Integer).
The type of input tensors to supply for profiling (default: TensorSupplyType.Auto).
Returns
-------
......
......@@ -58,12 +58,17 @@ def get_tensor_supply(supply_type: TensorSupplyType):
shape = list(map(int, param.shape))
if supply_type == TensorSupplyType.Auto:
if dtype == torch.float16 or dtype == torch.float32:
is_unsigned = param.is_unsigned()
is_float8 = param.is_float8()
if is_unsigned:
return torch.randint(low=0, high=3, size=shape, device=device, dtype=dtype)
elif is_float8:
return torch.randint(
low=-128, high=128, size=shape, device=device, dtype=torch.int8).to(dtype)
elif dtype in {torch.float16, torch.float32, torch.bfloat16}:
return torch.empty(*shape, device=device, dtype=dtype).normal_(-1.0, 1.0)
elif dtype == torch.uint8:
return torch.randint(0, 2, size=shape, device=device, dtype=dtype)
else:
raise NotImplementedError(dtype)
return torch.randint(low=-2, high=3, size=shape, device=device, dtype=dtype)
if dtype == torch.int8 and supply_type in [
TensorSupplyType.Uniform,
......@@ -72,8 +77,8 @@ def get_tensor_supply(supply_type: TensorSupplyType):
return torch.ones(*shape, device=device, dtype=dtype)
if supply_type == TensorSupplyType.Integer:
is_unsigned = str(dtype).removeprefix("torch.").startswith("uint")
is_float8 = str(dtype).removeprefix("torch.").startswith("float8")
is_unsigned = param.is_unsigned()
is_float8 = param.is_float8()
if is_unsigned:
return torch.randint(low=0, high=3, size=shape, device=device, dtype=dtype)
elif is_float8:
......@@ -97,6 +102,7 @@ def get_tensor_supply(supply_type: TensorSupplyType):
return get_tensor
# TODO: Align with torch.testing.assert_close
def torch_assert_close(
tensor_a,
tensor_b,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment