Unverified Commit f63b27e8 authored by Przemyslaw Tredak's avatar Przemyslaw Tredak Committed by GitHub
Browse files

Fix the integer overflow in fused softmax (#60)


Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
parent b67fe451
...@@ -121,8 +121,9 @@ __global__ void scaled_softmax_warp_forward( ...@@ -121,8 +121,9 @@ __global__ void scaled_softmax_warp_forward(
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// gridDim/blockIdx = (seq_len, attn_heads, batches) // gridDim/blockIdx = (seq_len, attn_heads, batches)
int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z)) size_t first_batch = (blockDim.y * (blockIdx.x + gridDim.x *
+ threadIdx.y) * WARP_BATCH; (blockIdx.y + gridDim.y * blockIdx.z))
+ threadIdx.y) * WARP_BATCH;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how // micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP. // many batches have to computed within this WARP.
...@@ -133,8 +134,9 @@ __global__ void scaled_softmax_warp_forward( ...@@ -133,8 +134,9 @@ __global__ void scaled_softmax_warp_forward(
// there might be multiple batches per warp. compute the index within the batch // there might be multiple batches per warp. compute the index within the batch
int local_idx = threadIdx.x; int local_idx = threadIdx.x;
src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; size_t thread_offset = first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; src += thread_offset;
dst += thread_offset;
// load data from global memory // load data from global memory
acc_t elements[WARP_BATCH][WARP_ITERATIONS]; acc_t elements[WARP_BATCH][WARP_ITERATIONS];
...@@ -236,9 +238,10 @@ __global__ void scaled_masked_softmax_warp_forward( ...@@ -236,9 +238,10 @@ __global__ void scaled_masked_softmax_warp_forward(
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// gridDim/blockIdx = (seq_len, attn_heads, batches) // gridDim/blockIdx = (seq_len, attn_heads, batches)
int first_batch = (blockDim.y * (blockIdx.x + gridDim.x * (blockIdx.y + gridDim.y * blockIdx.z)) size_t first_batch = (blockDim.y * (blockIdx.x + gridDim.x *
+ threadIdx.y) * WARP_BATCH; (blockIdx.y + gridDim.y * blockIdx.z))
int pad_first_batch = 0; + threadIdx.y) * WARP_BATCH;
size_t pad_first_batch = 0;
if (pad_batches != 1) { // bert style if (pad_batches != 1) { // bert style
pad_first_batch = (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y) pad_first_batch = (blockDim.y * (blockIdx.x + gridDim.x * blockIdx.z) + threadIdx.y)
* WARP_BATCH; * WARP_BATCH;
...@@ -255,9 +258,11 @@ __global__ void scaled_masked_softmax_warp_forward( ...@@ -255,9 +258,11 @@ __global__ void scaled_masked_softmax_warp_forward(
// there might be multiple batches per warp. compute the index within the batch // there might be multiple batches per warp. compute the index within the batch
int local_idx = threadIdx.x; int local_idx = threadIdx.x;
src += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; size_t thread_offset_src_dst = first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
dst += first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; size_t thread_offset_mask = pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
mask += pad_first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; src += thread_offset_src_dst;
dst += thread_offset_src_dst;
mask += thread_offset_mask;
// load data from global memory // load data from global memory
acc_t elements[WARP_BATCH][WARP_ITERATIONS]; acc_t elements[WARP_BATCH][WARP_ITERATIONS];
...@@ -365,7 +370,7 @@ __global__ void scaled_masked_softmax_warp_backward( ...@@ -365,7 +370,7 @@ __global__ void scaled_masked_softmax_warp_backward(
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, ) // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// gridDim/blockIdx = (seq_len, attn_heads, batches) // gridDim/blockIdx = (seq_len, attn_heads, batches)
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; size_t first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how // micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP. // many batches have to computed within this WARP.
...@@ -377,7 +382,7 @@ __global__ void scaled_masked_softmax_warp_backward( ...@@ -377,7 +382,7 @@ __global__ void scaled_masked_softmax_warp_backward(
int local_idx = threadIdx.x; int local_idx = threadIdx.x;
// the first element to process by the current thread // the first element to process by the current thread
int thread_offset = first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx; size_t thread_offset = first_batch * element_count + ELEMENTS_PER_LDG_STG * local_idx;
grad += thread_offset; grad += thread_offset;
output += thread_offset; output += thread_offset;
gradInput += thread_offset; gradInput += thread_offset;
......
...@@ -139,7 +139,8 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward( ...@@ -139,7 +139,8 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; size_t first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH
+ blockIdx.x;
int local_seq = blockIdx.x + 1; int local_seq = blockIdx.x + 1;
int warp_iteration_limit = (local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1)/ WARP_SIZE; int warp_iteration_limit = (local_seq + ELEMENTS_PER_LDG_STG * WARP_SIZE - 1)/ WARP_SIZE;
...@@ -152,8 +153,9 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward( ...@@ -152,8 +153,9 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
// there might be multiple batches per warp. compute the index within the batch // there might be multiple batches per warp. compute the index within the batch
int local_idx = threadIdx.x; int local_idx = threadIdx.x;
src += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; size_t thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
dst += first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; src += thread_offset;
dst += thread_offset;
// load data from global memory // load data from global memory
acc_t elements[WARP_BATCH][WARP_ITERATIONS]; acc_t elements[WARP_BATCH][WARP_ITERATIONS];
...@@ -263,7 +265,8 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward( ...@@ -263,7 +265,8 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward(
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4; constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; size_t first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH
+ blockIdx.x;
int local_seq = blockIdx.x + 1; int local_seq = blockIdx.x + 1;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how // micro_batch_size might not be a multiple of WARP_BATCH. Check how
...@@ -276,7 +279,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward( ...@@ -276,7 +279,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward(
int local_idx = threadIdx.x; int local_idx = threadIdx.x;
// the first element to process by the current thread // the first element to process by the current thread
int thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx; size_t thread_offset = first_batch * stride + ELEMENTS_PER_LDG_STG * local_idx;
grad += thread_offset; grad += thread_offset;
output += thread_offset; output += thread_offset;
gradInput += thread_offset; gradInput += thread_offset;
......
...@@ -687,6 +687,10 @@ at::Tensor scaled_masked_softmax_forward(at::Tensor input, ...@@ -687,6 +687,10 @@ at::Tensor scaled_masked_softmax_forward(at::Tensor input,
(input.scalar_type() == at::ScalarType::BFloat16), (input.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported"); "Only fp16 and bf16 are supported");
AT_ASSERTM(mask.dim() == 4, "expected 4D tensor"); AT_ASSERTM(mask.dim() == 4, "expected 4D tensor");
if (!input.is_contiguous())
input = input.contiguous();
if (!mask.is_contiguous())
mask = mask.contiguous();
const int batches = input.size(0); const int batches = input.size(0);
const int pad_batches = mask.size(0); const int pad_batches = mask.size(0);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment