"src/git@developer.sourcefind.cn:OpenDAS/tilelang.git" did not exist on "76435ca84ecd278b71b0614eb0d06838283b469e"
Unverified Commit cc00fb65 authored by Tong WU's avatar Tong WU Committed by GitHub
Browse files

[Enhancement] Add support for symbolic dimensions in Cython kernel adapter and...

[Enhancement] Add support for symbolic dimensions in Cython kernel adapter and improve static shape validation in wrapper (#1024)

* [Enhancement] Add support for symbolic dimensions in Cython kernel adapter and improve static shape validation in wrapper

* [BugFix] Fix shape mismatch and deprecate `T.if()` in fused_moe example

* [Fix] Add `is_symbolic_expr` function to check for symbolic expressions in TIR

- Introduced a new utility function `is_symbolic_expr` to determine if an expression is a symbolic expression, enhancing type checking capabilities.
- Updated shape handling in `CythonKernelAdapter` to utilize the new function, improving handling for symbolic shapes.
parent a79bc5c6
...@@ -213,7 +213,7 @@ def moe_forward_tilelang_routed(d_hidden, ...@@ -213,7 +213,7 @@ def moe_forward_tilelang_routed(d_hidden,
up_logits_local[i, j] = up_logits_local[i, j] * gate_logits_local[i, j] up_logits_local[i, j] = up_logits_local[i, j] * gate_logits_local[i, j]
for i, j in T.Parallel(block_token, block_dexpert): for i, j in T.Parallel(block_token, block_dexpert):
with T.If(i < actual_rows), T.Then(): if i < actual_rows:
up_logits[m_start + i, by * block_dexpert + j] = up_logits_local[i, j] up_logits[m_start + i, by * block_dexpert + j] = up_logits_local[i, j]
# Step 2: Compute down logits # Step 2: Compute down logits
...@@ -261,7 +261,7 @@ def moe_forward_tilelang_routed(d_hidden, ...@@ -261,7 +261,7 @@ def moe_forward_tilelang_routed(d_hidden,
transpose_B=True) transpose_B=True)
for i, j in T.Parallel(block_token, block_dhidden): for i, j in T.Parallel(block_token, block_dhidden):
with T.If(i < actual_rows), T.Then(): if i < actual_rows:
output[m_start + i, by * block_dhidden + output[m_start + i, by * block_dhidden +
j] = output_local[i, j] * routed_expert_weights[m_start + i] j] = output_local[i, j] * routed_expert_weights[m_start + i]
...@@ -356,11 +356,11 @@ class MoE(nn.Module): ...@@ -356,11 +356,11 @@ class MoE(nn.Module):
dtype=torch.float16, dtype=torch.float16,
device=self.device) device=self.device)
self.stacked_expert_weights = torch.empty( self.stacked_expert_weights = torch.empty(
(config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], 1), (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"]),
dtype=torch.float16, dtype=torch.float16,
device=self.device) device=self.device)
self.stacked_expert_tokens_idxs = torch.empty( self.stacked_expert_tokens_idxs = torch.empty(
(config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], 1), (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"]),
dtype=torch.int64, dtype=torch.int64,
device=self.device) device=self.device)
...@@ -389,7 +389,7 @@ class MoE(nn.Module): ...@@ -389,7 +389,7 @@ class MoE(nn.Module):
batch_size, seq_len, hidden_dim = x.shape batch_size, seq_len, hidden_dim = x.shape
expert_indices, expert_scores = self.gating_network(x) expert_indices, expert_scores = self.gating_network(x)
flat_expert_indices = expert_indices.view(-1) flat_expert_indices = expert_indices.view(-1)
flat_expert_weights = expert_scores.view(-1, 1) flat_expert_weights = expert_scores.view(-1)
x_flat = x.view(-1, hidden_dim) x_flat = x.view(-1, hidden_dim)
# Prepare for grouped GEMM # Prepare for grouped GEMM
...@@ -412,7 +412,7 @@ class MoE(nn.Module): ...@@ -412,7 +412,7 @@ class MoE(nn.Module):
expert_tokens = x_flat[exp_token_idxs] expert_tokens = x_flat[exp_token_idxs]
self.stacked_expert_tokens[start_idx:end_idx] = expert_tokens self.stacked_expert_tokens[start_idx:end_idx] = expert_tokens
self.stacked_expert_tokens_idxs[start_idx:end_idx, 0] = exp_token_idxs self.stacked_expert_tokens_idxs[start_idx:end_idx] = exp_token_idxs
self.stacked_expert_weights[start_idx:end_idx] = flat_expert_weights[ self.stacked_expert_weights[start_idx:end_idx] = flat_expert_weights[
idxs[start_idx:end_idx]] idxs[start_idx:end_idx]]
......
...@@ -29,6 +29,13 @@ except ImportError: ...@@ -29,6 +29,13 @@ except ImportError:
raise raise
def is_symbolic_expr(expr) -> bool:
"""Check if the expression is a symbolic expression.
A symbolic expression can be a simple tvm.Var, or an tvm.PrimExpr containing tvm.Var.
"""
return not isinstance(expr, tir.IntImm) and isinstance(expr, tir.PrimExpr)
class CythonKernelAdapter(BaseKernelAdapter): class CythonKernelAdapter(BaseKernelAdapter):
"""Adapter class that converts TVM/TIR functions to callable CUDA kernels using cython. """Adapter class that converts TVM/TIR functions to callable CUDA kernels using cython.
...@@ -278,6 +285,10 @@ class CythonKernelAdapter(BaseKernelAdapter): ...@@ -278,6 +285,10 @@ class CythonKernelAdapter(BaseKernelAdapter):
for j, s in enumerate(buffer.shape): for j, s in enumerate(buffer.shape):
if isinstance(s, tir.IntImm): if isinstance(s, tir.IntImm):
static_shape.append((j, s.value)) static_shape.append((j, s.value))
elif is_symbolic_expr(s):
static_shape.append((j, -1)) # -1 for symbolic
else:
raise ValueError(f"Unsupported shape type: {type(s)}")
for j, s in enumerate(buffer.strides): for j, s in enumerate(buffer.strides):
if isinstance(s, tir.IntImm): if isinstance(s, tir.IntImm):
static_strides.append((j, s.value)) static_strides.append((j, s.value))
......
...@@ -107,9 +107,19 @@ cdef class CythonKernelWrapper: ...@@ -107,9 +107,19 @@ cdef class CythonKernelWrapper:
if not isinstance(tensor, torch.Tensor): if not isinstance(tensor, torch.Tensor):
# otherwise, maybe torch.data_ptr() for T.ptr inputs # otherwise, maybe torch.data_ptr() for T.ptr inputs
continue continue
# Check ndim
if tensor.dim() != len(shape_list):
raise ValueError(
f"Static shape mismatch for parameter {param}: "
f"expected {len(shape_list)} dimensions, "
f"got {tensor.dim()}"
)
# Check each dimension
for shape_idx, expected_shape in shape_list: for shape_idx, expected_shape in shape_list:
actual_shape = tensor.shape[shape_idx] actual_shape = tensor.shape[shape_idx]
if actual_shape != expected_shape: if expected_shape != -1 and actual_shape != expected_shape:
raise ValueError( raise ValueError(
f"Static shape mismatch for parameter {param}: " f"Static shape mismatch for parameter {param}: "
f"expected {expected_shape} at index {shape_idx}, " f"expected {expected_shape} at index {shape_idx}, "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment