Unverified Commit d9eb1991 authored by Oleg Goncharov's avatar Oleg Goncharov Committed by GitHub
Browse files

[common] Added new unfused softmax cuda kernel to support causal attention mask (#652)



* Added new unfused softmax cuda kernel to support causal attention mask
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Added test suite for unfused causal softmax kernel
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Removed test cases with large matrices from the causal softmax test suite
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Cleaned up the code per lint
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Added a compute buffer to causal softmax testing suite to store intermediate results without casting
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Added more tests cases
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Relaxed absolute tolerance atol
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* Relaxed absolute tolerance for BF16
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

---------
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>
parent 94de051f
...@@ -15,6 +15,7 @@ add_executable(test_operator ...@@ -15,6 +15,7 @@ add_executable(test_operator
test_layernorm.cu test_layernorm.cu
test_rmsnorm.cu test_rmsnorm.cu
test_multi_cast_transpose.cu test_multi_cast_transpose.cu
test_causal_softmax.cu
../test_common.cu) ../test_common.cu)
list(APPEND test_operator_LINKER_LIBS CUDA::cudart GTest::gtest_main ${TE_LIB}) list(APPEND test_operator_LINKER_LIBS CUDA::cudart GTest::gtest_main ${TE_LIB})
......
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <gtest/gtest.h>
#include <transformer_engine/softmax.h>
#include "../test_common.h"
using namespace transformer_engine;
namespace {
using compute_t = float;
template <typename Type>
void compute_single_head_fwd(
Type *softmax_out,
const Type *data_in,
compute_t *buff,
const float scaling_factor,
const int rows,
const int cols)
{
for (int i = 0; i < rows; ++i) {
size_t offset = i * cols;
const int masked_elements = i + cols - rows + 1;
compute_t max_value = static_cast<compute_t>(-10'000.f);
for (int j = 0; j < masked_elements; ++j) {
compute_t tmp = scaling_factor * static_cast<compute_t>(data_in[offset + j]);
buff[offset + j] = tmp;
max_value = std::max(max_value, tmp);
}
compute_t accumulator = static_cast<compute_t>(0.f);
for (int j = 0; j < masked_elements; ++j) {
compute_t tmp = std::exp(buff[offset + j] - max_value);
buff[offset + j] = tmp;
accumulator += tmp;
}
for (int j = 0; j < cols; ++j) {
if (j < masked_elements) {
compute_t tmp = buff[offset + j] / accumulator;
softmax_out[offset + j] = static_cast<Type>(tmp);
} else {
softmax_out[offset + j] = static_cast<Type>(0.f);
}
}
}
}
template <typename Type>
void compute_single_head_bwd(
Type *grad_out,
const Type *grad_in,
const Type *softmax_in,
compute_t *buff,
const float scaling_factor,
const int batches,
const int heads,
const int rows,
const int cols)
{
for (int i = 0; i < rows; ++i) {
size_t offset = i * cols;
const int masked_elements = i + cols - rows + 1;
compute_t accumulator = static_cast<compute_t>(0.f);
for (int j = 0; j < masked_elements; ++j) {
compute_t tmp = static_cast<compute_t>(softmax_in[offset + j])
* static_cast<compute_t>(grad_in[offset + j]);
buff[offset + j] = tmp;
accumulator += tmp;
}
for (int j = 0; j < cols; ++j) {
if (j < masked_elements) {
compute_t tmp = buff[offset + j]
- static_cast<compute_t>(softmax_in[offset + j]) * accumulator;
grad_out[offset + j] = static_cast<Type>(scaling_factor * tmp);
} else {
grad_out[offset + j] = static_cast<Type>(0.f);
}
}
}
}
template <typename Type>
void compute_fwd_ref(
Type *softmax_out,
const Type *data_in,
compute_t *buff,
const float scaling_factor,
const int batches,
const int heads,
const int rows,
const int cols)
{
size_t head_size = rows * cols;
size_t batch_size = heads * head_size;
for (int b = 0; b < batches; ++b) {
for (int h = 0; h < heads; ++h) {
size_t offset = b * batch_size + h * head_size;
compute_single_head_fwd(softmax_out + offset, data_in + offset,
buff + offset, scaling_factor, rows, cols);
}
}
}
template <typename Type>
void compute_bwd_ref(
Type *grad_out,
const Type *grad_in,
const Type *softmax_in,
compute_t *buff,
const float scaling_factor,
const int batches,
const int heads,
const int rows,
const int cols)
{
size_t head_size = rows * cols;
size_t batch_size = heads * head_size;
for (int b = 0; b < batches; ++b) {
for (int h = 0; h < heads; ++h) {
size_t offset = b * batch_size + h * head_size;
compute_single_head_bwd(grad_out + offset, grad_in + offset, softmax_in + offset,
buff + offset, scaling_factor, batches, heads, rows, cols);
}
}
}
// Query Sequence Length = rows
// Key Sequence Length = cols
template <typename Type>
void performTest(
const size_t batches,
const size_t heads,
const size_t rows,
const size_t cols,
float scaling_factor)
{
using namespace test;
DType itype = TypeInfo<Type>::dtype;
Tensor data_in({ batches, heads, rows, cols }, itype);
Tensor softmax_out({ batches, heads, rows, cols }, itype);
Tensor softmax_in({ batches, heads, rows, cols }, itype);
Tensor grads_in({ batches, heads, rows, cols }, itype);
Tensor grads_out({ batches, heads, rows, cols }, itype);
const size_t elements_total = batches * heads * rows * cols;
std::unique_ptr<Type[]> softmax_out_ref = std::make_unique<Type[]>(elements_total);
std::unique_ptr<Type[]> grads_out_ref = std::make_unique<Type[]>(elements_total);
std::unique_ptr<compute_t[]> compute_buffer = std::make_unique<compute_t[]>(elements_total);
fillUniform(&data_in);
fillUniform(&softmax_in);
fillUniform(&grads_in);
nvte_scaled_aligned_causal_masked_softmax_forward(
data_in.data(), softmax_out.data(), scaling_factor, 0);
nvte_scaled_aligned_causal_masked_softmax_backward(
grads_in.data(), softmax_in.data(), grads_out.data(), scaling_factor, 0);
// Reference implementations
compute_fwd_ref(softmax_out_ref.get(), data_in.cpu_dptr<Type>(),
compute_buffer.get(), scaling_factor, batches, heads, rows, cols);
compute_bwd_ref(grads_out_ref.get(), grads_in.cpu_dptr<Type>(), softmax_in.cpu_dptr<Type>(),
compute_buffer.get(), scaling_factor, batches, heads, rows, cols);
cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
auto [atol, rtol] = getTolerances(itype);
if(itype == DType::kBFloat16) {
atol = 1e-3;
}
compareResults("softmax_fwd", softmax_out, softmax_out_ref.get(), atol, rtol);
compareResults("softmax_bwd", grads_out, grads_out_ref.get(), atol, rtol);
}
// [Batches, Attention Heads, Query Sequence Length, Key Sequence Length, Scaling Factor]
std::vector<std::tuple<size_t, size_t, size_t, size_t, float>> test_cases = {
{ 1, 1, 1, 16, -1.0f},
{ 1, 2, 17, 32, 0.8f},
{ 2, 1, 37, 112, 1.0f},
{ 2, 4, 127, 128, -0.2f},
{ 8, 6, 128, 256, 1.3f},
{ 1, 4, 270, 256, 0.8f},
{ 2, 2, 512, 512, -1.5f},
{ 1, 2, 819, 1024, 2.1f},
{ 1, 2, 281, 1024, 0.2f},
{ 1, 2, 277, 1024, -2.1f},
{ 1, 2, 127, 1024, 1.1f},
{ 2, 2, 107, 2048, 0.4f},
{ 2, 1, 103, 2048, -3.0f},
{ 2, 2, 101, 2048, 2.6f},
{ 1, 1, 1024, 4096, 0.6f},
{ 1, 2, 61, 4096, 0.6f},
{ 1, 2, 59, 4096, -4.9f},
{ 1, 2, 53, 4096, 3.5f},
{ 1, 1, 37, 8192, 0.7f},
{ 1, 1, 31, 8192, -5.8f},
{ 1, 1, 29, 8192, 4.4f},
{ 1, 1, 23, 12288, 0.8f},
{ 1, 1, 19, 12288, -6.7f},
{ 1, 1, 17, 12288, 3.3f},
{ 1, 1, 13, 16384, 0.9f},
{ 1, 1, 11, 16384, -7.6f},
{ 1, 1, 7, 16384, 6.2f}};
} // namespace
class CausalSoftmaxTestSuite
: public ::testing::TestWithParam<std::tuple<
transformer_engine::DType,
std::tuple<size_t, size_t, size_t, size_t, float>>> {};
TEST_P(CausalSoftmaxTestSuite, TestCausalSoftmax) {
using namespace transformer_engine;
using namespace test;
const DType input_type = std::get<0>(GetParam());
const auto size = std::get<1>(GetParam());
const size_t batches = std::get<0>(size);
const size_t heads = std::get<1>(size);
const size_t query_seq_len = std::get<2>(size);
const size_t key_seq_len = std::get<3>(size);
const float scaling_factor = std::get<4>(size);
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
performTest<InputType>(batches, heads, query_seq_len, key_seq_len, scaling_factor);
);
}
INSTANTIATE_TEST_SUITE_P(
OperatorTest,
CausalSoftmaxTestSuite,
::testing::Combine(
::testing::Values(DType::kFloat16, DType::kBFloat16),
::testing::ValuesIn(test_cases)),
[](const testing::TestParamInfo<CausalSoftmaxTestSuite::ParamType>& info) {
const auto size = std::get<1>(info.param);
const size_t batches = std::get<0>(size);
const size_t heads = std::get<1>(size);
const size_t query_seq_len = std::get<2>(size);
const size_t key_seq_len = std::get<3>(size);
std::string scaling_factor = std::to_string(std::get<4>(size));
for (char& c : scaling_factor) {
if (c == '-') { c = 'N'; }
if (c == '.') { c = 'p'; }
}
std::string name = test::typeName(std::get<0>(info.param)) + "X" +
std::to_string(batches) + "X" +
std::to_string(heads) + "X" +
std::to_string(query_seq_len) + "X" +
std::to_string(key_seq_len) + "X" +
scaling_factor;
return name;
});
...@@ -33,8 +33,7 @@ list(APPEND transformer_engine_SOURCES ...@@ -33,8 +33,7 @@ list(APPEND transformer_engine_SOURCES
util/system.cpp util/system.cpp
fused_softmax/scaled_masked_softmax.cu fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_softmax/scaled_masked_softmax.cu fused_softmax/scaled_aligned_causal_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_rope/fused_rope.cu) fused_rope/fused_rope.cu)
add_library(transformer_engine SHARED ${transformer_engine_SOURCES}) add_library(transformer_engine SHARED ${transformer_engine_SOURCES})
target_include_directories(transformer_engine PUBLIC target_include_directories(transformer_engine PUBLIC
...@@ -87,6 +86,7 @@ target_include_directories(transformer_engine PRIVATE ...@@ -87,6 +86,7 @@ target_include_directories(transformer_engine PRIVATE
# Compiler options # Compiler options
set_source_files_properties(fused_softmax/scaled_masked_softmax.cu set_source_files_properties(fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_softmax/scaled_aligned_causal_masked_softmax.cu
PROPERTIES PROPERTIES
COMPILE_OPTIONS "--use_fast_math") COMPILE_OPTIONS "--use_fast_math")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
......
...@@ -466,7 +466,7 @@ void dispatch_scaled_softmax_forward( ...@@ -466,7 +466,7 @@ void dispatch_scaled_softmax_forward(
int batches, int batches,
int attn_heads, int attn_heads,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_CHECK(key_seq_len >= 0 && key_seq_len <= 4096, "Unsupported shape."); NVTE_CHECK(key_seq_len >= 0 && key_seq_len <= 16384, "Unsupported shape.");
if (key_seq_len == 0) { if (key_seq_len == 0) {
return; return;
} else { } else {
...@@ -597,6 +597,22 @@ void dispatch_scaled_softmax_forward( ...@@ -597,6 +597,22 @@ void dispatch_scaled_softmax_forward(
batch_count, batch_count,
key_seq_len); key_seq_len);
break; break;
case 13: // 8192
scaled_softmax_warp_forward<input_t, output_t, acc_t, 13>
<<<blocks, threads, 0, stream>>>(dst,
src,
scale,
batch_count,
key_seq_len);
break;
case 14: // 16384
scaled_softmax_warp_forward<input_t, output_t, acc_t, 14>
<<<blocks, threads, 0, stream>>>(dst,
src,
scale,
batch_count,
key_seq_len);
break;
default: default:
break; break;
} }
...@@ -615,7 +631,7 @@ void dispatch_scaled_masked_softmax_forward( ...@@ -615,7 +631,7 @@ void dispatch_scaled_masked_softmax_forward(
int attn_heads, int attn_heads,
int pad_batches, int pad_batches,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_CHECK(key_seq_len >= 0 && key_seq_len <= 4096, "Unsupported shape."); NVTE_CHECK(key_seq_len >= 0 && key_seq_len <= 16384, "Unsupported shape.");
if (key_seq_len == 0) { if (key_seq_len == 0) {
return; return;
} else { } else {
...@@ -772,6 +788,26 @@ void dispatch_scaled_masked_softmax_forward( ...@@ -772,6 +788,26 @@ void dispatch_scaled_masked_softmax_forward(
key_seq_len, key_seq_len,
pad_batches); pad_batches);
break; break;
case 13: // 8192
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 13>
<<<blocks, threads, 0, stream>>>(dst,
src,
mask,
scale,
batch_count,
key_seq_len,
pad_batches);
break;
case 14: // 16384
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 14>
<<<blocks, threads, 0, stream>>>(dst,
src,
mask,
scale,
batch_count,
key_seq_len,
pad_batches);
break;
default: default:
break; break;
} }
...@@ -789,7 +825,7 @@ void dispatch_scaled_masked_softmax_backward( ...@@ -789,7 +825,7 @@ void dispatch_scaled_masked_softmax_backward(
int batches, int batches,
int attn_heads, int attn_heads,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_CHECK(key_seq_len >= 0 && key_seq_len <= 4096, "Unsupported shape."); NVTE_CHECK(key_seq_len >= 0 && key_seq_len <= 16384, "Unsupported shape.");
if (key_seq_len == 0) { if (key_seq_len == 0) {
return; return;
} else { } else {
...@@ -833,7 +869,6 @@ void dispatch_scaled_masked_softmax_backward( ...@@ -833,7 +869,6 @@ void dispatch_scaled_masked_softmax_backward(
batch_count, batch_count,
key_seq_len); key_seq_len);
break; break;
break;
case 2: // 4 case 2: // 4
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 2> scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 2>
<<<blocks, threads, 0, stream>>>(grad_input, <<<blocks, threads, 0, stream>>>(grad_input,
...@@ -843,7 +878,6 @@ void dispatch_scaled_masked_softmax_backward( ...@@ -843,7 +878,6 @@ void dispatch_scaled_masked_softmax_backward(
batch_count, batch_count,
key_seq_len); key_seq_len);
break; break;
break;
case 3: // 8 case 3: // 8
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 3> scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 3>
<<<blocks, threads, 0, stream>>>(grad_input, <<<blocks, threads, 0, stream>>>(grad_input,
...@@ -853,7 +887,6 @@ void dispatch_scaled_masked_softmax_backward( ...@@ -853,7 +887,6 @@ void dispatch_scaled_masked_softmax_backward(
batch_count, batch_count,
key_seq_len); key_seq_len);
break; break;
break;
case 4: // 16 case 4: // 16
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 4> scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 4>
<<<blocks, threads, 0, stream>>>(grad_input, <<<blocks, threads, 0, stream>>>(grad_input,
...@@ -863,7 +896,6 @@ void dispatch_scaled_masked_softmax_backward( ...@@ -863,7 +896,6 @@ void dispatch_scaled_masked_softmax_backward(
batch_count, batch_count,
key_seq_len); key_seq_len);
break; break;
break;
case 5: // 32 case 5: // 32
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 5> scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 5>
<<<blocks, threads, 0, stream>>>(grad_input, <<<blocks, threads, 0, stream>>>(grad_input,
...@@ -873,7 +905,6 @@ void dispatch_scaled_masked_softmax_backward( ...@@ -873,7 +905,6 @@ void dispatch_scaled_masked_softmax_backward(
batch_count, batch_count,
key_seq_len); key_seq_len);
break; break;
break;
case 6: // 64 case 6: // 64
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 6> scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 6>
<<<blocks, threads, 0, stream>>>(grad_input, <<<blocks, threads, 0, stream>>>(grad_input,
...@@ -883,7 +914,6 @@ void dispatch_scaled_masked_softmax_backward( ...@@ -883,7 +914,6 @@ void dispatch_scaled_masked_softmax_backward(
batch_count, batch_count,
key_seq_len); key_seq_len);
break; break;
break;
case 7: // 128 case 7: // 128
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 7> scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 7>
<<<blocks, threads, 0, stream>>>(grad_input, <<<blocks, threads, 0, stream>>>(grad_input,
...@@ -893,7 +923,6 @@ void dispatch_scaled_masked_softmax_backward( ...@@ -893,7 +923,6 @@ void dispatch_scaled_masked_softmax_backward(
batch_count, batch_count,
key_seq_len); key_seq_len);
break; break;
break;
case 8: // 256 case 8: // 256
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 8> scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 8>
<<<blocks, threads, 0, stream>>>(grad_input, <<<blocks, threads, 0, stream>>>(grad_input,
...@@ -903,7 +932,6 @@ void dispatch_scaled_masked_softmax_backward( ...@@ -903,7 +932,6 @@ void dispatch_scaled_masked_softmax_backward(
batch_count, batch_count,
key_seq_len); key_seq_len);
break; break;
break;
case 9: // 512 case 9: // 512
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 9> scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 9>
<<<blocks, threads, 0, stream>>>(grad_input, <<<blocks, threads, 0, stream>>>(grad_input,
...@@ -913,7 +941,6 @@ void dispatch_scaled_masked_softmax_backward( ...@@ -913,7 +941,6 @@ void dispatch_scaled_masked_softmax_backward(
batch_count, batch_count,
key_seq_len); key_seq_len);
break; break;
break;
case 10: // 1024 case 10: // 1024
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 10> scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 10>
<<<blocks, threads, 0, stream>>>(grad_input, <<<blocks, threads, 0, stream>>>(grad_input,
...@@ -923,7 +950,6 @@ void dispatch_scaled_masked_softmax_backward( ...@@ -923,7 +950,6 @@ void dispatch_scaled_masked_softmax_backward(
batch_count, batch_count,
key_seq_len); key_seq_len);
break; break;
break;
case 11: // 2048 case 11: // 2048
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 11> scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, stream>>>(grad_input, <<<blocks, threads, 0, stream>>>(grad_input,
...@@ -933,7 +959,6 @@ void dispatch_scaled_masked_softmax_backward( ...@@ -933,7 +959,6 @@ void dispatch_scaled_masked_softmax_backward(
batch_count, batch_count,
key_seq_len); key_seq_len);
break; break;
break;
case 12: // 4096 case 12: // 4096
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 12> scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 12>
<<<blocks, threads, 0, stream>>>(grad_input, <<<blocks, threads, 0, stream>>>(grad_input,
...@@ -943,8 +968,24 @@ void dispatch_scaled_masked_softmax_backward( ...@@ -943,8 +968,24 @@ void dispatch_scaled_masked_softmax_backward(
batch_count, batch_count,
key_seq_len); key_seq_len);
break; break;
case 13: // 8192
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 13>
<<<blocks, threads, 0, stream>>>(grad_input,
grad,
output,
scale,
batch_count,
key_seq_len);
break;
case 14: // 16384
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 14>
<<<blocks, threads, 0, stream>>>(grad_input,
grad,
output,
scale,
batch_count,
key_seq_len);
break; break;
default: default:
break; break;
} }
......
...@@ -368,7 +368,7 @@ void dispatch_scaled_upper_triang_masked_softmax_forward( ...@@ -368,7 +368,7 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(
int softmax_elements_stride, int softmax_elements_stride,
int attn_batches, int attn_batches,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_CHECK(softmax_elements >= 0 && softmax_elements <= 2048, "Unsupported shape."); NVTE_CHECK(softmax_elements >= 0 && softmax_elements <= 16384, "Unsupported shape.");
if (softmax_elements == 0) { if (softmax_elements == 0) {
return; return;
} else { } else {
...@@ -506,6 +506,33 @@ void dispatch_scaled_upper_triang_masked_softmax_forward( ...@@ -506,6 +506,33 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(
softmax_elements_stride, softmax_elements_stride,
softmax_elements); softmax_elements);
break; break;
case 12: // 4096
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 12>
<<<blocks, threads, 0, stream>>>(dst,
src,
scale,
batch_count,
softmax_elements_stride,
softmax_elements);
break;
case 13: // 8192
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 13>
<<<blocks, threads, 0, stream>>>(dst,
src,
scale,
batch_count,
softmax_elements_stride,
softmax_elements);
break;
case 14: // 16384
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 14>
<<<blocks, threads, 0, stream>>>(dst,
src,
scale,
batch_count,
softmax_elements_stride,
softmax_elements);
break;
default: default:
break; break;
} }
...@@ -522,7 +549,7 @@ void dispatch_scaled_upper_triang_masked_softmax_backward( ...@@ -522,7 +549,7 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(
int softmax_elements_stride, int softmax_elements_stride,
int attn_batches, int attn_batches,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_CHECK(softmax_elements >= 0 && softmax_elements <= 2048, "Unsupported shape."); NVTE_CHECK(softmax_elements >= 0 && softmax_elements <= 16384, "Unsupported shape.");
if (softmax_elements == 0) { if (softmax_elements == 0) {
return; return;
} else { } else {
...@@ -660,6 +687,33 @@ void dispatch_scaled_upper_triang_masked_softmax_backward( ...@@ -660,6 +687,33 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(
softmax_elements_stride, softmax_elements_stride,
softmax_elements); softmax_elements);
break; break;
case 12: // 4096
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 12>
<<<blocks, threads, 0, stream>>>(grad_input,
grad, output,
scale,
batch_count,
softmax_elements_stride,
softmax_elements);
break;
case 13: // 8192
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 13>
<<<blocks, threads, 0, stream>>>(grad_input,
grad, output,
scale,
batch_count,
softmax_elements_stride,
softmax_elements);
break;
case 14: // 16384
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 14>
<<<blocks, threads, 0, stream>>>(grad_input,
grad, output,
scale,
batch_count,
softmax_elements_stride,
softmax_elements);
break;
default: default:
break; break;
} }
......
...@@ -125,6 +125,41 @@ void nvte_scaled_upper_triang_masked_softmax_backward( ...@@ -125,6 +125,41 @@ void nvte_scaled_upper_triang_masked_softmax_backward(
); );
/*! \brief Compute scaled softmax activation using an implicit 2D mask aligned to the bottom right corner of the input matrix.
*
* \param[in] input Input tensor for softmax.
* \param[out] softmax_results Output tensor.
* \param[in] scale_factor Scalar for the input tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_scaled_aligned_causal_masked_softmax_forward(
const NVTETensor input,
NVTETensor softmax_results,
float scale_factor,
cudaStream_t stream
);
/*! \brief Compute the backward pass of the scaled softmax activation using an implicit 2D mask aligned to the bottom right corner of the input matrix.
*
* - `incoming_grads` is the input tensor containing the gradients received from the following layer.
* - `softmax_results` is the output tensor of the corresponding forward softmax operation.
* - `output_grads` is the output tensor containing the computed gradients.
*
* \param[in] incoming_grads Input gradient tensor for backward.
* \param[in] softmax_results Output tensor of softmax forward.
* \param[out] output_grads Output tensor.
* \param[in] scale_factor Scalar for the output tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_scaled_aligned_causal_masked_softmax_backward(
const NVTETensor incoming_grads,
const NVTETensor softmax_results,
NVTETensor output_grads,
float scale_factor,
cudaStream_t stream
);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
#endif #endif
......
...@@ -521,6 +521,17 @@ at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_, ...@@ -521,6 +521,17 @@ at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_,
float scale_factor float scale_factor
); );
at::Tensor scaled_aligned_causal_masked_softmax_forward(at::Tensor input,
float scale_factor
);
at::Tensor scaled_aligned_causal_masked_softmax_backward(at::Tensor output_grads_,
at::Tensor softmax_results_,
float scale_factor
);
/*************************************************************************************************** /***************************************************************************************************
* Rotary positional embedding * Rotary positional embedding
**************************************************************************************************/ **************************************************************************************************/
......
...@@ -23,6 +23,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -23,6 +23,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("scaled_upper_triang_masked_softmax_backward", m.def("scaled_upper_triang_masked_softmax_backward",
&scaled_upper_triang_masked_softmax_backward, &scaled_upper_triang_masked_softmax_backward,
"Scaled Upper-Triangular Masked Softmax BWD"); "Scaled Upper-Triangular Masked Softmax BWD");
m.def("scaled_aligned_causal_masked_softmax_forward",
&scaled_aligned_causal_masked_softmax_forward,
"Scaled Bottom-Right Corner Aligned Masked Softmax FWD");
m.def("scaled_aligned_causal_masked_softmax_backward",
&scaled_aligned_causal_masked_softmax_backward,
"Scaled Bottom-Right Corner Aligned Masked Softmax BWD");
// Other granular functions // Other granular functions
m.def("layernorm_fwd_fp8", &layernorm_fwd_fp8, "LN FWD FP8"); m.def("layernorm_fwd_fp8", &layernorm_fwd_fp8, "LN FWD FP8");
......
...@@ -20,7 +20,7 @@ at::Tensor scaled_softmax_forward(at::Tensor input, ...@@ -20,7 +20,7 @@ at::Tensor scaled_softmax_forward(at::Tensor input,
const int query_seq_len = input.size(2); const int query_seq_len = input.size(2);
const int key_seq_len = input.size(3); const int key_seq_len = input.size(3);
AT_ASSERTM(key_seq_len <= 4096, "Key sequence length must be 4096 or less"); AT_ASSERTM(key_seq_len <= 16384, "Key sequence length must be 16384 or less");
AT_ASSERTM(key_seq_len % 8 == 0, "Key sequence length must be divisible by 8"); AT_ASSERTM(key_seq_len % 8 == 0, "Key sequence length must be divisible by 8");
AT_ASSERTM(query_seq_len > 1, "Query sequence length must be greater than 1"); AT_ASSERTM(query_seq_len > 1, "Query sequence length must be greater than 1");
...@@ -92,7 +92,7 @@ at::Tensor scaled_masked_softmax_forward(at::Tensor input, ...@@ -92,7 +92,7 @@ at::Tensor scaled_masked_softmax_forward(at::Tensor input,
const int query_seq_len = input.size(2); const int query_seq_len = input.size(2);
const int key_seq_len = input.size(3); const int key_seq_len = input.size(3);
AT_ASSERTM(key_seq_len <= 4096, "Key sequence length must be 4096 or less"); AT_ASSERTM(key_seq_len <= 16384, "Key sequence length must be 16384 or less");
AT_ASSERTM(key_seq_len % 8 == 0, "Key sequence length must be divisible by 8"); AT_ASSERTM(key_seq_len % 8 == 0, "Key sequence length must be divisible by 8");
AT_ASSERTM(query_seq_len > 1, "Query sequence length must be greater than 1"); AT_ASSERTM(query_seq_len > 1, "Query sequence length must be greater than 1");
TORCH_CHECK(pad_batches == 1 || pad_batches == batches); TORCH_CHECK(pad_batches == 1 || pad_batches == batches);
...@@ -160,7 +160,7 @@ at::Tensor scaled_upper_triang_masked_softmax_forward(at::Tensor input, ...@@ -160,7 +160,7 @@ at::Tensor scaled_upper_triang_masked_softmax_forward(at::Tensor input,
const int attn_batches = input.size(0); const int attn_batches = input.size(0);
const int seq_len = input.size(1); const int seq_len = input.size(1);
AT_ASSERTM(seq_len <= 2048, "Sequence length must be 2048 or less"); AT_ASSERTM(seq_len <= 16384, "Sequence length must be 16384 or less");
// Output // Output
auto act_options = input.options().requires_grad(false); auto act_options = input.options().requires_grad(false);
...@@ -212,3 +212,75 @@ at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_, ...@@ -212,3 +212,75 @@ at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_,
return output_grads; return output_grads;
} }
at::Tensor scaled_aligned_causal_masked_softmax_forward(
at::Tensor input,
float scale_factor
) {
using namespace transformer_engine;
AT_ASSERTM(input.dim() == 4, "expected 4D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
const int batches = input.size(0);
const int attn_heads = input.size(1);
const int query_seq_len = input.size(2);
const int key_seq_len = input.size(3);
AT_ASSERTM(key_seq_len <= 16384, "Key sequence length must be 16384 or less");
AT_ASSERTM(key_seq_len % 8 == 0, "Key sequence length must be divisible by 8");
AT_ASSERTM(query_seq_len >= 1, "Query sequence length must be greater or equal to 1");
// Output
auto act_options = input.options().requires_grad(false);
auto softmax_results =
torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options);
auto input_cu = makeTransformerEngineTensor(input);
auto softmax_results_cu = makeTransformerEngineTensor(softmax_results);
nvte_scaled_aligned_causal_masked_softmax_forward(
input_cu.data(),
softmax_results_cu.data(),
scale_factor,
at::cuda::getCurrentCUDAStream());
return softmax_results;
}
at::Tensor scaled_aligned_causal_masked_softmax_backward(
at::Tensor output_grad_,
at::Tensor softmax_results_,
float scale_factor
) {
using namespace transformer_engine;
auto output_grads = output_grad_.contiguous();
auto softmax_results = softmax_results_.contiguous();
AT_ASSERTM(output_grads.dim() == 4, "expected 4D tensor");
AT_ASSERTM(softmax_results.dim() == 4, "expected 4D tensor");
AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
auto output_grads_cu = makeTransformerEngineTensor(output_grads);
auto softmax_results_cu = makeTransformerEngineTensor(softmax_results);
// Produce gradients in place.
nvte_scaled_aligned_causal_masked_softmax_backward(
output_grads_cu.data(),
softmax_results_cu.data(),
output_grads_cu.data(),
scale_factor,
at::cuda::getCurrentCUDAStream());
return output_grads;
}
...@@ -113,6 +113,70 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): ...@@ -113,6 +113,70 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
return out return out
class ScaledAlignedCausalMaskedSoftmax(torch.autograd.Function):
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
2. Apply causal mask aligned to the bottom right corner of the input matrix
3. Perform softmax.
"""
@staticmethod
def forward(ctx, inputs: torch.Tensor, scale: float) -> torch.Tensor:
"""ScaledAlignedCausalMaskedSoftmax fwd"""
scale_t = torch.tensor([scale])
softmax_results = tex.scaled_aligned_causal_masked_softmax_forward(
inputs, scale_t[0]
)
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
@staticmethod
def backward(
ctx, output_grads: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]:
"""ScaledAlignedCausalMaskedSoftmax bwd"""
softmax_results, scale_t = ctx.saved_tensors
input_grads = tex.scaled_aligned_causal_masked_softmax_backward(
output_grads, softmax_results, scale_t[0]
)
return input_grads, None
@staticmethod
@fp32_compute
def symbolic(g: torch.Graph, inputs: torch._C.Value, scale: float) -> torch._C.Value:
"""ScaledAlignedCausalMaskedSoftmax symbolic method"""
def triangular_mask():
dtype = _type_utils.JitScalarType.INT64
ones = torch.onnx.symbolic_opset9.ones_like(g, inputs, dtype)
k = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))
# rectangular causal mask aligned to the bottom right corner of Attention matrix
rows = inputs.size(dim=-2)
cols = inputs.size(dim=-1)
diag_shift = cols - rows + 1
mask = g.op("Trilu", ones, k, upper_i=diag_shift)
mask = g.op("Cast", mask, to_i=_C_onnx.TensorProtoDataType.BOOL)
return mask
# Captures the logic of function scaled_aligned_masked_softmax_warp_forward
mask = triangular_mask()
one = g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64))
inv_mask = g.op("Sub", one, mask)
neg_tenK = g.op("Constant", value_t=torch.tensor(-10000., dtype=torch.float16))
softmax_mask = g.op("Mul", mask, neg_tenK)
scale_input = g.op("Constant", value_t=torch.tensor(scale, dtype=torch.float16))
scaled = g.op("Mul", inputs, scale_input)
masked_scaled = g.op("Mul", inv_mask, scaled)
masked = g.op("Add", masked_scaled, softmax_mask)
out = g.op("Softmax", masked)
return out
class ScaledMaskedSoftmax(torch.autograd.Function): class ScaledMaskedSoftmax(torch.autograd.Function):
""" """
Fused operation which performs following three operations in sequence Fused operation which performs following three operations in sequence
...@@ -272,13 +336,13 @@ class FusedScaleMaskSoftmax(nn.Module): ...@@ -272,13 +336,13 @@ class FusedScaleMaskSoftmax(nn.Module):
if ( # pylint: disable=too-many-boolean-expressions if ( # pylint: disable=too-many-boolean-expressions
self.scaled_masked_softmax_fusion # user wants to fuse self.scaled_masked_softmax_fusion # user wants to fuse
and self.input_in_float16 # input must be fp16 and self.input_in_float16 # input must be fp16
and 16 < sk <= 4096 # sk must be 16 ~ 2048 and 16 <= sk <= 16384 # sk must be 16 ~ 16384
and sk % 8 == 0 # sk must be divisor of 8 and sk % 8 == 0 # sk must be divisor of 8
and sq % 4 == 0 # sq must be divisor of 4 and sq % 4 == 0 # sq must be divisor of 4
and attn_batches % 4 == 0 # np * b must be divisor of 4 and attn_batches % 4 == 0 # np * b must be divisor of 4
and self.attn_mask_type != "arbitrary" # Custom masks not supported and self.attn_mask_type != "arbitrary" # Custom masks not supported
): ):
if 0 <= sk <= 4096: if 0 <= sk <= 16384:
batch_per_block = self.get_batch_per_block(int(sk)) batch_per_block = self.get_batch_per_block(int(sk))
if self.attn_mask_type == "causal": if self.attn_mask_type == "causal":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment