Commit a1da26f2 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Bugfix] Disable force inline for ldmatrix (#227)

* Refactor GEMM and Bulk Copy operations to enhance layout handling and support for Hopper architecture

- Update `ComputeWarpPartition` to include a new parameter for Hopper WGMMA support.
- Modify layout checks in `LowerBulkCopy` to accommodate new GEMM layout types.
- Enhance layout inference logic in `InferLayout` for better compatibility with Hopper architecture.
- Include necessary header files for built-in operations and layout inference improvements.

* Refactor parameter formatting in CUDA matrix load functions for consistency

- Adjusted parameter alignment in `ptx_ldmatrix_x1`, `ptx_ldmatrix_x2`, `ptx_ldmatrix_x4`, and their transposed counterparts for improved readability.
- Added a blank line in `get_tensor_supply` function in `tensor.py` to enhance code clarity.

* Enhance tensor supply generation in `get_tensor_supply` function

- Introduced handling for unsigned integer and float8 tensor types, allowing for specific random tensor generation based on data type.
- Updated logic to return appropriate random tensors for different data types, improving flexibility and functionality of tensor supply generation.
- Refactored existing conditions for clarity and maintainability.

* Fix tensor supply generation logic in `get_tensor_supply` function

- Updated the variable reference from `tensor` to `param` to ensure correct handling of tensor data types.
- Improved the accuracy of unsigned integer and float8 checks for tensor supply generation, enhancing functionality and reliability.

* Enhance tensor supply checks in `get_tensor_supply` function

- Updated the logic for identifying unsigned integers and float8 types by using `removeprefix` on the dtype string, improving accuracy in tensor supply generation.
- Ensured better handling of tensor data types for more reliable random tensor generation based on the updated checks.

* Enhance KernelParam functionality and improve tensor supply checks

- Added methods `is_unsigned` and `is_float8` to the `KernelParam` class for better type identification of parameters.
- Updated the `get_tensor_supply` function to utilize the new methods, improving clarity and accuracy in tensor supply generation based on parameter types.
parent 3824adab
...@@ -24,6 +24,7 @@ using int4_t = int4; ...@@ -24,6 +24,7 @@ using int4_t = int4;
#define ushort unsigned short #define ushort unsigned short
#define TL_DEVICE __forceinline__ __device__ #define TL_DEVICE __forceinline__ __device__
#define TL_DEVICE_NOINLINE __noinline__ __device__
// Pack two half values. // Pack two half values.
TL_DEVICE unsigned __pack_half2(const half x, const half y) { TL_DEVICE unsigned __pack_half2(const half x, const half y) {
......
...@@ -6,8 +6,8 @@ ...@@ -6,8 +6,8 @@
namespace tl { namespace tl {
TL_DEVICE void ptx_ldmatrix_x1(void const *const smem_ptr, TL_DEVICE_NOINLINE void ptx_ldmatrix_x1(void const *const smem_ptr,
void *const local_ptr) { void *const local_ptr) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
int32_t *value = reinterpret_cast<int32_t *>(local_ptr); int32_t *value = reinterpret_cast<int32_t *>(local_ptr);
asm volatile("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];\n" asm volatile("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];\n"
...@@ -15,8 +15,8 @@ TL_DEVICE void ptx_ldmatrix_x1(void const *const smem_ptr, ...@@ -15,8 +15,8 @@ TL_DEVICE void ptx_ldmatrix_x1(void const *const smem_ptr,
: "r"(smem_int_ptr)); : "r"(smem_int_ptr));
} }
TL_DEVICE void ptx_ldmatrix_x2(void const *const smem_ptr, TL_DEVICE_NOINLINE void ptx_ldmatrix_x2(void const *const smem_ptr,
void *const local_ptr) { void *const local_ptr) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
int32_t *value = reinterpret_cast<int32_t *>(local_ptr); int32_t *value = reinterpret_cast<int32_t *>(local_ptr);
asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n"
...@@ -24,8 +24,8 @@ TL_DEVICE void ptx_ldmatrix_x2(void const *const smem_ptr, ...@@ -24,8 +24,8 @@ TL_DEVICE void ptx_ldmatrix_x2(void const *const smem_ptr,
: "r"(smem_int_ptr)); : "r"(smem_int_ptr));
} }
TL_DEVICE void ptx_ldmatrix_x4(void const *const smem_ptr, TL_DEVICE_NOINLINE void ptx_ldmatrix_x4(void const *const smem_ptr,
void *const local_ptr) { void *const local_ptr) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
int32_t *value = reinterpret_cast<int32_t *>(local_ptr); int32_t *value = reinterpret_cast<int32_t *>(local_ptr);
asm volatile( asm volatile(
...@@ -34,8 +34,8 @@ TL_DEVICE void ptx_ldmatrix_x4(void const *const smem_ptr, ...@@ -34,8 +34,8 @@ TL_DEVICE void ptx_ldmatrix_x4(void const *const smem_ptr,
: "r"(smem_int_ptr)); : "r"(smem_int_ptr));
} }
TL_DEVICE void ptx_ldmatrix_x1_trans(void const *const smem_ptr, TL_DEVICE_NOINLINE void ptx_ldmatrix_x1_trans(void const *const smem_ptr,
void *const local_ptr) { void *const local_ptr) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
int32_t *value = reinterpret_cast<int32_t *>(local_ptr); int32_t *value = reinterpret_cast<int32_t *>(local_ptr);
asm volatile("ldmatrix.sync.aligned.x1.trans.m8n8.shared.b16 {%0}, [%1];\n" asm volatile("ldmatrix.sync.aligned.x1.trans.m8n8.shared.b16 {%0}, [%1];\n"
...@@ -43,8 +43,8 @@ TL_DEVICE void ptx_ldmatrix_x1_trans(void const *const smem_ptr, ...@@ -43,8 +43,8 @@ TL_DEVICE void ptx_ldmatrix_x1_trans(void const *const smem_ptr,
: "r"(smem_int_ptr)); : "r"(smem_int_ptr));
} }
TL_DEVICE void ptx_ldmatrix_x2_trans(void const *const smem_ptr, TL_DEVICE_NOINLINE void ptx_ldmatrix_x2_trans(void const *const smem_ptr,
void *const local_ptr) { void *const local_ptr) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
int32_t *value = reinterpret_cast<int32_t *>(local_ptr); int32_t *value = reinterpret_cast<int32_t *>(local_ptr);
asm volatile( asm volatile(
...@@ -53,8 +53,8 @@ TL_DEVICE void ptx_ldmatrix_x2_trans(void const *const smem_ptr, ...@@ -53,8 +53,8 @@ TL_DEVICE void ptx_ldmatrix_x2_trans(void const *const smem_ptr,
: "r"(smem_int_ptr)); : "r"(smem_int_ptr));
} }
TL_DEVICE void ptx_ldmatrix_x4_trans(void const *const smem_ptr, TL_DEVICE_NOINLINE void ptx_ldmatrix_x4_trans(void const *const smem_ptr,
void *const local_ptr) { void *const local_ptr) {
uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr);
int32_t *value = reinterpret_cast<int32_t *>(local_ptr); int32_t *value = reinterpret_cast<int32_t *>(local_ptr);
asm volatile( asm volatile(
......
...@@ -147,9 +147,12 @@ public: ...@@ -147,9 +147,12 @@ public:
continue; continue;
// Check if buffer exists in use_list_ // Check if buffer exists in use_list_
ICHECK(use_list_.count(buffer)) if (!use_list_.count(buffer)) {
<< "Buffer " << buffer << " not found in use_list_. " LOG(WARNING) << "Buffer " << buffer << " not found in use_list_. "
<< "Potential mismatch between inference updates and use_list_."; << "Potential mismatch between inference updates and "
<< "use_list_.";
continue;
}
// Push back into BFS queue // Push back into BFS queue
for (int idx : use_list_[buffer]) { for (int idx : use_list_[buffer]) {
......
...@@ -64,3 +64,21 @@ class KernelParam: ...@@ -64,3 +64,21 @@ class KernelParam:
bool: True if parameter has no dimensions (empty shape), False otherwise bool: True if parameter has no dimensions (empty shape), False otherwise
""" """
return len(self.shape) == 0 return len(self.shape) == 0
def is_unsigned(self) -> bool:
"""
Checks if the parameter represents an unsigned integer type.
Returns:
bool: True if parameter is an unsigned integer type, False otherwise
"""
return str(self.dtype).removeprefix("torch.").startswith("uint")
def is_float8(self) -> bool:
"""
Checks if the parameter represents a float8 type.
Returns:
bool: True if parameter is a float8 type, False otherwise
"""
return str(self.dtype).removeprefix("torch.").startswith("float8")
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tvm import tir, IRModule from tvm import tir, IRModule
from tvm.target import Target from tvm.target import Target
import tilelang import tilelang
...@@ -39,7 +36,6 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: ...@@ -39,7 +36,6 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tir.transform.LowerOpaqueBlock()(mod) mod = tir.transform.LowerOpaqueBlock()(mod)
mod = tilelang.transform.MergeIfStmt()(mod) mod = tilelang.transform.MergeIfStmt()(mod)
mod = tilelang.transform.RewriteWgmmaSync()(mod) mod = tilelang.transform.RewriteWgmmaSync()(mod)
# mod = tilelang.transform.WarpSpecializedPipeline()(mod)
mod = tilelang.transform.InjectFenceProxy()(mod) mod = tilelang.transform.InjectFenceProxy()(mod)
else: else:
mod = tir.transform.PlanAndUpdateBufferAllocationLocation()(mod) mod = tir.transform.PlanAndUpdateBufferAllocationLocation()(mod)
......
...@@ -192,14 +192,14 @@ class JITKernel(object): ...@@ -192,14 +192,14 @@ class JITKernel(object):
return cls(func=tilelang_func, **kwargs) return cls(func=tilelang_func, **kwargs)
def get_profiler(self, def get_profiler(self,
tensor_supply_type: TensorSupplyType = TensorSupplyType.Integer) -> Profiler: tensor_supply_type: TensorSupplyType = TensorSupplyType.Auto) -> Profiler:
""" """
Creates a profiler to benchmark the compiled runtime module. Creates a profiler to benchmark the compiled runtime module.
Parameters Parameters
---------- ----------
tensor_supply_type : TensorSupplyType, optional tensor_supply_type : TensorSupplyType, optional
The type of input tensors to supply for profiling (default: TensorSupplyType.Integer). The type of input tensors to supply for profiling (default: TensorSupplyType.Auto).
Returns Returns
------- -------
......
...@@ -58,12 +58,17 @@ def get_tensor_supply(supply_type: TensorSupplyType): ...@@ -58,12 +58,17 @@ def get_tensor_supply(supply_type: TensorSupplyType):
shape = list(map(int, param.shape)) shape = list(map(int, param.shape))
if supply_type == TensorSupplyType.Auto: if supply_type == TensorSupplyType.Auto:
if dtype == torch.float16 or dtype == torch.float32: is_unsigned = param.is_unsigned()
is_float8 = param.is_float8()
if is_unsigned:
return torch.randint(low=0, high=3, size=shape, device=device, dtype=dtype)
elif is_float8:
return torch.randint(
low=-128, high=128, size=shape, device=device, dtype=torch.int8).to(dtype)
elif dtype in {torch.float16, torch.float32, torch.bfloat16}:
return torch.empty(*shape, device=device, dtype=dtype).normal_(-1.0, 1.0) return torch.empty(*shape, device=device, dtype=dtype).normal_(-1.0, 1.0)
elif dtype == torch.uint8:
return torch.randint(0, 2, size=shape, device=device, dtype=dtype)
else: else:
raise NotImplementedError(dtype) return torch.randint(low=-2, high=3, size=shape, device=device, dtype=dtype)
if dtype == torch.int8 and supply_type in [ if dtype == torch.int8 and supply_type in [
TensorSupplyType.Uniform, TensorSupplyType.Uniform,
...@@ -72,8 +77,8 @@ def get_tensor_supply(supply_type: TensorSupplyType): ...@@ -72,8 +77,8 @@ def get_tensor_supply(supply_type: TensorSupplyType):
return torch.ones(*shape, device=device, dtype=dtype) return torch.ones(*shape, device=device, dtype=dtype)
if supply_type == TensorSupplyType.Integer: if supply_type == TensorSupplyType.Integer:
is_unsigned = str(dtype).removeprefix("torch.").startswith("uint") is_unsigned = param.is_unsigned()
is_float8 = str(dtype).removeprefix("torch.").startswith("float8") is_float8 = param.is_float8()
if is_unsigned: if is_unsigned:
return torch.randint(low=0, high=3, size=shape, device=device, dtype=dtype) return torch.randint(low=0, high=3, size=shape, device=device, dtype=dtype)
elif is_float8: elif is_float8:
...@@ -97,6 +102,7 @@ def get_tensor_supply(supply_type: TensorSupplyType): ...@@ -97,6 +102,7 @@ def get_tensor_supply(supply_type: TensorSupplyType):
return get_tensor return get_tensor
# TODO: Align with torch.testing.assert_close
def torch_assert_close( def torch_assert_close(
tensor_a, tensor_a,
tensor_b, tensor_b,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment