Commit e9820408 authored by Woosuk Kwon's avatar Woosuk Kwon
Browse files

test


Signed-off-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
parent fdf6d72b
......@@ -665,11 +665,11 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
// number of times random will be generated per thread, to offset philox counter in thc random
// state
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
// int64_t counter_offset = params.b * params.h * 32;
// auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
// auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
// // Forward kernel will populate memory with the seed and offset.
// params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
int64_t counter_offset = params.b * params.h * 32;
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
// Forward kernel will populate memory with the seed and offset.
params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
// if (p_dropout > 0.0) {
// auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment