Unverified Commit cae06edd authored by silentCoder-dev's avatar silentCoder-dev Committed by GitHub
Browse files

[Language]Adds a random number generation capability through curand_kernel (#1461)



* add curand.{curand_init, curand}

* run format.sh

* add default value for curand_init & add test for curand

* Update testing/python/language/test_rand.py

Remove unused thread binding
Co-authored-by: default avatarcoderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

* remove unused library

* enable tilelang cache for testing

* run format.sh

* Revert "run format.sh"

This reverts commit 5afaff782f31cdf653e2c45b469da8dead228b8a.

* Revert "enable tilelang cache for testing"

This reverts commit c277a43e77938bd88d47a108dd1bd65734d4a1ae.

* Revert "remove unused library"

This reverts commit 568ad20611f039380113937fd131151a2bffd801.

* run format.sh

* ensure FreshName for __philox_state

* ensure FreshName for __philox_state

* change the return type of T.rng_init

---------
Co-authored-by: default avatarcoderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
parent 48e70e68
...@@ -102,6 +102,12 @@ TIR_DEFINE_TL_BUILTIN(ieee_frsqrt) ...@@ -102,6 +102,12 @@ TIR_DEFINE_TL_BUILTIN(ieee_frsqrt)
TIR_DEFINE_TL_BUILTIN(ieee_fdiv).set_num_inputs(3).set_attr<TCallEffectKind>( TIR_DEFINE_TL_BUILTIN(ieee_fdiv).set_num_inputs(3).set_attr<TCallEffectKind>(
"TCallEffectKind", Integer(CallEffectKind::kPure)); "TCallEffectKind", Integer(CallEffectKind::kPure));
TIR_DEFINE_TL_BUILTIN(rng_init).set_num_inputs(3).set_attr<TCallEffectKind>(
"TCallEffectKind", Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(rng_rand).set_num_inputs(0).set_attr<TCallEffectKind>(
"TCallEffectKind", Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(create_list_of_mbarrier) TIR_DEFINE_TL_BUILTIN(create_list_of_mbarrier)
.set_num_inputs(-1) .set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind", .set_attr<TCallEffectKind>("TCallEffectKind",
......
...@@ -147,6 +147,10 @@ TVM_DLL const Op &ieee_frsqrt(); ...@@ -147,6 +147,10 @@ TVM_DLL const Op &ieee_frsqrt();
// ieee_fdiv(x, y, rounding_mode) - IEEE-compliant division // ieee_fdiv(x, y, rounding_mode) - IEEE-compliant division
TVM_DLL const Op &ieee_fdiv(); TVM_DLL const Op &ieee_fdiv();
// random op
TVM_DLL const Op &rng_init();
TVM_DLL const Op &rng_rand();
/*! /*!
* \brief tvm intrinsics for TMADescriptor creation for tiled load * \brief tvm intrinsics for TMADescriptor creation for tiled load
* *
......
...@@ -297,6 +297,10 @@ std::string CodeGenTileLangCUDA::Finish() { ...@@ -297,6 +297,10 @@ std::string CodeGenTileLangCUDA::Finish() {
decl_stream << "#include <cooperative_groups.h>\n"; decl_stream << "#include <cooperative_groups.h>\n";
} }
if (need_curand_kernel_h_) {
decl_stream << "#include <curand_kernel.h>\n";
}
decl_stream << "#include <tl_templates/cuda/gemm.h>\n"; decl_stream << "#include <tl_templates/cuda/gemm.h>\n";
if (enable_sparse_gemm_) { if (enable_sparse_gemm_) {
decl_stream << "#include <tl_templates/cuda/gemm_sp.h>\n"; decl_stream << "#include <tl_templates/cuda/gemm_sp.h>\n";
...@@ -2730,6 +2734,20 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -2730,6 +2734,20 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
std::string func_name = math_func(op->dtype, "fdiv", rounding_mode); std::string func_name = math_func(op->dtype, "fdiv", rounding_mode);
os << func_name << "(" << PrintExpr(op->args[0]) << ", " os << func_name << "(" << PrintExpr(op->args[0]) << ", "
<< PrintExpr(op->args[1]) << ")"; << PrintExpr(op->args[1]) << ")";
} else if (op->op.same_as(tl::rng_init())) {
this->need_curand_kernel_h_ = true;
this->curand_philox_state = name_supply_->FreshName("__philox_state");
this->PrintIndent();
this->stream << "curandStatePhilox4_32_10_t " << this->curand_philox_state
<< ";\n";
this->PrintIndent();
this->stream << "curand_init(" << PrintExpr(op->args[0]) << ", "
<< PrintExpr(op->args[1]) << ", " << PrintExpr(op->args[2])
<< ", &" << this->curand_philox_state << ");\n";
// Store state_var for later use by rng_rand
} else if (op->op.same_as(tl::rng_rand())) {
this->need_curand_kernel_h_ = true;
os << "curand(&" << this->curand_philox_state << ")";
} else if (op->op.same_as(tl::warp_reduce_sum())) { } else if (op->op.same_as(tl::warp_reduce_sum())) {
os << "tl::warp_reduce_sum(" << PrintExpr(op->args[0]) << ")"; os << "tl::warp_reduce_sum(" << PrintExpr(op->args[0]) << ")";
} else if (op->op.same_as(tl::warp_reduce_max())) { } else if (op->op.same_as(tl::warp_reduce_max())) {
......
...@@ -88,6 +88,8 @@ private: ...@@ -88,6 +88,8 @@ private:
std::string vid_global_barrier_state_; std::string vid_global_barrier_state_;
// Global barrier expected node. // Global barrier expected node.
std::string vid_global_barrier_expect_; std::string vid_global_barrier_expect_;
// Global curand state
std::string curand_philox_state;
// whether enable fp16 // whether enable fp16
bool enable_fp16_{false}; bool enable_fp16_{false};
...@@ -123,6 +125,8 @@ private: ...@@ -123,6 +125,8 @@ private:
bool need_cast_smem_ptr_to_int_{false}; bool need_cast_smem_ptr_to_int_{false};
// whether need cooperative_groups.h // whether need cooperative_groups.h
bool need_cooperative_groups_{false}; bool need_cooperative_groups_{false};
// whether need curand_kernel.h
bool need_curand_kernel_h_{false};
// Op attribute map // Op attribute map
OpAttrMap<bool> op_need_warp_shuffle_ = OpAttrMap<bool> op_need_warp_shuffle_ =
Op::GetAttrMap<bool>("cuda.need_warp_shuffle"); Op::GetAttrMap<bool>("cuda.need_warp_shuffle");
......
...@@ -1190,6 +1190,7 @@ private: ...@@ -1190,6 +1190,7 @@ private:
}); });
if ((has_non_local || has_cast_operations) && !has_reducer) { if ((has_non_local || has_cast_operations) && !has_reducer) {
DLOG(INFO) << "Try to vectorize loop";
for_node = VectorizeLoop(for_node, saved_analyzer.get()); for_node = VectorizeLoop(for_node, saved_analyzer.get());
} }
......
...@@ -152,6 +152,10 @@ private: ...@@ -152,6 +152,10 @@ private:
} else if (node->op == builtin::call_extern()) { } else if (node->op == builtin::call_extern()) {
// do not vectorize extern calls // do not vectorize extern calls
vector_size_ = 1; vector_size_ = 1;
} else if (node->op.same_as(tl::rng_rand()) ||
node->op.same_as(tl::rng_init())) {
// do not vectorize random operation
vector_size_ = 1;
} }
return arith::IRMutatorWithAnalyzer::VisitExpr_(node); return arith::IRMutatorWithAnalyzer::VisitExpr_(node);
} }
......
import tilelang
import tilelang.language as T # noqa: N812
import torch
import triton
import triton.language as tl
@tilelang.jit
def tilelang_rand_1d(M=1024, seed=42):
blk_M = 128
num_threads = 128
@T.prim_func
def rand_kernel(A: T.Tensor((M,), "uint32")):
with T.Kernel(M // blk_M, threads=num_threads) as bx:
T.rng_init(seed)
for i in T.Parallel(blk_M):
A[bx * blk_M + i] = T.rng_rand()
return rand_kernel
@triton.jit
def triton_rand_1d(X, M, seed):
pid = tl.program_id(0)
offset = pid * M + tl.arange(0, M)
rand = tl.randint(seed, offset)
tl.store(X + offset, rand, mask=offset < M)
if __name__ == "__main__":
M = 1024
kernel = tilelang_rand_1d()
x = torch.empty(M, dtype=torch.uint32, device="cuda")
kernel(x)
...@@ -111,6 +111,11 @@ from .annotations import ( # noqa: F401 ...@@ -111,6 +111,11 @@ from .annotations import ( # noqa: F401
annotate_restrict_buffers, annotate_restrict_buffers,
) )
from .random import (
rng_init, # noqa: F401
rng_rand, # noqa: F401
)
def import_source(source: str | None = None): def import_source(source: str | None = None):
# source is the source code to be imported # source is the source code to be imported
......
from tvm import tir
import tilelang.language as T
# https://docs.nvidia.com/cuda/curand/device-api-overview.html#device-api-overview
def rng_init(seed, seq=None, off=0):
"""Initialize CUDA curand random number generator state
Parameters
----------
seed : PrimExpr
Random seed value.
seq : PrimExpr
Sequence number for parallel random number generation.
off : PrimExpr
Offset number for parallel random number generation.
Returns
-------
state : PrimExpr
The random number generator state handle.
"""
seed = tir.convert(seed)
if seq is None:
bx = T.get_block_binding()
ex = T.kernel.get_thread_extent()
tx = T.get_thread_binding()
id = tx + bx * ex
seq = tir.convert(id)
else:
seq = tir.convert(seq)
off = tir.convert(off)
return tir.call_intrin("void", tir.op.Op.get("tl.rng_init"), seed, seq, off)
def rng_rand():
"""Generate a 32-bit unsigned random integer
Returns
-------
random_value : PrimExpr
A 32-bit unsigned random integer.
"""
return tir.call_intrin("uint32", tir.op.Op.get("tl.rng_rand"))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment