Commit a91bc2a9 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Refactor] Update barrier functions and add new example for GEMM with warp specialization (#456)

* Add example for warp specialization with flash attention

* Introduced a new example script `example_warp_specialize_flashmla.py` demonstrating flash attention using warp specialization in TileLang.
* Implemented the `flashattn` function with shared memory allocation and memory barrier synchronization for improved performance.
* Added a reference program for validation against PyTorch's implementation, including profiling for latency and performance metrics.
* Removed the outdated `example_warp_specialize_mla.py` to streamline examples and focus on the new implementation.

* Add memory barrier functions to builtin.py

* Introduced `barrier_wait` and `barrier_arrive` functions for memory barrier synchronization.
* Enhanced documentation with detailed docstrings for both functions, clarifying their usage and parameters.
* The `barrier_wait` function serves as a wrapper for `mbarrier_wait_parity`, supporting parity values 0 and 1.
* Improved code organization and readability by adding blank lines for better separation of logical sections.

* Enhance code readability by adding blank lines in example_warp_specialize_flashmla.py and builtin.py

* Added blank lines to improve code organization and separation of logical sections in `example_warp_specialize_flashmla.py`.
* Included blank lines in `builtin.py` around the `wait_wgmma` and `barrier_wait` functions for better readability.

* [Refactor] Update barrier functions and add new example for GEMM with warp specialization

* Refactored memory barrier functions in `example_warp_specialize_flashmla.py` to use the new `barrier_wait` and `barrier_arrive` methods for improved clarity and consistency.
* Introduced a new example script `example_warp_specialize_gemm_copy_gemm_0_1.py` demonstrating matrix multiplication with warp specialization and shared memory allocation.
* Enhanced the `layout.cc` and `elem.cc` files to improve structural equality checks and error handling in copy operations.
* Updated `warpgroup.py` to refine thread ID calculations for better performance in warp specialization scenarios.
* Added new shuffle operations in `builtin.py` for enhanced functionality in parallel computations.

* lint fix

* Update loop variable checks in SIMT loop and buffer region validation

* Modified checks in `elem.cc` to ensure loop variable sizes are less than or equal to source and destination range sizes for better error handling.
* Adjusted assertions in `copy.py` to reflect the updated logic, allowing for more flexible region extent comparisons and improved error messaging.

* lint fix

* test fix
parent 8adfc117
......@@ -44,7 +44,6 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T.annotate_layout({
O_shared: tilelang.layout.make_swizzled_layout(O_shared),
})
T.create_list_of_mbarrier(128, 128, 256, 128)
loop_range = T.ceildiv(seqlen_kv, block_N)
......@@ -52,32 +51,29 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T.dec_max_nreg(24)
T.copy(Q[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_shared)
T.copy(Q_pe[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_pe_shared)
T.mbarrier_arrive(T.get_mbarrier(3))
T.barrier_arrive(barrier_id=3)
for k in T.serial(loop_range):
T.mbarrier_wait_parity(
T.FloorMod(k, 1) + 2, T.bitwise_xor(T.FloorDiv(k, 1) % 2, 1))
T.barrier_wait(barrier_id=(k % 1) + 2, parity=(k % 2) ^ 1)
T.copy(KV[bx, k * block_N:(k + 1) * block_N, cur_kv_head, :], KV_shared)
T.mbarrier_arrive(T.FloorMod(k, 1))
T.barrier_arrive(k % 1)
T.copy(K_pe[bx, k * block_N:(k + 1) * block_N, cur_kv_head, :], K_pe_shared)
T.mbarrier_arrive(T.FloorMod(k, 1) + 1)
T.barrier_arrive(k % 1 + 1)
with T.ws(0, 1):
T.inc_max_nreg(240)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
T.mbarrier_wait_parity(T.get_mbarrier(3), 0)
T.barrier_wait(3, 0)
for k in T.serial(loop_range):
T.clear(acc_s)
T.mbarrier_wait_parity(T.get_mbarrier(T.FloorMod(k, 1)), T.FloorDiv(k, 1) % 2)
T.barrier_wait(barrier_id=k % 1, parity=(k // 1) % 2)
T.gemm(
Q_shared,
KV_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullCol)
T.mbarrier_wait_parity(
T.get_mbarrier(T.FloorMod(k, 1) + 1),
T.FloorDiv(k, 1) % 2)
T.barrier_wait(barrier_id=k % 1 + 1, parity=(k // 1) % 2)
T.gemm(
Q_pe_shared,
K_pe_shared,
......@@ -98,7 +94,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] *= scores_scale[i]
T.gemm(S_shared, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol)
T.mbarrier_arrive(T.get_mbarrier(T.FloorMod(k, 1) + 2))
T.barrier_arrive(barrier_id=k % 1 + 2)
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
......@@ -181,6 +177,7 @@ def main():
program = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split)
kernel = tilelang.compile(program, out_idx=[6])
print(kernel.get_kernel_source())
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
......
import tilelang
import tilelang.language as T
tilelang.disable_cache()
def matmul_warp_specialize_copy_1_gemm_0(M,
N,
K,
block_M,
block_N,
block_K,
dtype="float16",
accum_dtype="float"):
warp_group_num = 2
threads = 128 * warp_group_num
# add decorator @tilelang.jit if you want to return a torch function
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype, "shared")
B_shared_g0 = T.alloc_shared((block_K, block_N // warp_group_num), dtype, "shared")
B_shared_g1 = T.alloc_shared((block_K, block_N // warp_group_num), dtype, "shared")
C_local_g0 = T.alloc_fragment((block_M, block_N // warp_group_num), accum_dtype)
C_local_g1 = T.alloc_fragment((block_M, block_N // warp_group_num), accum_dtype)
T.clear(C_local_g0)
T.clear(C_local_g1)
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0):
T.copy(A[by * block_M, ko * block_K], A_shared)
with T.ws(1):
T.copy(B[ko * block_K, bx * block_N], B_shared_g1)
T.gemm(A_shared, B_shared_g1, C_local_g1)
with T.ws(0):
T.copy(B[ko * block_K, bx * block_N + block_N // warp_group_num], B_shared_g0)
T.gemm(A_shared, B_shared_g0, C_local_g0)
T.copy(C_local_g1, C[by * block_M, bx * block_N])
T.copy(C_local_g0, C[by * block_M, bx * block_N + block_N // warp_group_num])
return main
def main():
M = 128
N = 128
K = 64
block_M = 128
block_N = 128
block_K = 64
# 1. Define the kernel (matmul) and compile/lower it into an executable module
func = matmul_warp_specialize_copy_1_gemm_0(M, N, K, block_M, block_N, block_K)
# print(func.script())
# 2. Compile the kernel into a torch function
# out_idx specifies the index of the output buffer in the argument list
# if out_idx is specified, the tensor will be created during runtime
# target currently can be "cuda" or "hip" or "cpu".
jit_kernel = tilelang.compile(
func,
out_idx=[2],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
# tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
print(jit_kernel.get_kernel_source())
# 3. Test the kernel in Python with PyTorch data
import torch
# Create random input tensors on the GPU
a = torch.randn(M, K, device="cuda", dtype=torch.float16)
b = torch.randn(K, N, device="cuda", dtype=torch.float16)
# Run the kernel through the Profiler
c = jit_kernel(a, b)
print(c)
# Reference multiplication using PyTorch
ref_c = a @ b
print(ref_c)
# Validate correctness
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("Kernel output matches PyTorch reference.")
# 4. Retrieve and inspect the generated CUDA source (optional)
# cuda_source = jit_kernel.get_kernel_source()
# print("Generated CUDA kernel:\n", cuda_source)
# 5.Profile latency with kernel
profiler = jit_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
latency = profiler.do_bench()
print(f"Latency: {latency} ms")
if __name__ == "__main__":
main()
......@@ -415,6 +415,7 @@ bool FragmentNode::IsEqual(const FragmentNode *other, bool skip_index) const {
// a[i, j] = b[j, i] in register level.
bool ret = StructuralEqual()(this->InputShape(), other->InputShape());
ret &= StructuralEqual()(this->ThreadRange(), other->ThreadRange());
if (!ret) {
// may be broadcast case
return true;
......
......@@ -115,6 +115,16 @@ For Copy::MakeSIMTLoop(arith::Analyzer *analyzer) const {
for (const auto &iv : loop_vars)
analyzer->Bind(iv->var, iv->dom);
ICHECK(loop_vars.size() <= src_range.size())
<< "loop_vars.size() = " << loop_vars.size()
<< ", src_range.size() = " << src_range.size() << ", src = " << src->name
<< ", dst = " << dst->name;
ICHECK(loop_vars.size() <= dst_range.size())
<< "loop_vars.size() = " << loop_vars.size()
<< ", dst_range.size() = " << dst_range.size() << ", src = " << src->name
<< ", dst = " << dst->name;
Array<PrimExpr> src_indices = MakeIndices(loop_vars, 0);
Array<PrimExpr> dst_indices = MakeIndices(loop_vars, 1);
......
......@@ -146,7 +146,6 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
if (level == InferLevel::kStrict)
return {};
auto block_size = T.thread_bounds->extent;
// Step 1: try to infer loop's partition from a source fragment
Buffer source_buffer, read_source_buffer;
for (const auto &[buffer, indices] : indice_map_) {
......@@ -227,14 +226,28 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
}
loop_layout_ = PlanLoopPartition(root_, vector_size, T.thread_bounds);
}
PrimExpr loop_thread_extent = loop_layout_->ThreadExtent();
if (!analyzer_.CanProveEqual(loop_thread_extent, block_size))
AddPredicate(
LT(InputPlaceholder(0) - T.thread_bounds->min, loop_thread_extent));
} else {
return {};
}
PrimExpr loop_thread_extent = loop_layout_->ThreadExtent();
auto block_size = T.thread_bounds->extent;
if (loop_layout_.defined()) {
if (loop_layout_->ThreadRange().defined()) {
auto thread_range = loop_layout_->ThreadRange();
block_size = thread_range->extent;
AddPredicate(GE(InputPlaceholder(0), thread_range->min));
AddPredicate(
LT(InputPlaceholder(0), thread_range->min + thread_range->extent));
}
}
if (!analyzer_.CanProveEqual(loop_thread_extent, block_size)) {
AddPredicate(
LT(InputPlaceholder(0), loop_thread_extent + T.thread_bounds->min));
}
// Step 2: Check that the loop's partition can correctly align with all source
// fragment
for (const auto &[buffer, _] : indice_map_) {
......
......@@ -1019,7 +1019,7 @@ public:
// Check if function only uses threadIdx.x before proceeding
if (!ThreadTagChecker::HasOnlyThreadIdxX(f)) {
LOG(WARNING) << "WarpSpecialize will be disabled because the program "
"uses thread tags other than threadIdx.x\n"
"uses thread tags other than threadIdx.x."
<< "If you want to use warp specialization, please refactor "
"your program to use threadIdx.x only";
// Return original function unchanged if other thread tags are found
......@@ -1190,12 +1190,14 @@ public:
static bool Detect(Stmt stmt, bool skip_thread_partition = false) {
WarpSpecializedDetector detector;
detector.VisitStmt(stmt);
return detector.has_tma_op_ && detector.has_mbarrier_op_;
return detector.has_warp_specialization_ ||
(detector.has_tma_op_ && detector.has_mbarrier_op_);
}
WarpSpecializedDetector() {
has_tma_op_ = false;
has_mbarrier_op_ = false;
has_warp_specialization_ = false;
}
private:
......@@ -1219,8 +1221,58 @@ private:
IRVisitorWithAnalyzer::VisitExpr_(op);
}
void VisitStmt_(const IfThenElseNode *op) final {
// do not visit the body of the if-then-else statement
// because we only care about the condition
auto cond = op->condition;
// assert cond is a binary expression
PostOrderVisit(cond, [this](const ObjectRef &node) {
bool is_cmp_op = false;
if (const auto *lt = node.as<LTNode>()) {
is_cmp_op = true;
} else if (const auto *le = node.as<LENode>()) {
is_cmp_op = true;
} else if (const auto *gt = node.as<GTNode>()) {
is_cmp_op = true;
} else if (const auto *ge = node.as<GENode>()) {
is_cmp_op = true;
}
if (is_cmp_op) {
bool has_thread_var = false;
bool has_warp_group_size = false;
// check if has thread_var_ in lt->a or lt->b
PostOrderVisit(node, [this, &has_thread_var,
&has_warp_group_size](const ObjectRef &node_) {
if (node_.as<VarNode>() == thread_var_->var.get()) {
has_thread_var = true;
} else if (const auto *imm = node_.as<IntImmNode>()) {
// 128 is the warp group size of nvidia gpus
has_warp_group_size = imm->value % 128 == 0;
}
});
if (has_thread_var && has_warp_group_size) {
has_warp_specialization_ = true;
}
}
});
}
void VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == tir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
if (iv->thread_tag == "threadIdx.x") {
ICHECK(iv->dom->extent.as<IntImmNode>());
thread_var_ = iv;
}
}
IRVisitorWithAnalyzer::VisitStmt_(op);
}
bool has_tma_op_{false};
IterVar thread_var_;
bool has_mbarrier_op_{false};
bool has_warp_specialization_{false};
};
using namespace tir::transform;
......
......@@ -336,6 +336,7 @@ def run_gemm(
@tvm.testing.requires_package("bitblas")
@tilelang.testing.requires_llvm
def tl_matmul_with_ladder_weight_only_transform_block_reduce_int4(
M,
N,
......@@ -630,6 +631,7 @@ def test_run_dequantize_gemm():
@tilelang.testing.requires_package("bitblas")
@tilelang.testing.requires_llvm
def test_assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4():
assert_tl_matmul_with_ladder_weight_only_transform_block_reduce_int4_correctness(
256, 1024, 512, "float16", "float16", "float16", 3)
......
......@@ -13,7 +13,7 @@ from tilelang.intrinsics.mma_macro_generator import (
)
from tilelang.transform import simplify_prim_func
tilelang.testing.set_random_seed(0)
tilelang.testing.set_random_seed(42)
@simplify_prim_func
......@@ -394,10 +394,10 @@ def assert_tl_matmul_weight_only_transform_correctness(M, N, K, in_dtype, out_dt
@tilelang.testing.requires_package("bitblas")
@tilelang.testing.requires_llvm
def test_assert_tl_matmul_weight_only_transform():
assert_tl_matmul_weight_only_transform_correctness(128, 128, 128, "int8", "int32", "int32")
if __name__ == "__main__":
# tilelang.testing.main()
test_assert_tl_matmul_weight_only_transform()
tilelang.testing.main()
......@@ -220,7 +220,7 @@ class TLCUDASourceWrapper(object):
smem_str = 0 if dynamic_smem_buf is None else dynamic_smem_buf
kernel_launch_code += "\t{}<<<{}, {}, {}, stream>>>({});\n".format(
function_name, grid_str, block_str, smem_str, call_args)
kernel_launch_code += "TILELANG_CHECK_LAST_ERROR(\"{}\");\n".format(function_name)
kernel_launch_code += "\tTILELANG_CHECK_LAST_ERROR(\"{}\");\n".format(function_name)
kernel_launch_code = self.generate_tma_descriptor_args(desc_name_map) + kernel_launch_code
......
......@@ -233,3 +233,62 @@ def wait_wgmma(*args):
tir.Call: A handle to the WGMMA wait operation
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.wait_wgmma"), *args)
def barrier_wait(barrier_id: Union[int, PrimExpr, tir.Call], parity: Union[int, Var, None] = None):
"""Wait for a memory barrier to complete.
Args:
barrier_id: Optional[int, PrimExpr]
The memory barrier to wait on
parity: Optional[int, Var]
The parity value to wait for
Returns:
tir.Call: A handle to the barrier wait operation
Current implementation is a sugar syntax for mbarrier_wait_parity, as we only support parity 0 and 1.
"""
return mbarrier_wait_parity(barrier_id, parity)
def barrier_arrive(barrier_id: Union[int, PrimExpr, tir.Call]):
"""Arrive at a memory barrier.
Args:
barrier_id: Optional[int, PrimExpr]
The memory barrier to arrive at
"""
return mbarrier_arrive(barrier_id)
def shfl_xor(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr, tir.Call]):
"""Perform a shuffle operation with XOR offset.
Args:
value: Optional[int, PrimExpr]
The value to shuffle
offset: Optional[int, PrimExpr]
The offset for the shuffle operation
Returns:
tir.Call: A handle to the shuffle operation
"""
return tir.call_extern(value.dtype, "__shfl_xor_sync", 0xffffffff, value, offset)
def shfl_down(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr, tir.Call]):
"""Perform a shuffle operation with down offset.
Args:
value: Optional[int, PrimExpr]
The value to shuffle
"""
return tir.call_extern(value.dtype, "__shfl_down_sync", 0xffffffff, value, offset)
def shfl_up(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr, tir.Call]):
"""Perform a shuffle operation with up offset.
Args:
value: Optional[int, PrimExpr]
The value to shuffle
"""
return tir.call_extern(value.dtype, "__shfl_up_sync", 0xffffffff, value, offset)
......@@ -60,7 +60,8 @@ def buffer_load_to_tile_region(load: tir.BufferLoad, access_type: str, extents:
return region(load, access_type, *extents)
def buffer_region_to_tile_region(buffer_region: tir.BufferRegion, access_type: str):
def buffer_region_to_tile_region(buffer_region: tir.BufferRegion, access_type: str,
extents: List[tir.PrimExpr]):
"""Convert a buffer region to a tile region descriptor.
Args:
......@@ -71,8 +72,34 @@ def buffer_region_to_tile_region(buffer_region: tir.BufferRegion, access_type: s
tir.Call: A region descriptor for the specified buffer region
"""
mins = [x.min for x in buffer_region.region]
extents = [x.extent for x in buffer_region.region]
return region(T.BufferLoad(buffer_region.buffer, mins), access_type, *extents)
region_extents = [x.extent for x in buffer_region.region]
assert len(region_extents) >= len(
extents), f"region_extents = {region_extents}, extents = {extents}"
# If region_extents already contains all elements
# of extents (in any order), pass directly
tmp_extents = list(extents)
for i in range(len(region_extents)):
v = region_extents[i]
if v in tmp_extents:
tmp_extents.remove(v)
elif v != 1:
raise ValueError(
f"buffer {buffer_region.buffer} region_extents[{i}] = {v}, extents[{i}] = {extents[i]}"
)
if len(tmp_extents) > 0:
# Otherwise, align extents from the last dimension, region_extents
# can only replace 1 with extents value, otherwise raise error
for i in range(len(extents)):
idx = len(region_extents) - len(extents) + i
if region_extents[idx] != extents[i]:
if region_extents[idx] == 1:
region_extents[idx] = extents[i]
else:
raise ValueError(
f"buffer {buffer_region.buffer} region_extents[{idx}] = {region_extents[idx]}, extents[{i}] = {extents[i]}"
)
return region(T.BufferLoad(buffer_region.buffer, mins), access_type, *region_extents)
def copy(
......@@ -108,13 +135,10 @@ def copy(
src_extent = get_extent(src)
dst_extent = get_extent(dst)
if src_extent:
extent = src_extent
elif dst_extent:
extent = dst_extent
else:
raise TypeError("Can't deduce copy extents from args")
assert src_extent or dst_extent, "Can't deduce copy extents from args"
src_extent = list(src_extent) if src_extent else [1] * len(dst_extent)
dst_extent = list(dst_extent) if dst_extent else [1] * len(src_extent)
extent = max(src_extent, dst_extent)
def _to_region(data, access_type):
if isinstance(data, tir.Var) and T.has_let_value(data):
......@@ -122,7 +146,7 @@ def copy(
if isinstance(data, tir.Buffer):
return buffer_to_tile_region(data, access_type)
elif isinstance(data, tir.BufferRegion):
return buffer_region_to_tile_region(data, access_type)
return buffer_region_to_tile_region(data, access_type, extent)
else:
return buffer_load_to_tile_region(data, access_type, extent)
......
......@@ -42,6 +42,24 @@ def print_var_with_condition(condition: tir.PrimExpr,
tir.call_extern("handle", "debug_print_var", msg, var)
@macro
def print_global_buffer_with_condition(condition: tir.PrimExpr,
buffer: tir.Buffer,
elems: int,
msg: str = "") -> tir.PrimExpr:
"""
Conditionally prints the values of a flattened TIR buffer if the condition is True.
"""
if condition:
# Iterate through the buffer elements and print each one.
for i in serial(elems):
coords = index_to_coordinates(i, buffer.shape)
tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i,
buffer[coords])
else:
tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i, buffer[coords])
@macro
def print_shared_buffer_with_condition(condition: tir.PrimExpr,
buffer: tir.Buffer,
......@@ -170,6 +188,15 @@ def print(obj: Any, msg: str = "") -> tir.PrimExpr:
if not msg:
msg = f"buffer<{buffer.name}, {buffer.dtype}>"
return print_shared_buffer_with_condition(condition, buffer, elems, msg)
elif buffer.scope() == "global":
# Get the number of elements in the buffer.
elems = 1
for dim in buffer.shape:
elems *= dim
condition = True
return print_global_buffer_with_condition(condition, buffer, elems, msg)
else:
raise ValueError(f"Unsupported buffer scope: {buffer.scope()}")
elif isinstance(obj, tir.PrimExpr):
if not msg:
......
......@@ -35,8 +35,13 @@ def WarpSpecialize(*warp_group_idx):
>>> T.ws(0, 1) -> if tx < 128 or (tx >= 128 and tx < 256)
"""
id_x, id_y, id_z = get_thread_bindings()
ex_x, ex_y, _ = get_thread_extents()
tid = id_z * (ex_y * ex_x) + id_y * ex_x + id_x
ex_x, ex_y, ex_z = get_thread_extents()
tid = id_x
if ex_y > 1:
tid = id_y * ex_x + tid
if ex_z > 1:
tid = id_z * (ex_y * ex_x) + tid
# only available for nvidia gpus.
warp_group_size = 128
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment