Unverified Commit eefb1ba2 authored by Burc Eryilmaz's avatar Burc Eryilmaz Committed by GitHub
Browse files

cuda rng changes for graph capture with apex MHA (#1025)


Co-authored-by: default avatarSukru Eryilmaz <seryilmaz@computelab-dgx1v-32.nvidia.com>
parent 154c6336
#pragma once #pragma once
//#include <ATen/ATen.h>
#ifdef OLD_GENERATOR
#include <ATen/CUDAGenerator.h>
#else
#include <ATen/CUDAGeneratorImpl.h> #include <ATen/CUDAGeneratorImpl.h>
#endif #include <ATen/cuda/CUDAGraphsUtils.cuh>
//#include <ATen/cuda/CUDAContext.h>
#include <curand_kernel.h> #include <curand_kernel.h>
#include "philox.h" #include "philox.h"
//#include <THC/THCGeneral.h>
#include <assert.h> #include <assert.h>
#include <cfloat> #include <cfloat>
...@@ -279,7 +271,7 @@ bool dispatch_softmax(output_t *dst, const input_t *src, int softmax_elements, i ...@@ -279,7 +271,7 @@ bool dispatch_softmax(output_t *dst, const input_t *src, int softmax_elements, i
} }
template <typename input_t, typename output_t, typename acc_t, int WARP_BATCH, int WARP_ITERATIONS, int WARP_SIZE, int ELEMENTS_PER_LDG_STG> template <typename input_t, typename output_t, typename acc_t, int WARP_BATCH, int WARP_ITERATIONS, int WARP_SIZE, int ELEMENTS_PER_LDG_STG>
__global__ void additive_masked_softmax_dropout_warp_forward_vec4(output_t *dst, uint8_t *dropout_mask, const input_t *src, const input_t *pad_mask, int batch_size, int stride, int element_count, int pad_batch_stride, std::pair<uint64_t,uint64_t> seeds, float p) __global__ void additive_masked_softmax_dropout_warp_forward_vec4(output_t *dst, uint8_t *dropout_mask, const input_t *src, const input_t *pad_mask, int batch_size, int stride, int element_count, int pad_batch_stride, at::PhiloxCudaState philox_args, float p)
{ {
assert(ELEMENTS_PER_LDG_STG==4); assert(ELEMENTS_PER_LDG_STG==4);
...@@ -386,8 +378,8 @@ __global__ void additive_masked_softmax_dropout_warp_forward_vec4(output_t *dst, ...@@ -386,8 +378,8 @@ __global__ void additive_masked_softmax_dropout_warp_forward_vec4(output_t *dst,
sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE); sum[i] += __shfl_xor_sync(FULL_MASK, sum[i], offset, WARP_SIZE);
} }
} }
auto seeds = at::cuda::philox::unpack(philox_args);
Philox ph(seeds.first, tid, seeds.second); Philox ph(std::get<0>(seeds), tid, std::get<1>(seeds));
uint8_t rands[WARP_BATCH][WARP_ITERATIONS]; uint8_t rands[WARP_BATCH][WARP_ITERATIONS];
float4 rand_num; float4 rand_num;
#pragma unroll #pragma unroll
...@@ -434,7 +426,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward_vec4(output_t *dst, ...@@ -434,7 +426,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward_vec4(output_t *dst,
template <typename input_t, typename output_t, typename acc_t, int WARP_BATCH, int WARP_ITERATIONS, int WARP_SIZE, int ELEMENTS_PER_LDG_STG> template <typename input_t, typename output_t, typename acc_t, int WARP_BATCH, int WARP_ITERATIONS, int WARP_SIZE, int ELEMENTS_PER_LDG_STG>
__global__ void additive_masked_softmax_dropout_warp_forward(output_t *dst, uint8_t *dropout_mask, const input_t *src, const input_t *pad_mask, int batch_size, int stride, int element_count, int pad_batch_stride, std::pair<uint64_t,uint64_t> seeds, float p) __global__ void additive_masked_softmax_dropout_warp_forward(output_t *dst, uint8_t *dropout_mask, const input_t *src, const input_t *pad_mask, int batch_size, int stride, int element_count, int pad_batch_stride, at::PhiloxCudaState philox_args, float p)
{ {
assert(ELEMENTS_PER_LDG_STG==1); assert(ELEMENTS_PER_LDG_STG==1);
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
...@@ -541,10 +533,11 @@ __global__ void additive_masked_softmax_dropout_warp_forward(output_t *dst, uint ...@@ -541,10 +533,11 @@ __global__ void additive_masked_softmax_dropout_warp_forward(output_t *dst, uint
} }
} }
curandStatePhilox4_32_10_t state; curandStatePhilox4_32_10_t state;
auto seeds = at::cuda::philox::unpack(philox_args);
curand_init( curand_init(
seeds.first, std::get<0>(seeds),
tid, tid,
seeds.second, std::get<1>(seeds),
&state); &state);
// store result // store result
...@@ -585,7 +578,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward(output_t *dst, uint ...@@ -585,7 +578,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward(output_t *dst, uint
// WARP_SIZE number of elements working on a single batch, has to be a power of two. // WARP_SIZE number of elements working on a single batch, has to be a power of two.
// ELEMENTS_PER_LDG_STG has to be 1. // ELEMENTS_PER_LDG_STG has to be 1.
template <typename input_t, typename output_t, typename acc_t> template <typename input_t, typename output_t, typename acc_t>
using additive_masked_softmax_dropout_forward_func = void(*)(output_t *dst, uint8_t *dropout_mask, const input_t *src, const input_t *pad_mask, int batch_size, int stride, int element_count, int pad_batch_stride, std::pair<uint64_t,uint64_t> seeds, float p); using additive_masked_softmax_dropout_forward_func = void(*)(output_t *dst, uint8_t *dropout_mask, const input_t *src, const input_t *pad_mask, int batch_size, int stride, int element_count, int pad_batch_stride, at::PhiloxCudaState philox_args, float p);
template <typename input_t, typename output_t, typename acc_t> template <typename input_t, typename output_t, typename acc_t>
...@@ -622,7 +615,6 @@ bool warp_additive_masked_softmax_dropout_kernel(int element_count, int log2_ele ...@@ -622,7 +615,6 @@ bool warp_additive_masked_softmax_dropout_kernel(int element_count, int log2_ele
case 7: // 128 case 7: // 128
if (flag_vec4) kernel = &additive_masked_softmax_dropout_warp_forward_vec4<input_t, output_t, acc_t, 2,4,32,4>; if (flag_vec4) kernel = &additive_masked_softmax_dropout_warp_forward_vec4<input_t, output_t, acc_t, 2,4,32,4>;
else kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 2,4,32,1>; else kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 2,4,32,1>;
//kernel = &additive_masked_softmax_dropout_warp_forward<input_t, output_t, acc_t, 2,4,32,1>;
break; break;
case 8: // 256 case 8: // 256
if (flag_vec4) kernel = &additive_masked_softmax_dropout_warp_forward_vec4<input_t, output_t, acc_t, 1,8,32,4>; if (flag_vec4) kernel = &additive_masked_softmax_dropout_warp_forward_vec4<input_t, output_t, acc_t, 1,8,32,4>;
...@@ -671,19 +663,13 @@ bool dispatch_additive_masked_softmax_dropout(output_t *dst, uint8_t *dropout_ma ...@@ -671,19 +663,13 @@ bool dispatch_additive_masked_softmax_dropout(output_t *dst, uint8_t *dropout_ma
int warps_per_block = (threads_per_block / warp_size); int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp; int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (batch_count + batches_per_block - 1) / batches_per_block; int blocks = (batch_count + batches_per_block - 1) / batches_per_block;
c10::optional<at::Generator> gen_;
auto gen = at::cuda::detail::getDefaultCUDAGenerator(); auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(gen_, at::cuda::detail::getDefaultCUDAGenerator());
int64_t counter_offset = (totalElements/(blocks*threads_per_block)+1); int64_t counter_offset = (totalElements/(blocks*threads_per_block)+1);
std::pair<uint64_t, uint64_t> rng_engine_inputs; at::PhiloxCudaState rng_engine_inputs;
{ {
// See Note [Acquire lock when using random generators]
#ifdef OLD_GENERATOR
std::lock_guard<std::mutex> lock(gen->mutex_); std::lock_guard<std::mutex> lock(gen->mutex_);
rng_engine_inputs = gen->philox_engine_inputs(counter_offset); rng_engine_inputs = gen->philox_cuda_state(counter_offset);
#else
std::lock_guard<std::mutex> lock(gen.mutex());
rng_engine_inputs = at::check_generator<at::CUDAGeneratorImpl>(gen)->philox_engine_inputs(counter_offset);
#endif
} }
// compute launch size // compute launch size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment