Commit 62424a74 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

simd fix for upper traingular softmax kernel

parent 5e22689a
...@@ -125,7 +125,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward( ...@@ -125,7 +125,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = 4; constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x;
int local_seq = blockIdx.x + 1; int local_seq = blockIdx.x + 1;
...@@ -245,7 +245,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward( ...@@ -245,7 +245,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward(
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE; constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = 4; constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x; int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x;
int local_seq = blockIdx.x + 1; int local_seq = blockIdx.x + 1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment