Unverified Commit 5cb5c068 authored by Yu Cheng's avatar Yu Cheng Committed by GitHub
Browse files

[Bugfix] Fix missing host cuTensorMapEncodeIm2col call (#1094)

parent bddb125e
......@@ -122,6 +122,7 @@ def main(argv=None):
out_c = kernel(a, b)
ref_c = ref_program(S, P, D)(a, b)
torch.testing.assert_close(out_c, ref_c, rtol=1e-2, atol=1e-2)
print("All checks passed.✅")
if __name__ == "__main__":
......
......@@ -163,7 +163,7 @@ private:
}
PrimExpr VisitExpr_(const CallNode *op) {
if (op->op.same_as(tma_load())) {
if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col())) {
auto arg0 = op->args[0].as<Call>();
bool is_1d_tma_load =
arg0 && !arg0.value()->op.same_as(create_tma_descriptor()) &&
......@@ -203,7 +203,7 @@ private:
void VisitStmt_(const EvaluateNode *op) final {
if (const auto *call = op->value.as<CallNode>()) {
if (call->op.same_as(tma_load())) {
if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) {
pending_tma_ops_.push_back(GetRef<Call>(call));
} else if (call->op.same_as(mbarrier_expect_tx())) {
pending_tma_ops_.push_back(GetRef<Call>(call));
......@@ -451,7 +451,7 @@ private:
}
PrimExpr VisitExpr_(const CallNode *op) {
if (op->op.same_as(tma_load())) {
if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col())) {
// check this must be in the tma_op_to_barrier_id_
ICHECK(tma_op_to_barrier_id_.count(GetRef<Call>(op)))
<< "tma_load must be in the tma_op_to_barrier_id_";
......@@ -459,7 +459,8 @@ private:
auto new_args = op->args;
auto arg0 = op->args[0].as<Call>();
auto is_1d_tma_load =
arg0 && !arg0.value()->op.same_as(create_tma_descriptor());
arg0 && !arg0.value()->op.same_as(create_tma_descriptor()) &&
!arg0.value()->op.same_as(create_tma_im2col_descriptor());
if (is_1d_tma_load) {
new_args.Set(2, barrier_id);
} else {
......
......@@ -106,6 +106,35 @@ TMA_DESC_INIT_FUNC = """
\t}}
"""
TMA_IM2COL_DESC_INIT_FUNC = """
\tCUtensorMap {0};
\tCUtensorMapDataType {0}_type= (CUtensorMapDataType){1};
\tcuuint32_t {0}_tensorRank= {2};
\tvoid *{0}_globalAddress= {3};
\tcuuint64_t {0}_globalDim[{2}]= {{{4}}};
\tcuuint64_t {0}_globalStride[{2}]= {{{5}}};
\tcuuint32_t {0}_elementStrides[{2}]= {{{6}}};
\tint {0}_lowerCorner[{2} - 2]= {{{7}}};
\tint {0}_upperCorner[{2} - 2]= {{{8}}};
\tcuuint32_t {0}_channelsPerPixel= {9};
\tcuuint32_t {0}_pixelsPerColumn= {10};
\tCUtensorMapInterleave {0}_interleave= (CUtensorMapInterleave){11};
\tCUtensorMapSwizzle {0}_swizzle= (CUtensorMapSwizzle){12};
\tCUtensorMapL2promotion {0}_l2Promotion= (CUtensorMapL2promotion){13};
\tCUtensorMapFloatOOBfill {0}_oobFill= (CUtensorMapFloatOOBfill){14};
\tCUresult {0}_result = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeIm2col)(
&{0}, {0}_type, {0}_tensorRank, {0}_globalAddress, {0}_globalDim, {0}_globalStride + 1,
{0}_lowerCorner, {0}_upperCorner, {0}_channelsPerPixel, {0}_pixelsPerColumn, {0}_elementStrides, {0}_interleave, {0}_swizzle, {0}_l2Promotion, {0}_oobFill);
\tif ({0}_result != CUDA_SUCCESS) {{
\t\tstd::stringstream ss;
\t\tss << "Error: Failed to initialize the TMA descriptor {0}";
\t\tsnprintf(error_buf, ERROR_BUF_SIZE, "%s", ss.str().c_str());
\t\treturn -1;
\t}}
"""
TMA_DESC_INIT_FUNC_PY = """
\t{0}_type = cuda.bindings.driver.CUtensorMapDataType({1})
\t{0}_tensorRank = {2}
......@@ -401,7 +430,10 @@ class TLCUDASourceWrapper(object):
if len(args) < 3:
raise ValueError(
f"TMA descriptor args too short: {len(args)} elements, expected at least 3")
_, dtype, tensor_rank, globalAddress, *remaining_args = args[1:]
tma_create_str, _, dtype, tensor_rank, globalAddress, *remaining_args = args
is_img2col = (tma_create_str.value == "__tvm_tensormap_create_im2col")
dtype = self._pythonic_expr(dtype)
tensor_rank = int(self._pythonic_expr(tensor_rank))
......@@ -409,42 +441,81 @@ class TLCUDASourceWrapper(object):
if not isinstance(tensor_rank, int) or tensor_rank <= 0:
raise ValueError(f"Invalid tensor_rank: {tensor_rank}. Must be a positive integer")
# Calculate required length for remaining_args
expected_args_len = 4 * tensor_rank + 4 # 4 groups of tensor_rank size + 4 parameters
if len(remaining_args) < expected_args_len:
raise ValueError(f"Insufficient remaining args: got {len(remaining_args)}, "
f"expected {expected_args_len} for tensor_rank {tensor_rank}")
# Extract dimensions and strides using list slicing
global_dim = remaining_args[:tensor_rank]
global_stride = remaining_args[tensor_rank:2 * tensor_rank]
box_dim = remaining_args[2 * tensor_rank:3 * tensor_rank]
element_strides = remaining_args[3 * tensor_rank:4 * tensor_rank]
global_dim = [self._pythonic_expr(i) for i in global_dim]
global_stride = [self._pythonic_expr(i) for i in global_stride]
box_dim = [self._pythonic_expr(i) for i in box_dim]
element_strides = [self._pythonic_expr(i) for i in element_strides]
# Extract remaining parameters
try:
interleave, swizzle, l2Promotion, oobFill = remaining_args[4 * tensor_rank:4 *
tensor_rank + 4]
interleave = self._pythonic_expr(interleave)
swizzle = self._pythonic_expr(swizzle)
l2Promotion = self._pythonic_expr(l2Promotion)
oobFill = self._pythonic_expr(oobFill)
except ValueError as e:
raise ValueError(
"Failed to unpack the final 4 TMA parameters (interleave, swizzle, l2Promotion, oobFill)"
) from e
if not is_img2col:
# Calculate required length for remaining_args
expected_args_len = 4 * tensor_rank + 4 # 4 groups of tensor_rank size + 4 parameters
if len(remaining_args) < expected_args_len:
raise ValueError(f"Insufficient remaining args: got {len(remaining_args)}, "
f"expected {expected_args_len} for tensor_rank {tensor_rank}")
# Extract dimensions and strides using list slicing
global_dim = remaining_args[:tensor_rank]
global_stride = remaining_args[tensor_rank:2 * tensor_rank]
box_dim = remaining_args[2 * tensor_rank:3 * tensor_rank]
element_strides = remaining_args[3 * tensor_rank:4 * tensor_rank]
global_dim = [self._pythonic_expr(i) for i in global_dim]
global_stride = [self._pythonic_expr(i) for i in global_stride]
box_dim = [self._pythonic_expr(i) for i in box_dim]
element_strides = [self._pythonic_expr(i) for i in element_strides]
# Extract remaining parameters
try:
interleave, swizzle, l2Promotion, oobFill = remaining_args[4 * tensor_rank:4 *
tensor_rank + 4]
interleave = self._pythonic_expr(interleave)
swizzle = self._pythonic_expr(swizzle)
l2Promotion = self._pythonic_expr(l2Promotion)
oobFill = self._pythonic_expr(oobFill)
except ValueError as e:
raise ValueError(
"Failed to unpack the final 4 TMA parameters (interleave, swizzle, l2Promotion, oobFill)"
) from e
tma_descripter_init += TMA_DESC_INIT_FUNC.format(
handle_name, dtype, tensor_rank, globalAddress, ",".join(global_dim),
",".join(global_stride), ",".join(box_dim), ",".join(element_strides),
interleave, swizzle, l2Promotion, oobFill)
else:
# Calculate required length for remaining_args
expected_args_len = 5 * tensor_rank + 2
if len(remaining_args) < expected_args_len:
raise ValueError(f"Insufficient remaining args: got {len(remaining_args)}, "
f"expected {expected_args_len} for tensor_rank {tensor_rank}")
# Extract dimensions and strides using list slicing
global_dim = remaining_args[:tensor_rank]
global_stride = remaining_args[tensor_rank:2 * tensor_rank]
element_strides = remaining_args[2 * tensor_rank:3 * tensor_rank]
lower_corner = remaining_args[3 * tensor_rank:4 * tensor_rank - 2]
upper_corner = remaining_args[4 * tensor_rank - 2:5 * tensor_rank - 4]
global_dim = [self._pythonic_expr(i) for i in global_dim]
global_stride = [self._pythonic_expr(i) for i in global_stride]
element_strides = [self._pythonic_expr(i) for i in element_strides]
lower_corner = [self._pythonic_expr(i) for i in lower_corner]
upper_corner = [self._pythonic_expr(i) for i in upper_corner]
# Extract remaining parameters
try:
smem_box_pixel, smem_box_channel, interleave, swizzle, l2Promotion, oobFill = remaining_args[
5 * tensor_rank - 4:5 * tensor_rank + 2]
smem_box_pixel = self._pythonic_expr(smem_box_pixel)
smem_box_channel = self._pythonic_expr(smem_box_channel)
interleave = self._pythonic_expr(interleave)
swizzle = self._pythonic_expr(swizzle)
l2Promotion = self._pythonic_expr(l2Promotion)
oobFill = self._pythonic_expr(oobFill)
except ValueError as e:
raise ValueError(
"Failed to unpack the final 6 TMA parameters (smem_box_pixel, smem_box_channel, interleave, swizzle, l2Promotion, oobFill)"
) from e
tma_descripter_init += TMA_IM2COL_DESC_INIT_FUNC.format(
handle_name, dtype, tensor_rank, globalAddress, ",".join(global_dim),
",".join(global_stride), ",".join(element_strides), ",".join(lower_corner),
",".join(upper_corner), smem_box_channel, smem_box_pixel, interleave, swizzle,
l2Promotion, oobFill)
tma_descripter_init += TMA_DESC_INIT_FUNC.format(handle_name, dtype, tensor_rank,
globalAddress, ",".join(global_dim),
",".join(global_stride),
",".join(box_dim),
",".join(element_strides), interleave,
swizzle, l2Promotion, oobFill)
return tma_descripter_init
def parse_source_information(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment