"vscode:/vscode.git/clone" did not exist on "5afe7bb63a72f488922a553f29757f8da07d0d10"
Commit e9820408 authored by Woosuk Kwon's avatar Woosuk Kwon
Browse files

test


Signed-off-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
parent fdf6d72b
...@@ -665,11 +665,11 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s ...@@ -665,11 +665,11 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
// number of times random will be generated per thread, to offset philox counter in thc random // number of times random will be generated per thread, to offset philox counter in thc random
// state // state
// We use a custom RNG that increases the offset by batch_size * nheads * 32. // We use a custom RNG that increases the offset by batch_size * nheads * 32.
// int64_t counter_offset = params.b * params.h * 32; int64_t counter_offset = params.b * params.h * 32;
// auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
// auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
// // Forward kernel will populate memory with the seed and offset. // Forward kernel will populate memory with the seed and offset.
// params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr()); params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
// if (p_dropout > 0.0) { // if (p_dropout > 0.0) {
// auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>( // auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment