Unverified Commit d9eb1991 authored by Oleg Goncharov's avatar Oleg Goncharov Committed by GitHub
Browse files

[common] Added new unfused softmax cuda kernel to support causal attention mask (#652)



* Added new unfused softmax cuda kernel to support causal attention mask
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Added test suite for unfused causal softmax kernel
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Removed test cases with large matrices from the causal softmax test suite
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Cleaned up the code per lint
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Added a compute buffer to causal softmax testing suite to store intermediate results without casting
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Added more tests cases
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Relaxed absolute tolerance atol
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Relaxed absolute tolerance for BF16
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

---------
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>
parent 94de051f
...@@ -15,6 +15,7 @@ add_executable(test_operator ...@@ -15,6 +15,7 @@ add_executable(test_operator
test_layernorm.cu test_layernorm.cu
test_rmsnorm.cu test_rmsnorm.cu
test_multi_cast_transpose.cu test_multi_cast_transpose.cu
test_causal_softmax.cu
../test_common.cu) ../test_common.cu)
list(APPEND test_operator_LINKER_LIBS CUDA::cudart GTest::gtest_main ${TE_LIB}) list(APPEND test_operator_LINKER_LIBS CUDA::cudart GTest::gtest_main ${TE_LIB})
......
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/softmax.h>
#include "../test_common.h"
using namespace transformer_engine;
namespace {
using compute_t = float;
template <typename Type>
void compute_single_head_fwd(
Type *softmax_out,
const Type *data_in,
compute_t *buff,
const float scaling_factor,
const int rows,
const int cols)
{
for (int i = 0; i < rows; ++i) {
size_t offset = i * cols;
const int masked_elements = i + cols - rows + 1;
compute_t max_value = static_cast<compute_t>(-10'000.f);
for (int j = 0; j < masked_elements; ++j) {
compute_t tmp = scaling_factor * static_cast<compute_t>(data_in[offset + j]);
buff[offset + j] = tmp;
max_value = std::max(max_value, tmp);
}
compute_t accumulator = static_cast<compute_t>(0.f);
for (int j = 0; j < masked_elements; ++j) {
compute_t tmp = std::exp(buff[offset + j] - max_value);
buff[offset + j] = tmp;
accumulator += tmp;
}
for (int j = 0; j < cols; ++j) {
if (j < masked_elements) {
compute_t tmp = buff[offset + j] / accumulator;
softmax_out[offset + j] = static_cast<Type>(tmp);
} else {
softmax_out[offset + j] = static_cast<Type>(0.f);
}
}
}
}
template <typename Type>
void compute_single_head_bwd(
Type *grad_out,
const Type *grad_in,
const Type *softmax_in,
compute_t *buff,
const float scaling_factor,
const int batches,
const int heads,
const int rows,
const int cols)
{
for (int i = 0; i < rows; ++i) {
size_t offset = i * cols;
const int masked_elements = i + cols - rows + 1;
compute_t accumulator = static_cast<compute_t>(0.f);
for (int j = 0; j < masked_elements; ++j) {
compute_t tmp = static_cast<compute_t>(softmax_in[offset + j])
* static_cast<compute_t>(grad_in[offset + j]);
buff[offset + j] = tmp;
accumulator += tmp;
}
for (int j = 0; j < cols; ++j) {
if (j < masked_elements) {
compute_t tmp = buff[offset + j]
- static_cast<compute_t>(softmax_in[offset + j]) * accumulator;
grad_out[offset + j] = static_cast<Type>(scaling_factor * tmp);
} else {
grad_out[offset + j] = static_cast<Type>(0.f);
}
}
}
}
template <typename Type>
void compute_fwd_ref(
Type *softmax_out,
const Type *data_in,
compute_t *buff,
const float scaling_factor,
const int batches,
const int heads,
const int rows,
const int cols)
{
size_t head_size = rows * cols;
size_t batch_size = heads * head_size;
for (int b = 0; b < batches; ++b) {
for (int h = 0; h < heads; ++h) {
size_t offset = b * batch_size + h * head_size;
compute_single_head_fwd(softmax_out + offset, data_in + offset,
buff + offset, scaling_factor, rows, cols);
}
}
}
template <typename Type>
void compute_bwd_ref(
Type *grad_out,
const Type *grad_in,
const Type *softmax_in,
compute_t *buff,
const float scaling_factor,
const int batches,
const int heads,
const int rows,
const int cols)
{
size_t head_size = rows * cols;
size_t batch_size = heads * head_size;
for (int b = 0; b < batches; ++b) {
for (int h = 0; h < heads; ++h) {
size_t offset = b * batch_size + h * head_size;
compute_single_head_bwd(grad_out + offset, grad_in + offset, softmax_in + offset,
buff + offset, scaling_factor, batches, heads, rows, cols);
}
}
}
// Query Sequence Length = rows
// Key Sequence Length = cols
template <typename Type>
void performTest(
const size_t batches,
const size_t heads,
const size_t rows,
const size_t cols,
float scaling_factor)
{
using namespace test;
DType itype = TypeInfo<Type>::dtype;
Tensor data_in({ batches, heads, rows, cols }, itype);
Tensor softmax_out({ batches, heads, rows, cols }, itype);
Tensor softmax_in({ batches, heads, rows, cols }, itype);
Tensor grads_in({ batches, heads, rows, cols }, itype);
Tensor grads_out({ batches, heads, rows, cols }, itype);
const size_t elements_total = batches * heads * rows * cols;
std::unique_ptr<Type[]> softmax_out_ref = std::make_unique<Type[]>(elements_total);
std::unique_ptr<Type[]> grads_out_ref = std::make_unique<Type[]>(elements_total);
std::unique_ptr<compute_t[]> compute_buffer = std::make_unique<compute_t[]>(elements_total);
fillUniform(&data_in);
fillUniform(&softmax_in);
fillUniform(&grads_in);
nvte_scaled_aligned_causal_masked_softmax_forward(
data_in.data(), softmax_out.data(), scaling_factor, 0);
nvte_scaled_aligned_causal_masked_softmax_backward(
grads_in.data(), softmax_in.data(), grads_out.data(), scaling_factor, 0);
// Reference implementations
compute_fwd_ref(softmax_out_ref.get(), data_in.cpu_dptr<Type>(),
compute_buffer.get(), scaling_factor, batches, heads, rows, cols);
compute_bwd_ref(grads_out_ref.get(), grads_in.cpu_dptr<Type>(), softmax_in.cpu_dptr<Type>(),
compute_buffer.get(), scaling_factor, batches, heads, rows, cols);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
auto [atol, rtol] = getTolerances(itype);
if(itype == DType::kBFloat16) {
atol = 1e-3;
}
compareResults("softmax_fwd", softmax_out, softmax_out_ref.get(), atol, rtol);
compareResults("softmax_bwd", grads_out, grads_out_ref.get(), atol, rtol);
}
// [Batches, Attention Heads, Query Sequence Length, Key Sequence Length, Scaling Factor]
std::vector<std::tuple<size_t, size_t, size_t, size_t, float>> test_cases = {
{ 1, 1, 1, 16, -1.0f},
{ 1, 2, 17, 32, 0.8f},
{ 2, 1, 37, 112, 1.0f},
{ 2, 4, 127, 128, -0.2f},
{ 8, 6, 128, 256, 1.3f},
{ 1, 4, 270, 256, 0.8f},
{ 2, 2, 512, 512, -1.5f},
{ 1, 2, 819, 1024, 2.1f},
{ 1, 2, 281, 1024, 0.2f},
{ 1, 2, 277, 1024, -2.1f},
{ 1, 2, 127, 1024, 1.1f},
{ 2, 2, 107, 2048, 0.4f},
{ 2, 1, 103, 2048, -3.0f},
{ 2, 2, 101, 2048, 2.6f},
{ 1, 1, 1024, 4096, 0.6f},
{ 1, 2, 61, 4096, 0.6f},
{ 1, 2, 59, 4096, -4.9f},
{ 1, 2, 53, 4096, 3.5f},
{ 1, 1, 37, 8192, 0.7f},
{ 1, 1, 31, 8192, -5.8f},
{ 1, 1, 29, 8192, 4.4f},
{ 1, 1, 23, 12288, 0.8f},
{ 1, 1, 19, 12288, -6.7f},
{ 1, 1, 17, 12288, 3.3f},
{ 1, 1, 13, 16384, 0.9f},
{ 1, 1, 11, 16384, -7.6f},
{ 1, 1, 7, 16384, 6.2f}};
} // namespace
class CausalSoftmaxTestSuite
: public ::testing::TestWithParam<std::tuple<
transformer_engine::DType,
std::tuple<size_t, size_t, size_t, size_t, float>>> {};
TEST_P(CausalSoftmaxTestSuite, TestCausalSoftmax) {
using namespace transformer_engine;
using namespace test;
const DType input_type = std::get<0>(GetParam());
const auto size = std::get<1>(GetParam());
const size_t batches = std::get<0>(size);
const size_t heads = std::get<1>(size);
const size_t query_seq_len = std::get<2>(size);
const size_t key_seq_len = std::get<3>(size);
const float scaling_factor = std::get<4>(size);
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
performTest<InputType>(batches, heads, query_seq_len, key_seq_len, scaling_factor);
);
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
CausalSoftmaxTestSuite,
::testing::Combine(
::testing::Values(DType::kFloat16, DType::kBFloat16),
::testing::ValuesIn(test_cases)),
[](const testing::TestParamInfo<CausalSoftmaxTestSuite::ParamType>& info) {
const auto size = std::get<1>(info.param);
const size_t batches = std::get<0>(size);
const size_t heads = std::get<1>(size);
const size_t query_seq_len = std::get<2>(size);
const size_t key_seq_len = std::get<3>(size);
std::string scaling_factor = std::to_string(std::get<4>(size));
for (char& c : scaling_factor) {
if (c == '-') { c = 'N'; }
if (c == '.') { c = 'p'; }
}
std::string name = test::typeName(std::get<0>(info.param)) + "X" +
std::to_string(batches) + "X" +
std::to_string(heads) + "X" +
std::to_string(query_seq_len) + "X" +
std::to_string(key_seq_len) + "X" +
scaling_factor;
return name;
});
...@@ -33,8 +33,7 @@ list(APPEND transformer_engine_SOURCES ...@@ -33,8 +33,7 @@ list(APPEND transformer_engine_SOURCES
util/system.cpp util/system.cpp
fused_softmax/scaled_masked_softmax.cu fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_softmax/scaled_masked_softmax.cu fused_softmax/scaled_aligned_causal_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_rope/fused_rope.cu) fused_rope/fused_rope.cu)
add_library(transformer_engine SHARED ${transformer_engine_SOURCES}) add_library(transformer_engine SHARED ${transformer_engine_SOURCES})
target_include_directories(transformer_engine PUBLIC target_include_directories(transformer_engine PUBLIC
...@@ -87,6 +86,7 @@ target_include_directories(transformer_engine PRIVATE ...@@ -87,6 +86,7 @@ target_include_directories(transformer_engine PRIVATE
# Compiler options # Compiler options
set_source_files_properties(fused_softmax/scaled_masked_softmax.cu set_source_files_properties(fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_softmax/scaled_aligned_causal_masked_softmax.cu
PROPERTIES PROPERTIES
COMPILE_OPTIONS "--use_fast_math") COMPILE_OPTIONS "--use_fast_math")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
......
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <assert.h>
#include <stdint.h>
#include <cfloat>
#include <limits>
#include <array>
#include <functional>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <transformer_engine/softmax.h>
#include "../common.h"
#include "../utils.cuh"
#include "../util/logging.h"
namespace transformer_engine {
template<typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
template<>
__device__ __inline__ void copy_vector<bf16, 1>(bf16 *dst, const bf16 *src) {
*dst = *src;
}
template<>
__device__ __inline__ void copy_vector<bf16, 4>(bf16 *dst, const bf16 *src) {
*((uint64_t*) dst) = *((uint64_t*) src); // NOLINT(*)
}
template<>
__device__ __inline__ void copy_vector<fp16, 1>(fp16 *dst, const fp16 *src) {
*dst = *src;
}
template<>
__device__ __inline__ void copy_vector<fp16, 4>(fp16 *dst, const fp16 *src) {
*((uint64_t*) dst) = *((uint64_t*) src); // NOLINT(*)
}
template<>
__device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *src) {
*dst = *src;
}
template<>
__device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst, const uint8_t *src) {
*((uint32_t*) dst) = *((uint32_t*) src); // NOLINT(*)
}
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_zero_vector(Datatype *dst);
template <>
__device__ __inline__ void copy_zero_vector<bf16, 1>(bf16 *dst) {
*dst = 0.0f;
}
template <>
__device__ __inline__ void copy_zero_vector<bf16, 4>(bf16 *dst) {
*((float2*) dst) = make_float2(0.0f, 0.0f); // NOLINT(*)
}
template <>
__device__ __inline__ void copy_zero_vector<fp16, 1>(fp16 *dst) {
*dst = 0.0f;
}
template <>
__device__ __inline__ void copy_zero_vector<fp16, 4>(fp16 *dst) {
*((float2*) dst) = make_float2(0.0f, 0.0f); // NOLINT(*)
}
template<typename T>
struct Add {
__device__ __forceinline__ T operator()(T a, T b) const {
return a + b;
}
};
template<typename T>
struct Max {
__device__ __forceinline__ T operator()(T a, T b) const {
return a < b ? b : a;
}
};
template<typename T>
__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize,
unsigned int mask = 0xffffffff) {
#if CUDA_VERSION >= 9000
return __shfl_xor_sync(mask, value, laneMask, width);
#else
return __shfl_xor(value, laneMask, width);
#endif
}
template<typename acc_t, int WARP_ROWS, int WARP_SIZE, template<typename> class ReduceOp>
__device__ __forceinline__ void warp_reduce(acc_t* sum) {
ReduceOp<acc_t> r;
#pragma unroll
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0; i < WARP_ROWS; ++i) {
acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE);
sum[i] = r(sum[i], b);
}
}
}
/*
* Extended softmax (from native aten pytorch) with the following additional features
* 1) input scaling
* 2) implicit causal masking
*
* works for all cases:
* k > q
* k < q
* k = q
*
* where:
* microbatches = batches * attn_heads * query_seq_len
* rows = query_seq_len
* cols = key_seq_len
*/
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
__global__ void scaled_aligned_causal_masked_softmax_warp_forward(
output_t *dst,
const input_t *src,
const acc_t scale,
const int microbatches,
const int rows,
const int cols
) {
// 1) WARP_WIDTH must match the value of warp_size
// 2) WARP_ROWS must match the value of rows_per_warp
// of the dispatch_scaled_aligned_causal_masked_softmax_forward method.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_WIDTH = (next_power_of_two < THREADS_PER_WARP) ? next_power_of_two
: THREADS_PER_WARP;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_WIDTH;
constexpr int WARP_ROWS = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
const int global_row_idx = (blockIdx.x * blockDim.y + threadIdx.y) * WARP_ROWS;
const int col = threadIdx.x * ELEMENTS_PER_LDG_STG;
const size_t thread_offset = global_row_idx * cols + col;
src += thread_offset;
dst += thread_offset;
// load data from global memory into registers WITH scaling
acc_t elements[WARP_ROWS][WARP_ITERATIONS];
input_t temp_data[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int w = 0; w < WARP_ROWS; ++w) {
const int microbatch = global_row_idx + w;
const int i = microbatch % rows; // local row index of attention matrix
const int masked_elements = i + cols - rows + 1;
if (microbatch >= microbatches) {
break;
}
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
const int j = col + it * WARP_WIDTH;
const int itr_idx = w * cols + it * WARP_WIDTH;
if (j < masked_elements) {
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_data, src + itr_idx);
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if (j + element < masked_elements) {
elements[w][it + element] = (acc_t)temp_data[element] * scale;
} else {
elements[w][it + element] = (acc_t)( -10'000 );
}
}
} else {
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
elements[w][it + element] = (acc_t)( -10'000 );
}
}
}
}
// compute max_value
acc_t max_value[WARP_ROWS];
#pragma unroll
for (int w = 0; w < WARP_ROWS; ++w) {
max_value[w] = elements[w][0];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
max_value[w] =
(max_value[w] > elements[w][it]) ? max_value[w] : elements[w][it];
}
}
warp_reduce<acc_t, WARP_ROWS, WARP_WIDTH, Max>(max_value);
acc_t sum[WARP_ROWS] { 0.0f };
#pragma unroll
for (int w = 0; w < WARP_ROWS; ++w) {
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; ++it) {
elements[w][it] = expf((elements[w][it] - max_value[w]));
sum[w] += elements[w][it];
}
}
warp_reduce<acc_t, WARP_ROWS, WARP_WIDTH, Add>(sum);
output_t out[ELEMENTS_PER_LDG_STG] { 0.0f };
// store result
#pragma unroll
for (int w = 0; w < WARP_ROWS; ++w) {
const int microbatch = global_row_idx + w;
const int i = microbatch % rows;
const int masked_elements = i + cols - rows + 1;
// out of Attention matrix bounds (rows)
if (microbatch >= microbatches) {
break;
}
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
const int j = col + it * WARP_WIDTH; // index of the first column
const int itr_idx = w * cols + it * WARP_WIDTH;
if (j < masked_elements) {
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if (j + element < masked_elements) {
out[element] = elements[w][it + element] / sum[w];
} else {
out[element] = (output_t)( 0.0f );
}
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + itr_idx, out);
} else if (j < cols) {
copy_zero_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + itr_idx);
} else {
break;
}
}
}
}
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
__global__ void scaled_aligned_causal_masked_softmax_warp_backward(
output_t *gradInput,
const input_t *grad,
const input_t *softmax_output,
const acc_t scale,
const int microbatches,
const int rows,
const int cols
) {
// 1) WARP_WIDTH must match the value of warp_size
// 2) WARP_ROWS must match the value of rows_per_warp
// of the dispatch_scaled_aligned_causal_masked_softmax_forward method.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_WIDTH = (next_power_of_two < THREADS_PER_WARP) ? next_power_of_two
: THREADS_PER_WARP;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_WIDTH;
constexpr int WARP_ROWS = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
const int global_row_idx = (blockIdx.x * blockDim.y + threadIdx.y) * WARP_ROWS;
const int col = threadIdx.x * ELEMENTS_PER_LDG_STG;
const size_t thread_offset = global_row_idx * cols + col;
grad += thread_offset;
softmax_output += thread_offset;
gradInput += thread_offset;
// load data from global memory into registers
acc_t grad_reg[WARP_ROWS][WARP_ITERATIONS] { 0.0f };
acc_t softmax_output_reg[WARP_ROWS][WARP_ITERATIONS] { 0.0f };
input_t temp_grad[ELEMENTS_PER_LDG_STG];
input_t temp_output[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int w = 0; w < WARP_ROWS; ++w) {
const int microbatch = global_row_idx + w;
const int i = microbatch % rows; // local row index of attention matrix
const int masked_elements = i + cols - rows + 1;
if (microbatch >= microbatches) {
break;
}
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
const int j = col + it * WARP_WIDTH; // index of the first column
const int itr_idx = w * cols + it * WARP_WIDTH;
if (j < masked_elements) {
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_grad, grad + itr_idx);
copy_vector<input_t, ELEMENTS_PER_LDG_STG>(temp_output, softmax_output + itr_idx);
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
if (j + element < masked_elements) {
softmax_output_reg[w][it + element] = (acc_t)temp_output[element];
grad_reg[w][it + element] =
(acc_t)temp_grad[element] * softmax_output_reg[w][it + element];
}
}
}
}
}
acc_t sum[WARP_ROWS];
#pragma unroll
for (int w = 0; w < WARP_ROWS; ++w) {
sum[w] = grad_reg[w][0];
#pragma unroll
for (int it = 1; it < WARP_ITERATIONS; ++it) {
sum[w] += grad_reg[w][it];
}
}
warp_reduce<acc_t, WARP_ROWS, WARP_WIDTH, Add>(sum);
// store result
#pragma unroll
for (int w = 0; w < WARP_ROWS; ++w) {
const int microbatch = global_row_idx + w;
if (microbatch >= microbatches) {
break;
}
#pragma unroll
for (int it = 0; it < WARP_ITERATIONS; it += ELEMENTS_PER_LDG_STG) {
const int j = col + it * WARP_WIDTH; // index of the first column
const int itr_idx = w * cols + it * WARP_WIDTH;
if (j < cols) {
output_t out[ELEMENTS_PER_LDG_STG];
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
out[element] = (output_t)(scale * (grad_reg[w][it + element] -
softmax_output_reg[w][it + element] * sum[w]));
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(gradInput + itr_idx, out);
}
}
}
}
template<typename input_t, typename output_t, typename acc_t, int log2_elements>
void call_kernel_scaled_aligned_causal_masked_softmax_forward(
dim3 grid_size,
dim3 block_size,
const int shmem_size,
cudaStream_t stream,
output_t *dst,
const input_t *src,
const acc_t scale,
const int microbatches,
const int query_seq_len,
const int key_seq_len
) {
scaled_aligned_causal_masked_softmax_warp_forward<input_t, output_t, acc_t, log2_elements>
<<<grid_size, block_size, shmem_size, stream>>>(
dst, src, scale, microbatches, query_seq_len, key_seq_len);
}
template<typename input_t, typename output_t, typename acc_t, int log2_elements>
void call_kernel_scaled_aligned_causal_masked_softmax_backward(
dim3 grid_size,
dim3 block_size,
const int shmem_size,
cudaStream_t stream,
output_t *gradInput,
const input_t *grad,
const input_t *output,
const acc_t scale,
const int microbatches,
const int query_seq_len,
const int key_seq_len
) {
scaled_aligned_causal_masked_softmax_warp_backward<input_t, output_t, acc_t, log2_elements>
<<<grid_size, block_size, 0, stream>>>(
gradInput, grad, output, scale, microbatches, query_seq_len, key_seq_len);
}
template<typename input_t, typename output_t, typename acc_t>
struct FunctionWrapper {
using ForwardType = std::function<
void(
dim3 grid_size,
dim3 block_size,
const int shmem_size,
cudaStream_t stream,
output_t *dst,
const input_t *src,
const acc_t scale,
const int microbatches,
const int query_seq_len,
const int key_seq_len
)
>;
using BackwardType = std::function<
void(
dim3 grid_size,
dim3 block_size,
const int shmem_size,
cudaStream_t stream,
output_t *gradInput,
const input_t *grad,
const input_t *output,
const acc_t scale,
const int microbatches,
const int query_seq_len,
const int key_seq_len
)
>;
};
constexpr int MIN_SUPPORTED_POWER = 4;
constexpr int MAX_SUPPORTED_POWER = 14;
constexpr int MIN_POWER = MIN_SUPPORTED_POWER - 1;
constexpr int MAX_POWER = MAX_SUPPORTED_POWER + 1;
// Recursively instantiate the function for the limit of "log2_elements",
// i.e. "MAX_POWER" defined above.
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
struct CompileTimeLoopForward {
using ForwardFuncType = typename FunctionWrapper<input_t, output_t, acc_t>::ForwardType;
static void populate(std::array<ForwardFuncType, MAX_POWER>* arr) {
CompileTimeLoopForward<input_t, output_t, acc_t, log2_elements - 1>::populate(arr);
(*arr)[log2_elements] = &call_kernel_scaled_aligned_causal_masked_softmax_forward<
output_t, input_t, acc_t, log2_elements>;
}
};
template <typename input_t, typename output_t, typename acc_t>
struct CompileTimeLoopForward<input_t, output_t, acc_t, MIN_POWER> {
using ForwardFuncType = typename FunctionWrapper<input_t, output_t, acc_t>::ForwardType;
static void populate(std::array<ForwardFuncType, MAX_POWER>* arr) {
(*arr)[MIN_POWER] = nullptr;
}
};
template <typename input_t, typename output_t, typename acc_t, int log2_elements>
struct CompileTimeLoopBackward {
using BackwardFuncType = typename FunctionWrapper<input_t, output_t, acc_t>::BackwardType;
static void populate(std::array<BackwardFuncType, MAX_POWER>* arr) {
CompileTimeLoopBackward<input_t, output_t, acc_t, log2_elements - 1>::populate(arr);
(*arr)[log2_elements] = &call_kernel_scaled_aligned_causal_masked_softmax_backward<
output_t, input_t, acc_t, log2_elements>;
}
};
template <typename input_t, typename output_t, typename acc_t>
struct CompileTimeLoopBackward<input_t, output_t, acc_t, MIN_POWER> {
using BackwardFuncType = typename FunctionWrapper<input_t, output_t, acc_t>::BackwardType;
static void populate(std::array<BackwardFuncType, MAX_POWER>* arr) {
(*arr)[MIN_POWER] = nullptr;
}
};
template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_aligned_causal_masked_softmax_forward(
output_t *dst,
const input_t *src,
const input_t scale,
int query_seq_len,
int key_seq_len,
int batches,
int attn_heads,
cudaStream_t stream
) {
NVTE_CHECK(key_seq_len >= 0 && key_seq_len <= 16384, "Unsupported shape.");
if (key_seq_len == 0) {
return;
}
int log2_elements = log2_ceil(key_seq_len);
const int next_power_of_two = 1 << log2_elements;
// This value must match the WARP_WIDTH constexpr
// value computed inside scaled_aligned_causal_masked_softmax_warp_forward.
int warp_width = (next_power_of_two < THREADS_PER_WARP) ? next_power_of_two
: THREADS_PER_WARP;
// This value must match the WARP_ROWS constexpr
// value computed inside scaled_aligned_causal_masked_softmax_warp_forward.
int microbatches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = threads_per_block / warp_width;
int microbatches_per_block = warps_per_block * microbatches_per_warp;
int microbatches = batches * attn_heads * query_seq_len;
int blocks = DIVUP(microbatches, microbatches_per_block);
dim3 block_size(warp_width, warps_per_block);
dim3 grid_size(blocks);
// create an array of pointers to functions
using ForwardFuncType = typename FunctionWrapper<input_t, output_t, acc_t>::ForwardType;
static std::array<ForwardFuncType, MAX_POWER> forwardFunctionArray;
static bool is_initialized = false;
if (!is_initialized) {
CompileTimeLoopForward<input_t, output_t, acc_t, MAX_SUPPORTED_POWER>::populate(
&forwardFunctionArray);
is_initialized = true;
}
// Call the corresponding kernel
forwardFunctionArray[log2_elements](grid_size, block_size, 0, stream, dst, src, scale,
microbatches, query_seq_len, key_seq_len);
}
template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_aligned_causal_masked_softmax_backward(
output_t *grad_input,
const input_t *grad,
const input_t *output,
const acc_t scale,
int query_seq_len,
int key_seq_len,
int batches,
int attn_heads,
cudaStream_t stream
) {
NVTE_CHECK(key_seq_len >= 0 && key_seq_len <= 16384, "Unsupported shape.");
if (key_seq_len == 0) {
return;
}
int log2_elements = log2_ceil(key_seq_len);
const int next_power_of_two = 1 << log2_elements;
// This value must match the WARP_WIDTH constexpr
// value computed inside scaled_aligned_causal_masked_softmax_warp_forward.
int warp_width = (next_power_of_two < THREADS_PER_WARP) ? next_power_of_two : THREADS_PER_WARP;
// This value must match the WARP_ROWS constexpr
// value computed inside scaled_aligned_causal_masked_softmax_warp_forward.
int microbatches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = threads_per_block / warp_width;
int microbatches_per_block = warps_per_block * microbatches_per_warp;
int microbatches = batches * attn_heads * query_seq_len;
int blocks = DIVUP(microbatches, microbatches_per_block);
dim3 block_size(warp_width, warps_per_block);
dim3 grid_size(blocks);
// create an array of pointers to functions
using BackwardFuncType = typename FunctionWrapper<input_t, output_t, acc_t>::BackwardType;
static std::array<BackwardFuncType, MAX_POWER> backwardFunctionArray;
static bool is_initialized = false;
if (!is_initialized) {
CompileTimeLoopBackward<input_t, output_t, acc_t, MAX_SUPPORTED_POWER>::populate(
&backwardFunctionArray);
is_initialized = true;
}
// Call the corresponding kernel
backwardFunctionArray[log2_elements](grid_size, block_size, 0, stream, grad_input, grad,
output, scale, microbatches, query_seq_len, key_seq_len);
}
void scaled_aligned_causal_masked_softmax_forward(
const Tensor &input,
Tensor *softmax_results,
float scale_factor,
cudaStream_t stream) {
const int batches = input.data.shape[0];
const int attn_heads = input.data.shape[1];
const int query_seq_len = input.data.shape[2];
const int key_seq_len = input.data.shape[3];
TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(input.data.dtype, softmax_type,
dispatch_scaled_aligned_causal_masked_softmax_forward<softmax_type, softmax_type, float>(
reinterpret_cast<softmax_type*>(softmax_results->data.dptr),
reinterpret_cast<const softmax_type*>(input.data.dptr),
scale_factor,
query_seq_len,
key_seq_len,
batches,
attn_heads,
stream););
}
void scaled_aligned_causal_masked_softmax_backward(
Tensor output_grads,
const Tensor incoming_grads,
const Tensor softmax_results,
float scale_factor,
cudaStream_t stream) {
// output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const int batches = output_grads.data.shape[0];
const int attn_heads = output_grads.data.shape[1];
const int query_seq_len = output_grads.data.shape[2];
const int key_seq_len = output_grads.data.shape[3];
// Softmax Grad
TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(output_grads.data.dtype, softmax_type,
dispatch_scaled_aligned_causal_masked_softmax_backward<softmax_type, softmax_type, float>(
reinterpret_cast<softmax_type*>(output_grads.data.dptr),
reinterpret_cast<softmax_type const*>(incoming_grads.data.dptr),
reinterpret_cast<softmax_type const*>(softmax_results.data.dptr),
scale_factor,
query_seq_len,
key_seq_len,
batches,
attn_heads,
stream););
}
} // end namespace transformer_engine
void nvte_scaled_aligned_causal_masked_softmax_forward(
const NVTETensor input,
NVTETensor softmax_results,
float scale_factor,
cudaStream_t stream
) {
NVTE_API_CALL(nvte_scaled_aligned_causal_masked_softmax_forward);
using namespace transformer_engine;
scaled_aligned_causal_masked_softmax_forward(
*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(softmax_results),
scale_factor,
stream);
}
void nvte_scaled_aligned_causal_masked_softmax_backward(
const NVTETensor incoming_grads,
const NVTETensor softmax_results,
NVTETensor output_grads,
float scale_factor,
cudaStream_t stream
) {
NVTE_API_CALL(nvte_scaled_aligned_causal_masked_softmax_backward);
using namespace transformer_engine;
scaled_aligned_causal_masked_softmax_backward(
*reinterpret_cast<Tensor*>(output_grads),
*reinterpret_cast<const Tensor*>(incoming_grads),
*reinterpret_cast<const Tensor*>(softmax_results),
scale_factor,
stream);
}
...@@ -466,7 +466,7 @@ void dispatch_scaled_softmax_forward( ...@@ -466,7 +466,7 @@ void dispatch_scaled_softmax_forward(
int batches, int batches,
int attn_heads, int attn_heads,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_CHECK(key_seq_len >= 0 && key_seq_len <= 4096, "Unsupported shape."); NVTE_CHECK(key_seq_len >= 0 && key_seq_len <= 16384, "Unsupported shape.");
if (key_seq_len == 0) { if (key_seq_len == 0) {
return; return;
} else { } else {
...@@ -597,6 +597,22 @@ void dispatch_scaled_softmax_forward( ...@@ -597,6 +597,22 @@ void dispatch_scaled_softmax_forward(
batch_count, batch_count,
key_seq_len); key_seq_len);
break; break;
case 13: // 8192
scaled_softmax_warp_forward<input_t, output_t, acc_t, 13>
<<<blocks, threads, 0, stream>>>(dst,
src,
scale,
batch_count,
key_seq_len);
break;
case 14: // 16384
scaled_softmax_warp_forward<input_t, output_t, acc_t, 14>
<<<blocks, threads, 0, stream>>>(dst,
src,
scale,
batch_count,
key_seq_len);
break;
default: default:
break; break;
} }
...@@ -615,7 +631,7 @@ void dispatch_scaled_masked_softmax_forward( ...@@ -615,7 +631,7 @@ void dispatch_scaled_masked_softmax_forward(
int attn_heads, int attn_heads,
int pad_batches, int pad_batches,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_CHECK(key_seq_len >= 0 && key_seq_len <= 4096, "Unsupported shape."); NVTE_CHECK(key_seq_len >= 0 && key_seq_len <= 16384, "Unsupported shape.");
if (key_seq_len == 0) { if (key_seq_len == 0) {
return; return;
} else { } else {
...@@ -772,6 +788,26 @@ void dispatch_scaled_masked_softmax_forward( ...@@ -772,6 +788,26 @@ void dispatch_scaled_masked_softmax_forward(
key_seq_len, key_seq_len,
pad_batches); pad_batches);
break; break;
case 13: // 8192
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 13>
<<<blocks, threads, 0, stream>>>(dst,
src,
mask,
scale,
batch_count,
key_seq_len,
pad_batches);
break;
case 14: // 16384
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 14>
<<<blocks, threads, 0, stream>>>(dst,
src,
mask,
scale,
batch_count,
key_seq_len,
pad_batches);
break;
default: default:
break; break;
} }
...@@ -789,7 +825,7 @@ void dispatch_scaled_masked_softmax_backward( ...@@ -789,7 +825,7 @@ void dispatch_scaled_masked_softmax_backward(
int batches, int batches,
int attn_heads, int attn_heads,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_CHECK(key_seq_len >= 0 && key_seq_len <= 4096, "Unsupported shape."); NVTE_CHECK(key_seq_len >= 0 && key_seq_len <= 16384, "Unsupported shape.");
if (key_seq_len == 0) { if (key_seq_len == 0) {
return; return;
} else { } else {
...@@ -833,7 +869,6 @@ void dispatch_scaled_masked_softmax_backward( ...@@ -833,7 +869,6 @@ void dispatch_scaled_masked_softmax_backward(
batch_count, batch_count,
key_seq_len); key_seq_len);
break; break;
break;
case 2: // 4 case 2: // 4
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 2> scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 2>
<<<blocks, threads, 0, stream>>>(grad_input, <<<blocks, threads, 0, stream>>>(grad_input,
...@@ -843,7 +878,6 @@ void dispatch_scaled_masked_softmax_backward( ...@@ -843,7 +878,6 @@ void dispatch_scaled_masked_softmax_backward(
batch_count, batch_count,
key_seq_len); key_seq_len);
break; break;
break;
case 3: // 8 case 3: // 8
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 3> scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 3>
<<<blocks, threads, 0, stream>>>(grad_input, <<<blocks, threads, 0, stream>>>(grad_input,
...@@ -853,7 +887,6 @@ void dispatch_scaled_masked_softmax_backward( ...@@ -853,7 +887,6 @@ void dispatch_scaled_masked_softmax_backward(
batch_count, batch_count,
key_seq_len); key_seq_len);
break; break;
break;
case 4: // 16 case 4: // 16
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 4> scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 4>
<<<blocks, threads, 0, stream>>>(grad_input, <<<blocks, threads, 0, stream>>>(grad_input,
...@@ -863,7 +896,6 @@ void dispatch_scaled_masked_softmax_backward( ...@@ -863,7 +896,6 @@ void dispatch_scaled_masked_softmax_backward(
batch_count, batch_count,
key_seq_len); key_seq_len);
break; break;
break;
case 5: // 32 case 5: // 32
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 5> scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 5>
<<<blocks, threads, 0, stream>>>(grad_input, <<<blocks, threads, 0, stream>>>(grad_input,
...@@ -873,7 +905,6 @@ void dispatch_scaled_masked_softmax_backward( ...@@ -873,7 +905,6 @@ void dispatch_scaled_masked_softmax_backward(
batch_count, batch_count,
key_seq_len); key_seq_len);
break; break;
break;
case 6: // 64 case 6: // 64
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 6> scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 6>
<<<blocks, threads, 0, stream>>>(grad_input, <<<blocks, threads, 0, stream>>>(grad_input,
...@@ -883,7 +914,6 @@ void dispatch_scaled_masked_softmax_backward( ...@@ -883,7 +914,6 @@ void dispatch_scaled_masked_softmax_backward(
batch_count, batch_count,
key_seq_len); key_seq_len);
break; break;
break;
case 7: // 128 case 7: // 128
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 7> scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 7>
<<<blocks, threads, 0, stream>>>(grad_input, <<<blocks, threads, 0, stream>>>(grad_input,
...@@ -893,7 +923,6 @@ void dispatch_scaled_masked_softmax_backward( ...@@ -893,7 +923,6 @@ void dispatch_scaled_masked_softmax_backward(
batch_count, batch_count,
key_seq_len); key_seq_len);
break; break;
break;
case 8: // 256 case 8: // 256
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 8> scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 8>
<<<blocks, threads, 0, stream>>>(grad_input, <<<blocks, threads, 0, stream>>>(grad_input,
...@@ -903,7 +932,6 @@ void dispatch_scaled_masked_softmax_backward( ...@@ -903,7 +932,6 @@ void dispatch_scaled_masked_softmax_backward(
batch_count, batch_count,
key_seq_len); key_seq_len);
break; break;
break;
case 9: // 512 case 9: // 512
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 9> scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 9>
<<<blocks, threads, 0, stream>>>(grad_input, <<<blocks, threads, 0, stream>>>(grad_input,
...@@ -913,7 +941,6 @@ void dispatch_scaled_masked_softmax_backward( ...@@ -913,7 +941,6 @@ void dispatch_scaled_masked_softmax_backward(
batch_count, batch_count,
key_seq_len); key_seq_len);
break; break;
break;
case 10: // 1024 case 10: // 1024
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 10> scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 10>
<<<blocks, threads, 0, stream>>>(grad_input, <<<blocks, threads, 0, stream>>>(grad_input,
...@@ -923,7 +950,6 @@ void dispatch_scaled_masked_softmax_backward( ...@@ -923,7 +950,6 @@ void dispatch_scaled_masked_softmax_backward(
batch_count, batch_count,
key_seq_len); key_seq_len);
break; break;
break;
case 11: // 2048 case 11: // 2048
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 11> scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, stream>>>(grad_input, <<<blocks, threads, 0, stream>>>(grad_input,
...@@ -933,7 +959,6 @@ void dispatch_scaled_masked_softmax_backward( ...@@ -933,7 +959,6 @@ void dispatch_scaled_masked_softmax_backward(
batch_count, batch_count,
key_seq_len); key_seq_len);
break; break;
break;
case 12: // 4096 case 12: // 4096
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 12> scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 12>
<<<blocks, threads, 0, stream>>>(grad_input, <<<blocks, threads, 0, stream>>>(grad_input,
...@@ -943,8 +968,24 @@ void dispatch_scaled_masked_softmax_backward( ...@@ -943,8 +968,24 @@ void dispatch_scaled_masked_softmax_backward(
batch_count, batch_count,
key_seq_len); key_seq_len);
break; break;
case 13: // 8192
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 13>
<<<blocks, threads, 0, stream>>>(grad_input,
grad,
output,
scale,
batch_count,
key_seq_len);
break;
case 14: // 16384
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 14>
<<<blocks, threads, 0, stream>>>(grad_input,
grad,
output,
scale,
batch_count,
key_seq_len);
break; break;
default: default:
break; break;
} }
......
...@@ -368,7 +368,7 @@ void dispatch_scaled_upper_triang_masked_softmax_forward( ...@@ -368,7 +368,7 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(
int softmax_elements_stride, int softmax_elements_stride,
int attn_batches, int attn_batches,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_CHECK(softmax_elements >= 0 && softmax_elements <= 2048, "Unsupported shape."); NVTE_CHECK(softmax_elements >= 0 && softmax_elements <= 16384, "Unsupported shape.");
if (softmax_elements == 0) { if (softmax_elements == 0) {
return; return;
} else { } else {
...@@ -506,6 +506,33 @@ void dispatch_scaled_upper_triang_masked_softmax_forward( ...@@ -506,6 +506,33 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(
softmax_elements_stride, softmax_elements_stride,
softmax_elements); softmax_elements);
break; break;
case 12: // 4096
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 12>
<<<blocks, threads, 0, stream>>>(dst,
src,
scale,
batch_count,
softmax_elements_stride,
softmax_elements);
break;
case 13: // 8192
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 13>
<<<blocks, threads, 0, stream>>>(dst,
src,
scale,
batch_count,
softmax_elements_stride,
softmax_elements);
break;
case 14: // 16384
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 14>
<<<blocks, threads, 0, stream>>>(dst,
src,
scale,
batch_count,
softmax_elements_stride,
softmax_elements);
break;
default: default:
break; break;
} }
...@@ -522,7 +549,7 @@ void dispatch_scaled_upper_triang_masked_softmax_backward( ...@@ -522,7 +549,7 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(
int softmax_elements_stride, int softmax_elements_stride,
int attn_batches, int attn_batches,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_CHECK(softmax_elements >= 0 && softmax_elements <= 2048, "Unsupported shape."); NVTE_CHECK(softmax_elements >= 0 && softmax_elements <= 16384, "Unsupported shape.");
if (softmax_elements == 0) { if (softmax_elements == 0) {
return; return;
} else { } else {
...@@ -660,6 +687,33 @@ void dispatch_scaled_upper_triang_masked_softmax_backward( ...@@ -660,6 +687,33 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(
softmax_elements_stride, softmax_elements_stride,
softmax_elements); softmax_elements);
break; break;
case 12: // 4096
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 12>
<<<blocks, threads, 0, stream>>>(grad_input,
grad, output,
scale,
batch_count,
softmax_elements_stride,
softmax_elements);
break;
case 13: // 8192
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 13>
<<<blocks, threads, 0, stream>>>(grad_input,
grad, output,
scale,
batch_count,
softmax_elements_stride,
softmax_elements);
break;
case 14: // 16384
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 14>
<<<blocks, threads, 0, stream>>>(grad_input,
grad, output,
scale,
batch_count,
softmax_elements_stride,
softmax_elements);
break;
default: default:
break; break;
} }
......
...@@ -125,6 +125,41 @@ void nvte_scaled_upper_triang_masked_softmax_backward( ...@@ -125,6 +125,41 @@ void nvte_scaled_upper_triang_masked_softmax_backward(
); );
/*! \brief Compute scaled softmax activation using an implicit 2D mask aligned to the bottom right corner of the input matrix.
*
* \param[in] input Input tensor for softmax.
* \param[out] softmax_results Output tensor.
* \param[in] scale_factor Scalar for the input tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_scaled_aligned_causal_masked_softmax_forward(
const NVTETensor input,
NVTETensor softmax_results,
float scale_factor,
cudaStream_t stream
);
/*! \brief Compute the backward pass of the scaled softmax activation using an implicit 2D mask aligned to the bottom right corner of the input matrix.
*
* - `incoming_grads` is the input tensor containing the gradients received from the following layer.
* - `softmax_results` is the output tensor of the corresponding forward softmax operation.
* - `output_grads` is the output tensor containing the computed gradients.
*
* \param[in] incoming_grads Input gradient tensor for backward.
* \param[in] softmax_results Output tensor of softmax forward.
* \param[out] output_grads Output tensor.
* \param[in] scale_factor Scalar for the output tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_scaled_aligned_causal_masked_softmax_backward(
const NVTETensor incoming_grads,
const NVTETensor softmax_results,
NVTETensor output_grads,
float scale_factor,
cudaStream_t stream
);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
#endif #endif
......
...@@ -521,6 +521,17 @@ at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_, ...@@ -521,6 +521,17 @@ at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_,
float scale_factor float scale_factor
); );
at::Tensor scaled_aligned_causal_masked_softmax_forward(at::Tensor input,
float scale_factor
);
at::Tensor scaled_aligned_causal_masked_softmax_backward(at::Tensor output_grads_,
at::Tensor softmax_results_,
float scale_factor
);
/*************************************************************************************************** /***************************************************************************************************
* Rotary positional embedding * Rotary positional embedding
**************************************************************************************************/ **************************************************************************************************/
......
...@@ -23,6 +23,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -23,6 +23,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("scaled_upper_triang_masked_softmax_backward", m.def("scaled_upper_triang_masked_softmax_backward",
&scaled_upper_triang_masked_softmax_backward, &scaled_upper_triang_masked_softmax_backward,
"Scaled Upper-Triangular Masked Softmax BWD"); "Scaled Upper-Triangular Masked Softmax BWD");
m.def("scaled_aligned_causal_masked_softmax_forward",
&scaled_aligned_causal_masked_softmax_forward,
"Scaled Bottom-Right Corner Aligned Masked Softmax FWD");
m.def("scaled_aligned_causal_masked_softmax_backward",
&scaled_aligned_causal_masked_softmax_backward,
"Scaled Bottom-Right Corner Aligned Masked Softmax BWD");
// Other granular functions // Other granular functions
m.def("layernorm_fwd_fp8", &layernorm_fwd_fp8, "LN FWD FP8"); m.def("layernorm_fwd_fp8", &layernorm_fwd_fp8, "LN FWD FP8");
......
...@@ -20,7 +20,7 @@ at::Tensor scaled_softmax_forward(at::Tensor input, ...@@ -20,7 +20,7 @@ at::Tensor scaled_softmax_forward(at::Tensor input,
const int query_seq_len = input.size(2); const int query_seq_len = input.size(2);
const int key_seq_len = input.size(3); const int key_seq_len = input.size(3);
AT_ASSERTM(key_seq_len <= 4096, "Key sequence length must be 4096 or less"); AT_ASSERTM(key_seq_len <= 16384, "Key sequence length must be 16384 or less");
AT_ASSERTM(key_seq_len % 8 == 0, "Key sequence length must be divisible by 8"); AT_ASSERTM(key_seq_len % 8 == 0, "Key sequence length must be divisible by 8");
AT_ASSERTM(query_seq_len > 1, "Query sequence length must be greater than 1"); AT_ASSERTM(query_seq_len > 1, "Query sequence length must be greater than 1");
...@@ -92,7 +92,7 @@ at::Tensor scaled_masked_softmax_forward(at::Tensor input, ...@@ -92,7 +92,7 @@ at::Tensor scaled_masked_softmax_forward(at::Tensor input,
const int query_seq_len = input.size(2); const int query_seq_len = input.size(2);
const int key_seq_len = input.size(3); const int key_seq_len = input.size(3);
AT_ASSERTM(key_seq_len <= 4096, "Key sequence length must be 4096 or less"); AT_ASSERTM(key_seq_len <= 16384, "Key sequence length must be 16384 or less");
AT_ASSERTM(key_seq_len % 8 == 0, "Key sequence length must be divisible by 8"); AT_ASSERTM(key_seq_len % 8 == 0, "Key sequence length must be divisible by 8");
AT_ASSERTM(query_seq_len > 1, "Query sequence length must be greater than 1"); AT_ASSERTM(query_seq_len > 1, "Query sequence length must be greater than 1");
TORCH_CHECK(pad_batches == 1 || pad_batches == batches); TORCH_CHECK(pad_batches == 1 || pad_batches == batches);
...@@ -160,7 +160,7 @@ at::Tensor scaled_upper_triang_masked_softmax_forward(at::Tensor input, ...@@ -160,7 +160,7 @@ at::Tensor scaled_upper_triang_masked_softmax_forward(at::Tensor input,
const int attn_batches = input.size(0); const int attn_batches = input.size(0);
const int seq_len = input.size(1); const int seq_len = input.size(1);
AT_ASSERTM(seq_len <= 2048, "Sequence length must be 2048 or less"); AT_ASSERTM(seq_len <= 16384, "Sequence length must be 16384 or less");
// Output // Output
auto act_options = input.options().requires_grad(false); auto act_options = input.options().requires_grad(false);
...@@ -212,3 +212,75 @@ at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_, ...@@ -212,3 +212,75 @@ at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_,
return output_grads; return output_grads;
} }
at::Tensor scaled_aligned_causal_masked_softmax_forward(
at::Tensor input,
float scale_factor
) {
using namespace transformer_engine;
AT_ASSERTM(input.dim() == 4, "expected 4D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
const int batches = input.size(0);
const int attn_heads = input.size(1);
const int query_seq_len = input.size(2);
const int key_seq_len = input.size(3);
AT_ASSERTM(key_seq_len <= 16384, "Key sequence length must be 16384 or less");
AT_ASSERTM(key_seq_len % 8 == 0, "Key sequence length must be divisible by 8");
AT_ASSERTM(query_seq_len >= 1, "Query sequence length must be greater or equal to 1");
// Output
auto act_options = input.options().requires_grad(false);
auto softmax_results =
torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options);
auto input_cu = makeTransformerEngineTensor(input);
auto softmax_results_cu = makeTransformerEngineTensor(softmax_results);
nvte_scaled_aligned_causal_masked_softmax_forward(
input_cu.data(),
softmax_results_cu.data(),
scale_factor,
at::cuda::getCurrentCUDAStream());
return softmax_results;
}
at::Tensor scaled_aligned_causal_masked_softmax_backward(
at::Tensor output_grad_,
at::Tensor softmax_results_,
float scale_factor
) {
using namespace transformer_engine;
auto output_grads = output_grad_.contiguous();
auto softmax_results = softmax_results_.contiguous();
AT_ASSERTM(output_grads.dim() == 4, "expected 4D tensor");
AT_ASSERTM(softmax_results.dim() == 4, "expected 4D tensor");
AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
auto output_grads_cu = makeTransformerEngineTensor(output_grads);
auto softmax_results_cu = makeTransformerEngineTensor(softmax_results);
// Produce gradients in place.
nvte_scaled_aligned_causal_masked_softmax_backward(
output_grads_cu.data(),
softmax_results_cu.data(),
output_grads_cu.data(),
scale_factor,
at::cuda::getCurrentCUDAStream());
return output_grads;
}
...@@ -113,6 +113,70 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): ...@@ -113,6 +113,70 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
return out return out
class ScaledAlignedCausalMaskedSoftmax(torch.autograd.Function):
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
2. Apply causal mask aligned to the bottom right corner of the input matrix
3. Perform softmax.
"""
@staticmethod
def forward(ctx, inputs: torch.Tensor, scale: float) -> torch.Tensor:
"""ScaledAlignedCausalMaskedSoftmax fwd"""
scale_t = torch.tensor([scale])
softmax_results = tex.scaled_aligned_causal_masked_softmax_forward(
inputs, scale_t[0]
)
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
@staticmethod
def backward(
ctx, output_grads: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]:
"""ScaledAlignedCausalMaskedSoftmax bwd"""
softmax_results, scale_t = ctx.saved_tensors
input_grads = tex.scaled_aligned_causal_masked_softmax_backward(
output_grads, softmax_results, scale_t[0]
)
return input_grads, None
@staticmethod
@fp32_compute
def symbolic(g: torch.Graph, inputs: torch._C.Value, scale: float) -> torch._C.Value:
"""ScaledAlignedCausalMaskedSoftmax symbolic method"""
def triangular_mask():
dtype = _type_utils.JitScalarType.INT64
ones = torch.onnx.symbolic_opset9.ones_like(g, inputs, dtype)
k = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))
# rectangular causal mask aligned to the bottom right corner of Attention matrix
rows = inputs.size(dim=-2)
cols = inputs.size(dim=-1)
diag_shift = cols - rows + 1
mask = g.op("Trilu", ones, k, upper_i=diag_shift)
mask = g.op("Cast", mask, to_i=_C_onnx.TensorProtoDataType.BOOL)
return mask
# Captures the logic of function scaled_aligned_masked_softmax_warp_forward
mask = triangular_mask()
one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))
inv_mask = g.op("Sub", one, mask)
neg_tenK = g.op("Constant", value_t=torch.tensor(-10000., dtype=torch.float16))
softmax_mask = g.op("Mul", mask, neg_tenK)
scale_input = g.op("Constant", value_t=torch.tensor(scale, dtype=torch.float16))
scaled = g.op("Mul", inputs, scale_input)
masked_scaled = g.op("Mul", inv_mask, scaled)
masked = g.op("Add", masked_scaled, softmax_mask)
out = g.op("Softmax", masked)
return out
class ScaledMaskedSoftmax(torch.autograd.Function): class ScaledMaskedSoftmax(torch.autograd.Function):
""" """
Fused operation which performs following three operations in sequence Fused operation which performs following three operations in sequence
...@@ -272,13 +336,13 @@ class FusedScaleMaskSoftmax(nn.Module): ...@@ -272,13 +336,13 @@ class FusedScaleMaskSoftmax(nn.Module):
if ( # pylint: disable=too-many-boolean-expressions if ( # pylint: disable=too-many-boolean-expressions
self.scaled_masked_softmax_fusion # user wants to fuse self.scaled_masked_softmax_fusion # user wants to fuse
and self.input_in_float16 # input must be fp16 and self.input_in_float16 # input must be fp16
and 16 < sk <= 4096 # sk must be 16 ~ 2048 and 16 <= sk <= 16384 # sk must be 16 ~ 16384
and sk % 8 == 0 # sk must be divisor of 8 and sk % 8 == 0 # sk must be divisor of 8
and sq % 4 == 0 # sq must be divisor of 4 and sq % 4 == 0 # sq must be divisor of 4
and attn_batches % 4 == 0 # np * b must be divisor of 4 and attn_batches % 4 == 0 # np * b must be divisor of 4
and self.attn_mask_type != "arbitrary" # Custom masks not supported and self.attn_mask_type != "arbitrary" # Custom masks not supported
): ):
if 0 <= sk <= 4096: if 0 <= sk <= 16384:
batch_per_block = self.get_batch_per_block(int(sk)) batch_per_block = self.get_batch_per_block(int(sk))
if self.attn_mask_type == "causal": if self.attn_mask_type == "causal":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment