Unverified Commit f674d49e authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Framework agnostic softmax kernels (#30)



* Make fused softmax kernels PyTorch independent
Co-authored-by: default avatarSean Lee <selee@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Address review comments
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* move get_batch_per_block to python
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix license in softmax.h
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarSean Lee <selee@nvidia.com>
parent b2743878
......@@ -137,75 +137,6 @@ if framework in ("all", "pytorch"):
)
)
ext_modules.append(
CUDAExtension(
name="scaled_upper_triang_masked_softmax_cuda",
sources=[
os.path.join(
path,
"transformer_engine/pytorch/csrc/fused_softmax/scaled_upper_triang_masked_softmax.cpp",
),
os.path.join(
path,
"transformer_engine/pytorch/csrc/fused_softmax/scaled_upper_triang_masked_softmax_cuda.cu",
),
],
extra_compile_args={
"cxx": ["-O3"],
"nvcc": append_nvcc_threads(extra_compiler_flags() + cc_flag),
},
include_dirs=[
os.path.join(path, "transformer_engine/pytorch/csrc/fused_softmax")
],
)
)
ext_modules.append(
CUDAExtension(
name="scaled_masked_softmax_cuda",
sources=[
os.path.join(
path,
"transformer_engine/pytorch/csrc/fused_softmax/scaled_masked_softmax.cpp",
),
os.path.join(
path,
"transformer_engine/pytorch/csrc/fused_softmax/scaled_masked_softmax_cuda.cu",
),
],
extra_compile_args={
"cxx": ["-O3"],
"nvcc": append_nvcc_threads(extra_compiler_flags() + cc_flag),
},
include_dirs=[
os.path.join(path, "transformer_engine/pytorch/csrc/fused_softmax")
],
)
)
ext_modules.append(
CUDAExtension(
name="scaled_softmax_cuda",
sources=[
os.path.join(
path,
"transformer_engine/pytorch/csrc/fused_softmax/scaled_softmax.cpp",
),
os.path.join(
path,
"transformer_engine/pytorch/csrc/fused_softmax/scaled_softmax_cuda.cu",
),
],
extra_compile_args={
"cxx": ["-O3"],
"nvcc": append_nvcc_threads(extra_compiler_flags() + cc_flag),
},
include_dirs=[
os.path.join(path, "transformer_engine/pytorch/csrc/fused_softmax")
],
)
)
def get_cmake_bin():
cmake_bin = "cmake"
......
......@@ -31,7 +31,9 @@ add_library(transformer_engine SHARED
layer_norm/ln_api.cpp
layer_norm/ln_bwd_semi_cuda_kernel.cu
layer_norm/ln_fwd_cuda_kernel.cu
util/cast.cu)
util/cast.cu
fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu)
target_include_directories(transformer_engine PUBLIC "${PROJECT_SOURCE_DIR}/include")
......@@ -41,3 +43,10 @@ list(APPEND transformer_engine_LINKER_LIBS CUDA::cublas CUDA::cudart)
target_link_libraries(transformer_engine PUBLIC ${transformer_engine_LINKER_LIBS})
target_include_directories(transformer_engine PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
set_source_files_properties(fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu
PROPERTIES
COMPILE_OPTIONS "--use_fast_math")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3")
......@@ -219,6 +219,26 @@ struct TypeInfo{
NVTE_ERROR("Invalid type."); \
}
#define TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(dtype, type, ...) \
switch (dtype) \
{ \
using namespace transformer_engine; \
case DType::kFloat16: \
{ \
using type = fp16; \
__VA_ARGS__; \
break; \
} \
case DType::kBFloat16: \
{ \
using type = bf16; \
__VA_ARGS__; \
break; \
} \
default: \
NVTE_ERROR("Invalid type for 16 bit."); \
}
template<typename T>
struct TypeId{};
......@@ -283,6 +303,12 @@ inline size_t product(const std::vector<size_t> &shape) {
return ret;
}
inline int log2_ceil(int value) {
int log2_value = 0;
while ((1 << log2_value) < value) ++log2_value;
return log2_value;
}
template <typename T>
struct is_fp8 : std::false_type {};
......
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_SOFTMAX_H_
#define TRANSFORMER_ENGINE_SOFTMAX_H_
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include "transformer_engine.h"
#ifdef __cplusplus
extern "C" {
#endif
void nvte_scaled_softmax_forward(
const NVTETensor input,
NVTETensor softmax_results,
float scale_factor,
cudaStream_t stream
);
void nvte_scaled_softmax_backward(
const NVTETensor output_grads,
const NVTETensor softmax_results,
float scale_factor,
cudaStream_t stream
);
void nvte_scaled_masked_softmax_forward(
const NVTETensor input,
const NVTETensor mask,
NVTETensor softmax_results,
float scale_factor,
cudaStream_t stream
);
void nvte_scaled_masked_softmax_backward(
const NVTETensor input,
NVTETensor softmax_results,
float scale_factor,
cudaStream_t stream
);
void nvte_scaled_upper_triang_masked_softmax_forward(
const NVTETensor input,
NVTETensor softmax_results,
float scale_factor,
cudaStream_t stream
);
void nvte_scaled_upper_triang_masked_softmax_backward(
const NVTETensor output_grads,
const NVTETensor softmax_results,
float scale_factor,
cudaStream_t stream
);
#ifdef __cplusplus
} // extern "C"
#endif
#endif // TRANSFORMER_ENGINE_SOFTMAX_H_
......@@ -14,6 +14,7 @@
#include <transformer_engine/logging.h>
#include <transformer_engine/transformer_engine.h>
#include <transformer_engine/cast.h>
#include <transformer_engine/softmax.h>
#include <ATen/ATen.h>
#include <ATen/cudnn/Handle.h>
#include <ATen/cuda/CUDAContext.h>
......
......@@ -483,8 +483,219 @@ at::Tensor cast_from_fp8(const at::Tensor &input,
}
at::Tensor scaled_softmax_forward(at::Tensor input,
float scale_factor
) {
using namespace transformer_engine;
AT_ASSERTM(input.dim() == 4, "expected 4D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
const int batches = input.size(0);
const int attn_heads = input.size(1);
const int query_seq_len = input.size(2);
const int key_seq_len = input.size(3);
TORCH_CHECK(key_seq_len <= 4096);
TORCH_CHECK(query_seq_len > 1);
// Output
auto act_options = input.options().requires_grad(false);
auto softmax_results =
torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options);
auto input_cu = makeTransformerEngineTensor(input);
auto softmax_results_cu = makeTransformerEngineTensor(softmax_results);
nvte_scaled_softmax_forward(input_cu.data(), softmax_results_cu.data(), scale_factor,
at::cuda::getCurrentCUDAStream());
return softmax_results;
}
at::Tensor scaled_softmax_backward(at::Tensor output_grad_,
at::Tensor softmax_results_,
float scale_factor
) {
using namespace transformer_engine;
auto output_grads = output_grad_.contiguous();
auto softmax_results = softmax_results_.contiguous();
AT_ASSERTM(output_grads.dim() == 4, "expected 4D tensor");
AT_ASSERTM(softmax_results.dim() == 4, "expected 4D tensor");
AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
auto output_grads_cu = makeTransformerEngineTensor(output_grads);
auto softmax_results_cu = makeTransformerEngineTensor(softmax_results);
nvte_scaled_softmax_backward(
output_grads_cu.data(), softmax_results_cu.data(),
scale_factor, at::cuda::getCurrentCUDAStream());
return output_grads;
}
at::Tensor scaled_masked_softmax_forward(at::Tensor input,
at::Tensor mask,
float scale_factor
) {
using namespace transformer_engine;
AT_ASSERTM(input.dim() == 4, "expected 4D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM(mask.dim() == 4, "expected 4D tensor");
const int batches = input.size(0);
const int pad_batches = mask.size(0);
const int attn_heads = input.size(1);
const int query_seq_len = input.size(2);
const int key_seq_len = input.size(3);
TORCH_CHECK(key_seq_len <= 4096);
TORCH_CHECK(query_seq_len > 1);
TORCH_CHECK(pad_batches == 1 || pad_batches == batches);
TORCH_CHECK(mask.size(1) == 1);
TORCH_CHECK(mask.size(2) == query_seq_len);
TORCH_CHECK(mask.size(3) == key_seq_len);
auto act_options = input.options().requires_grad(false);
auto softmax_results =
torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options);
auto input_cu = makeTransformerEngineTensor(input);
auto mask_cu = makeTransformerEngineTensor(mask);
auto softmax_results_cu = makeTransformerEngineTensor(softmax_results);
nvte_scaled_masked_softmax_forward(
input_cu.data(), mask_cu.data(), softmax_results_cu.data(),
scale_factor, at::cuda::getCurrentCUDAStream());
return softmax_results;
}
at::Tensor scaled_masked_softmax_backward(at::Tensor output_grad_,
at::Tensor softmax_results_,
float scale_factor
) {
using namespace transformer_engine;
auto output_grads = output_grad_.contiguous();
auto softmax_results = softmax_results_.contiguous();
AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor");
AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
auto output_grads_cu = makeTransformerEngineTensor(output_grads);
auto softmax_results_cu = makeTransformerEngineTensor(softmax_results);
nvte_scaled_softmax_backward(
output_grads_cu.data(), softmax_results_cu.data(),
scale_factor, at::cuda::getCurrentCUDAStream());
return output_grads;
}
at::Tensor scaled_upper_triang_masked_softmax_forward(at::Tensor input,
float scale_factor
) {
using namespace transformer_engine;
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
const int attn_batches = input.size(0);
const int seq_len = input.size(1);
TORCH_CHECK(seq_len <= 2048);
// Output
auto act_options = input.options().requires_grad(false);
auto softmax_results =
torch::empty({attn_batches, seq_len, seq_len}, act_options);
auto input_cu = makeTransformerEngineTensor(input);
auto softmax_results_cu = makeTransformerEngineTensor(softmax_results);
nvte_scaled_upper_triang_masked_softmax_forward(input_cu.data(),
softmax_results_cu.data(),
scale_factor,
at::cuda::getCurrentCUDAStream());
return softmax_results;
}
at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_,
at::Tensor softmax_results_,
float scale_factor
) {
using namespace transformer_engine;
auto output_grads = output_grads_.contiguous();
auto softmax_results = softmax_results_.contiguous();
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
TORCH_CHECK(output_grads.size(1) == output_grads.size(2));
auto output_grads_cu = makeTransformerEngineTensor(output_grads);
auto softmax_results_cu = makeTransformerEngineTensor(softmax_results);
nvte_scaled_upper_triang_masked_softmax_backward(output_grads_cu.data(),
softmax_results_cu.data(),
scale_factor,
at::cuda::getCurrentCUDAStream());
return output_grads;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// Granular functions
// Softmax functions
m.def("scaled_softmax_forward", &scaled_softmax_forward, "Scaled Softmax FWD");
m.def("scaled_softmax_backward", &scaled_softmax_backward, "Scaled Softmax BWD");
m.def("scaled_masked_softmax_forward", &scaled_masked_softmax_forward,
"Scaled Masked Softmax FWD");
m.def("scaled_masked_softmax_backward", &scaled_masked_softmax_backward,
"Scaled Masked Softmax BWD");
m.def("scaled_upper_triang_masked_softmax_forward",
&scaled_upper_triang_masked_softmax_forward,
"Scaled Upper-Triangular Masked Softmax FWD");
m.def("scaled_upper_triang_masked_softmax_backward",
&scaled_upper_triang_masked_softmax_backward,
"Scaled Upper-Triangular Masked Softmax BWD");
// Other granular functions
m.def("layernorm_fwd_fp8", &layernorm_fwd_fp8, "LN FWD FP8");
m.def("layernorm_bwd", &layernorm_bwd, "LN BWD");
m.def("layernorm_fwd", &layernorm_fwd, "LN FWD");
......
......@@ -116,3 +116,37 @@ at::Tensor cast_from_fp8(const at::Tensor &input,
transformer_engine::DType itype,
transformer_engine::DType otype
);
at::Tensor scaled_softmax_forward(at::Tensor input,
float scale_factor
);
at::Tensor scaled_softmax_backward(at::Tensor output_grad_,
at::Tensor softmax_results_,
float scale_factor
);
at::Tensor scaled_masked_softmax_forward(at::Tensor input,
at::Tensor mask,
float scale_factor
);
at::Tensor scaled_masked_softmax_backward(at::Tensor output_grad_,
at::Tensor softmax_results_,
float scale_factor
);
at::Tensor scaled_upper_triang_masked_softmax_forward(at::Tensor input,
float scale_factor
);
at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_,
at::Tensor softmax_results_,
float scale_factor
);
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TORCH_CHECK
#define TORCH_CHECK AT_CHECK
#endif
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <vector>
namespace transformer_engine {
namespace scaled_masked_softmax {
torch::Tensor fwd_cuda(
torch::Tensor const& input,
torch::Tensor const& mask,
float scale_factor);
torch::Tensor bwd_cuda(
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor);
int get_batch_per_block_cuda(
int query_seq_len,
int key_seq_len,
int batches,
int attn_heads);
torch::Tensor fwd(
torch::Tensor const& input,
torch::Tensor const& mask,
float scale_factor) {
AT_ASSERTM(input.dim() == 4, "expected 4D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM(mask.dim() == 4, "expected 4D tensor");
return fwd_cuda(input, mask, scale_factor);
}
torch::Tensor bwd(
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor) {
AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor");
AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
return bwd_cuda(output_grads, softmax_results, scale_factor);
}
int get_batch_per_block(
int query_seq_len,
int key_seq_len,
int batches,
int attn_heads) {
return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, attn_heads);
}
} // end namespace scaled_masked_softmax
} // end namespace transformer_engine
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward",
&transformer_engine::scaled_masked_softmax::fwd,
"Self Multihead Attention scaled, time masked softmax -- Forward.");
m.def("backward",
&transformer_engine::scaled_masked_softmax::bwd,
"Self Multihead Attention scaled, time masked softmax -- Backward.");
m.def("get_batch_per_block",
&transformer_engine::scaled_masked_softmax::get_batch_per_block,
"Return Batch per block size.");
}
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "scaled_masked_softmax.h"
#include "type_shim.h"
namespace transformer_engine {
namespace scaled_masked_softmax {
int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, int attn_heads) {
return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads);
}
torch::Tensor fwd_cuda(
torch::Tensor const& input,
torch::Tensor const& mask,
float scale_factor) {
// input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const int batches = input.size(0);
const int pad_batches = mask.size(0);
const int attn_heads = input.size(1);
const int query_seq_len = input.size(2);
const int key_seq_len = input.size(3);
TORCH_INTERNAL_ASSERT(key_seq_len <= 4096);
TORCH_INTERNAL_ASSERT(query_seq_len > 1);
TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches);
TORCH_INTERNAL_ASSERT(mask.size(1) == 1);
TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len);
TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len);
// Output
auto act_options = input.options().requires_grad(false);
torch::Tensor softmax_results =
torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options);
// Softmax Intermediate Result Ptr
void* input_ptr = static_cast<void*>(input.data_ptr());
void* mask_ptr = static_cast<void*>(mask.data_ptr());
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
DISPATCH_HALF_AND_BFLOAT(
input.scalar_type(),
"dispatch_scaled_masked_softmax_forward",
dispatch_scaled_masked_softmax_forward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(softmax_results_ptr),
reinterpret_cast<const scalar_t*>(input_ptr),
reinterpret_cast<const uint8_t*>(mask_ptr),
scale_factor,
query_seq_len,
key_seq_len,
batches,
attn_heads,
pad_batches););
return softmax_results;
}
torch::Tensor bwd_cuda(
torch::Tensor const& output_grads_,
torch::Tensor const& softmax_results_,
float scale_factor) {
auto output_grads = output_grads_.contiguous();
auto softmax_results = softmax_results_.contiguous();
// output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const int batches = output_grads.size(0);
const int attn_heads = output_grads.size(1);
const int query_seq_len = output_grads.size(2);
const int key_seq_len = output_grads.size(3);
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
// Softmax Grad
DISPATCH_HALF_AND_BFLOAT(
output_grads_.scalar_type(),
"dispatch_scaled_masked_softmax_backward",
dispatch_scaled_masked_softmax_backward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
scale_factor,
query_seq_len,
key_seq_len,
batches,
attn_heads););
// backward pass is completely in-place
return output_grads;
}
} // end namespace scaled_masked_softmax
} // end namespace transformer_engine
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <vector>
namespace transformer_engine {
namespace scaled_softmax {
torch::Tensor fwd_cuda(
torch::Tensor const& input,
float scale_factor);
torch::Tensor bwd_cuda(
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor);
torch::Tensor fwd(
torch::Tensor const& input,
float scale_factor) {
AT_ASSERTM(input.dim() == 4, "expected 4D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
return fwd_cuda(input, scale_factor);
}
torch::Tensor bwd(
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor) {
AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor");
AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
return bwd_cuda(output_grads, softmax_results, scale_factor);
}
} // end namespace scaled_softmax
} // end namespace transformer_engine
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward",
&transformer_engine::scaled_softmax::fwd,
"Self Multihead Attention scaled, softmax -- Forward.");
m.def("backward",
&transformer_engine::scaled_softmax::bwd,
"Self Multihead Attention scaled, softmax -- Backward.");
}
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "scaled_masked_softmax.h"
#include "type_shim.h"
namespace transformer_engine {
namespace scaled_softmax {
torch::Tensor fwd_cuda(
torch::Tensor const& input,
float scale_factor) {
// input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const int batches = input.size(0);
const int attn_heads = input.size(1);
const int query_seq_len = input.size(2);
const int key_seq_len = input.size(3);
TORCH_INTERNAL_ASSERT(key_seq_len <= 4096);
TORCH_INTERNAL_ASSERT(query_seq_len > 1);
// Output
auto act_options = input.options().requires_grad(false);
torch::Tensor softmax_results =
torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options);
// Softmax Intermediate Result Ptr
void* input_ptr = static_cast<void*>(input.data_ptr());
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
DISPATCH_HALF_AND_BFLOAT(
input.scalar_type(),
"dispatch_scaled_softmax_forward",
dispatch_scaled_softmax_forward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(softmax_results_ptr),
reinterpret_cast<const scalar_t*>(input_ptr),
scale_factor,
query_seq_len,
key_seq_len,
batches,
attn_heads););
return softmax_results;
}
torch::Tensor bwd_cuda(
torch::Tensor const& output_grads_,
torch::Tensor const& softmax_results_,
float scale_factor) {
auto output_grads = output_grads_.contiguous();
auto softmax_results = softmax_results_.contiguous();
// output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const int batches = output_grads.size(0);
const int attn_heads = output_grads.size(1);
const int query_seq_len = output_grads.size(2);
const int key_seq_len = output_grads.size(3);
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
// Softmax Grad
DISPATCH_HALF_AND_BFLOAT(
output_grads_.scalar_type(),
"dispatch_scaled_masked_softmax_backward",
dispatch_scaled_masked_softmax_backward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
scale_factor,
query_seq_len,
key_seq_len,
batches,
attn_heads););
// backward pass is completely in-place
return output_grads;
}
} // end namespace scaled_softmax
} // end namespace transformer_engine
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <vector>
namespace transformer_engine {
namespace scaled_upper_triang_masked_softmax {
torch::Tensor fwd_cuda(
torch::Tensor const& input,
float scale_factor);
torch::Tensor bwd_cuda(
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor);
torch::Tensor fwd(torch::Tensor const& input, float scale_factor) {
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
return fwd_cuda(input, scale_factor);
}
torch::Tensor bwd(
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor) {
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
return bwd_cuda(output_grads, softmax_results, scale_factor);
}
} // end namespace scaled_upper_triang_masked_softmax
} // end namespace transformer_engine
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward",
&transformer_engine::scaled_upper_triang_masked_softmax::fwd,
"Self Multihead Attention scaled, time masked softmax -- Forward.");
m.def("backward",
&transformer_engine::scaled_upper_triang_masked_softmax::bwd,
"Self Multihead Attention scaled, time masked softmax -- Backward.");
}
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "scaled_upper_triang_masked_softmax.h"
#include "type_shim.h"
namespace transformer_engine {
namespace scaled_upper_triang_masked_softmax {
torch::Tensor fwd_cuda(
torch::Tensor const& input,
float scale_factor) {
// input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
const int attn_batches = input.size(0);
const int seq_len = input.size(1);
TORCH_INTERNAL_ASSERT(seq_len <= 2048);
// Output
auto act_options = input.options().requires_grad(false);
torch::Tensor softmax_results =
torch::empty({attn_batches, seq_len, seq_len}, act_options);
// Softmax Intermediate Result Ptr
void* input_ptr = static_cast<void*>(input.data_ptr());
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
DISPATCH_HALF_AND_BFLOAT(
input.scalar_type(),
"dispatch_scaled_upper_triang_masked_softmax_forward",
dispatch_scaled_upper_triang_masked_softmax_forward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(softmax_results_ptr),
reinterpret_cast<const scalar_t*>(input_ptr),
scale_factor,
seq_len,
seq_len,
attn_batches););
return softmax_results;
}
torch::Tensor bwd_cuda(
torch::Tensor const& output_grads_,
torch::Tensor const& softmax_results_,
float scale_factor) {
auto output_grads = output_grads_.contiguous();
auto softmax_results = softmax_results_.contiguous();
// output grads is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
const int attn_batches = output_grads.size(0);
const int seq_len = output_grads.size(1);
TORCH_INTERNAL_ASSERT(output_grads.size(1) == output_grads.size(2));
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
// Softmax Grad
DISPATCH_HALF_AND_BFLOAT(
output_grads_.scalar_type(),
"dispatch_scaled_upper_triang_masked_softmax_backward",
dispatch_scaled_upper_triang_masked_softmax_backward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
scale_factor,
seq_len,
seq_len,
attn_batches););
// backward pass is completely in-place
return output_grads;
}
} // end namespace scaled_upper_triang_masked_softmax
} // end namespace transformer_engine
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <ATen/ATen.h>
#include "compat.h"
#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \
switch (TYPE) \
{ \
case at::ScalarType::Half: \
{ \
using scalar_t = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_HALF_BFLOAT_AND_FLOAT(TYPE, NAME, ...) \
switch (TYPE) \
{ \
case at::ScalarType::Half: \
{ \
using scalar_t = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t = float; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
switch (TYPEIN) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_in = float; \
switch (TYPEOUT) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_out = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
} \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_in = at::Half; \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_in = at::BFloat16; \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
}
......@@ -9,6 +9,11 @@ from typing import Callable, Tuple, Union
import torch
from torch import nn
import transformer_engine_extensions as tex
THREADS_PER_WARP = 32
THREADS_PER_BLOCK = 128
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
"""
......@@ -21,10 +26,8 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
@staticmethod
def forward(ctx, inputs: torch.Tensor, scale: float) -> torch.Tensor:
"""ScaledUpperTriangMaskedSoftmax fwd"""
import scaled_upper_triang_masked_softmax_cuda
scale_t = torch.tensor([scale])
softmax_results = scaled_upper_triang_masked_softmax_cuda.forward(
softmax_results = tex.scaled_upper_triang_masked_softmax_forward(
inputs, scale_t[0]
)
......@@ -36,10 +39,8 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
ctx, output_grads: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]:
"""ScaledUpperTriangMaskedSoftmax bwd"""
import scaled_upper_triang_masked_softmax_cuda
softmax_results, scale_t = ctx.saved_tensors
input_grads = scaled_upper_triang_masked_softmax_cuda.backward(
input_grads = tex.scaled_upper_triang_masked_softmax_backward(
output_grads, softmax_results, scale_t[0]
)
......@@ -59,11 +60,9 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
ctx, inputs: torch.Tensor, mask: torch.Tensor, scale: float
) -> torch.Tensor:
"""ScaledMaskedSoftmax fwd"""
import scaled_masked_softmax_cuda
scale_t = torch.tensor([scale])
softmax_results = scaled_masked_softmax_cuda.forward(inputs, mask, scale_t[0])
softmax_results = tex.scaled_masked_softmax_forward(inputs, mask, scale_t[0])
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
......@@ -72,11 +71,9 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
ctx, output_grads: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]:
"""ScaledMaskedSoftmax bwd"""
import scaled_masked_softmax_cuda
softmax_results, scale_t = ctx.saved_tensors
input_grads = scaled_masked_softmax_cuda.backward(
input_grads = tex.scaled_masked_softmax_backward(
output_grads, softmax_results, scale_t[0]
)
return input_grads, None, None
......@@ -92,11 +89,9 @@ class ScaledSoftmax(torch.autograd.Function):
@staticmethod
def forward(ctx, inputs: torch.Tensor, scale: float) -> torch.Tensor:
"""ScaledSoftmax fwd"""
import scaled_softmax_cuda
scale_t = torch.tensor([scale])
softmax_results = scaled_softmax_cuda.forward(inputs, scale_t[0])
softmax_results = tex.scaled_softmax_forward(inputs, scale_t[0])
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
......@@ -105,11 +100,9 @@ class ScaledSoftmax(torch.autograd.Function):
ctx, output_grads: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]:
"""ScaledSoftmax bwd"""
import scaled_softmax_cuda
softmax_results, scale_t = ctx.saved_tensors
input_grads = scaled_softmax_cuda.backward(
input_grads = tex.scaled_softmax_backward(
output_grads, softmax_results, scale_t[0]
)
return input_grads, None, None
......@@ -170,7 +163,7 @@ class FusedScaleMaskSoftmax(nn.Module):
and attn_batches % 4 == 0 # np * b must be divisor of 4
):
if 0 <= sk <= 4096:
batch_per_block = self.get_batch_per_block(sq, sk, b, np)
batch_per_block = self.get_batch_per_block(sk)
if self.attn_mask_type == "causal":
if attn_batches % batch_per_block == 0:
......@@ -220,8 +213,11 @@ class FusedScaleMaskSoftmax(nn.Module):
return probs
@staticmethod
def get_batch_per_block(sq: int, sk: int, b: int, np: int) -> int:
def get_batch_per_block(key_seq_len: int) -> int:
"""Softmax utility"""
import scaled_masked_softmax_cuda
return scaled_masked_softmax_cuda.get_batch_per_block(sq, sk, b, np)
pow2 = 1 << (key_seq_len - 1).bit_length()
warp_size = pow2 if pow2 < THREADS_PER_WARP else THREADS_PER_WARP
batches_per_warp = 2 if pow2 <= 128 else 1
warps_per_block = THREADS_PER_BLOCK / warp_size
batches_per_block = warps_per_block * batches_per_warp
return batches_per_block
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment