"...text-generation-inference.git" did not exist on "54fec9319371b2792526e0cbfebe6cee66ed3980"
Unverified Commit f0672603 authored by silentCoder-dev's avatar silentCoder-dev Committed by GitHub
Browse files

[Refactor] Rename test for curand & add triton baseline in `test_tilelang_language_rand.py` (#1464)

* rename test for curand & add triton baseline

* add a comment for calling T.rng_rand() four times

* refactor tilelang&triton kernel

* Add boundary checks for M not divisible by 128
parent 7248a810
import tilelang
import tilelang.language as T # noqa: N812
import torch
import triton
import triton.language as tl
@tilelang.jit
def tilelang_rand_1d(M=1024, seed=42):
blk_M = 128
num_threads = 128
@T.prim_func
def rand_kernel(A: T.Tensor((M,), "uint32")):
with T.Kernel(M // blk_M, threads=num_threads) as bx:
T.rng_init(seed)
for i in T.Parallel(blk_M):
A[bx * blk_M + i] = T.rng_rand()
return rand_kernel
@triton.jit
def triton_rand_1d(X, M, seed):
pid = tl.program_id(0)
offset = pid * M + tl.arange(0, M)
rand = tl.randint(seed, offset)
tl.store(X + offset, rand, mask=offset < M)
if __name__ == "__main__":
M = 1024
kernel = tilelang_rand_1d()
x = torch.empty(M, dtype=torch.uint32, device="cuda")
kernel(x)
import tilelang
import tilelang.language as T
import torch
import triton
import triton.language as tl
import pytest
import tilelang.testing
@tilelang.jit
def tilelang_rand_1d(M=1024, seed=42):
num_per_thread = 128
threads = 1
blk_M = num_per_thread * threads
@T.prim_func
def rand_kernel(A: T.Tensor((M,), "uint32")):
with T.Kernel(T.ceildiv(M, threads * num_per_thread), threads=threads) as bx:
tx = T.get_thread_binding()
T.rng_init(seed, 0, bx * blk_M + tx * num_per_thread)
for i, j in T.Parallel(threads, num_per_thread):
offsets = (bx * threads + i) * num_per_thread
idx = offsets + j
if idx < M:
A[idx] = T.rng_rand()
return rand_kernel
@triton.jit
def triton_rand_1d(X, M, elements_per_thread, seed):
pid = tl.program_id(0)
offset = pid * elements_per_thread + tl.arange(0, elements_per_thread)
r0, r1, r2, r3 = tl.randint4x(seed, offset)
base_idx = offset * 4
tl.store(X + base_idx, r0, mask=base_idx < M)
tl.store(X + base_idx + 1, r1, mask=(base_idx + 1) < M)
tl.store(X + base_idx + 2, r2, mask=(base_idx + 2) < M)
tl.store(X + base_idx + 3, r3, mask=(base_idx + 3) < M)
@tilelang.testing.requires_cuda
@pytest.mark.parametrize("M, seed", [(1024, 42), (512, 123), (128, 0)])
def test_rand_1d(M, seed):
kernel = tilelang_rand_1d(M, seed)
tilelang_result = torch.empty(M, dtype=torch.uint32, device="cuda")
kernel(tilelang_result)
triton_result = torch.empty(M, dtype=torch.uint32, device="cuda")
grid = (M // 128,)
triton_rand_1d[grid](triton_result, tl.constexpr(M), tl.constexpr(128 // 4), seed)
torch.testing.assert_close(tilelang_result, triton_result)
if __name__ == "__main__":
tilelang.testing.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment