Unverified Commit f674d49e authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Framework agnostic softmax kernels (#30)



* Make fused softmax kernels PyTorch independent
Co-authored-by: default avatarSean Lee <selee@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Address review comments
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* move get_batch_per_block to python
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix license in softmax.h
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarSean Lee <selee@nvidia.com>
parent b2743878
......@@ -137,75 +137,6 @@ if framework in ("all", "pytorch"):
)
)
ext_modules.append(
CUDAExtension(
name="scaled_upper_triang_masked_softmax_cuda",
sources=[
os.path.join(
path,
"transformer_engine/pytorch/csrc/fused_softmax/scaled_upper_triang_masked_softmax.cpp",
),
os.path.join(
path,
"transformer_engine/pytorch/csrc/fused_softmax/scaled_upper_triang_masked_softmax_cuda.cu",
),
],
extra_compile_args={
"cxx": ["-O3"],
"nvcc": append_nvcc_threads(extra_compiler_flags() + cc_flag),
},
include_dirs=[
os.path.join(path, "transformer_engine/pytorch/csrc/fused_softmax")
],
)
)
ext_modules.append(
CUDAExtension(
name="scaled_masked_softmax_cuda",
sources=[
os.path.join(
path,
"transformer_engine/pytorch/csrc/fused_softmax/scaled_masked_softmax.cpp",
),
os.path.join(
path,
"transformer_engine/pytorch/csrc/fused_softmax/scaled_masked_softmax_cuda.cu",
),
],
extra_compile_args={
"cxx": ["-O3"],
"nvcc": append_nvcc_threads(extra_compiler_flags() + cc_flag),
},
include_dirs=[
os.path.join(path, "transformer_engine/pytorch/csrc/fused_softmax")
],
)
)
ext_modules.append(
CUDAExtension(
name="scaled_softmax_cuda",
sources=[
os.path.join(
path,
"transformer_engine/pytorch/csrc/fused_softmax/scaled_softmax.cpp",
),
os.path.join(
path,
"transformer_engine/pytorch/csrc/fused_softmax/scaled_softmax_cuda.cu",
),
],
extra_compile_args={
"cxx": ["-O3"],
"nvcc": append_nvcc_threads(extra_compiler_flags() + cc_flag),
},
include_dirs=[
os.path.join(path, "transformer_engine/pytorch/csrc/fused_softmax")
],
)
)
def get_cmake_bin():
cmake_bin = "cmake"
......
......@@ -31,7 +31,9 @@ add_library(transformer_engine SHARED
layer_norm/ln_api.cpp
layer_norm/ln_bwd_semi_cuda_kernel.cu
layer_norm/ln_fwd_cuda_kernel.cu
util/cast.cu)
util/cast.cu
fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu)
target_include_directories(transformer_engine PUBLIC "${PROJECT_SOURCE_DIR}/include")
......@@ -41,3 +43,10 @@ list(APPEND transformer_engine_LINKER_LIBS CUDA::cublas CUDA::cudart)
target_link_libraries(transformer_engine PUBLIC ${transformer_engine_LINKER_LIBS})
target_include_directories(transformer_engine PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
set_source_files_properties(fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu
PROPERTIES
COMPILE_OPTIONS "--use_fast_math")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3")
......@@ -219,6 +219,26 @@ struct TypeInfo{
NVTE_ERROR("Invalid type."); \
}
#define TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(dtype, type, ...) \
switch (dtype) \
{ \
using namespace transformer_engine; \
case DType::kFloat16: \
{ \
using type = fp16; \
__VA_ARGS__; \
break; \
} \
case DType::kBFloat16: \
{ \
using type = bf16; \
__VA_ARGS__; \
break; \
} \
default: \
NVTE_ERROR("Invalid type for 16 bit."); \
}
template<typename T>
struct TypeId{};
......@@ -283,6 +303,12 @@ inline size_t product(const std::vector<size_t> &shape) {
return ret;
}
inline int log2_ceil(int value) {
int log2_value = 0;
while ((1 << log2_value) < value) ++log2_value;
return log2_value;
}
template <typename T>
struct is_fp8 : std::false_type {};
......
......@@ -4,42 +4,48 @@
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_SCALED_MASKED_SOFTMAX_H_
#define TRANSFORMER_ENGINE_SCALED_MASKED_SOFTMAX_H_
#include <transformer_engine/softmax.h>
#include <transformer_engine/logging.h>
#include <assert.h>
#include <cuda_fp16.h>
#include <stdint.h>
#include <c10/macros/Macros.h>
#include <cfloat>
#include <limits>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include "../utils.cuh"
#include "../common.h"
namespace transformer_engine {
namespace {
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 1>(c10::BFloat16 *dst,
const c10::BFloat16 *src) {
__device__ __inline__ void copy_vector<bf16, 1>(bf16 *dst,
const bf16 *src) {
*dst = *src;
}
template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 4>(c10::BFloat16 *dst,
const c10::BFloat16 *src) {
__device__ __inline__ void copy_vector<bf16, 4>(bf16 *dst,
const bf16 *src) {
*((float2*) dst) = *((float2*) src); // NOLINT(*)
}
template <>
__device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst,
const c10::Half *src) {
__device__ __inline__ void copy_vector<half, 1>(half *dst,
const half *src) {
*dst = *src;
}
template <>
__device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst,
const c10::Half *src) {
__device__ __inline__ void copy_vector<half, 4>(half *dst,
const half *src) {
*((float2*) dst) = *((float2*) src); // NOLINT(*)
}
......@@ -55,12 +61,6 @@ __device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst,
*((half2*) dst) = *((half2*) src); // NOLINT(*)
}
int log2_ceil(int value) {
int log2_value = 0;
while ((1 << log2_value) < value) ++log2_value;
return log2_value;
}
template<typename T>
struct Add {
__device__ __forceinline__ T operator()(T a, T b) const {
......@@ -113,8 +113,8 @@ __global__ void scaled_softmax_warp_forward(
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_forward_kernel.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ?
next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_SIZE = (next_power_of_two < THREADS_PER_WARP) ?
next_power_of_two : THREADS_PER_WARP;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
......@@ -228,8 +228,8 @@ __global__ void scaled_masked_softmax_warp_forward(
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_forward_kernel.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ?
next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_SIZE = (next_power_of_two < THREADS_PER_WARP) ?
next_power_of_two : THREADS_PER_WARP;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
......@@ -305,6 +305,13 @@ __global__ void scaled_masked_softmax_warp_forward(
}
warp_reduce<acc_t, WARP_BATCH, WARP_SIZE, Max>(max_value);
// compute scale value to account for full mask
acc_t scale_value[WARP_BATCH];
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
scale_value[i] = (max_value[i] == -10000.0) ? 0.0 : 1.0;
}
acc_t sum[WARP_BATCH] { 0.0f };
#pragma unroll
for (int i = 0; i < WARP_BATCH; ++i) {
......@@ -328,7 +335,7 @@ __global__ void scaled_masked_softmax_warp_forward(
if (element_index < element_count) {
#pragma unroll
for (int element = 0; element < ELEMENTS_PER_LDG_STG; ++element) {
out[element] = elements[i][it + element] / sum[i];
out[element] = elements[i][it + element] * scale_value[i] / sum[i];
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count
+ it * WARP_SIZE, out);
......@@ -350,8 +357,8 @@ __global__ void scaled_masked_softmax_warp_backward(
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_backward_kernel.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ?
next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_SIZE = (next_power_of_two < THREADS_PER_WARP) ?
next_power_of_two : THREADS_PER_WARP;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
......@@ -439,21 +446,7 @@ __global__ void scaled_masked_softmax_warp_backward(
}
}
}
} // end of anonymous namespace
int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int attn_heads) {
int log2_elements = log2_ceil(key_seq_len);
const int next_power_of_two = 1 << log2_elements;
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
return batches_per_block;
}
template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_softmax_forward(
......@@ -463,8 +456,9 @@ void dispatch_scaled_softmax_forward(
int query_seq_len,
int key_seq_len,
int batches,
int attn_heads) {
TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 4096);
int attn_heads,
cudaStream_t stream) {
NVTE_CHECK(key_seq_len >= 0 && key_seq_len <= 4096, "Unsupported shape.");
if (key_seq_len == 0) {
return;
} else {
......@@ -474,7 +468,8 @@ void dispatch_scaled_softmax_forward(
// This value must match the WARP_SIZE constexpr
// value computed inside softmax_warp_forward.
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
int warp_size = (next_power_of_two < THREADS_PER_WARP) ? next_power_of_two
: THREADS_PER_WARP;
// This value must match the WARP_BATCH constexpr
// value computed inside softmax_warp_forward.
......@@ -485,14 +480,14 @@ void dispatch_scaled_softmax_forward(
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0);
NVTE_CHECK(query_seq_len%batches_per_block == 0, "Unsupported shape.");
dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches);
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
scaled_softmax_warp_forward<input_t, output_t, acc_t, 0>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
<<<blocks, threads, 0, stream>>>(dst,
src,
scale,
batch_count,
......@@ -500,7 +495,7 @@ void dispatch_scaled_softmax_forward(
break;
case 1: // 2
scaled_softmax_warp_forward<input_t, output_t, acc_t, 1>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
<<<blocks, threads, 0, stream>>>(dst,
src,
scale,
batch_count,
......@@ -508,7 +503,7 @@ void dispatch_scaled_softmax_forward(
break;
case 2: // 4
scaled_softmax_warp_forward<input_t, output_t, acc_t, 2>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
<<<blocks, threads, 0, stream>>>(dst,
src,
scale,
batch_count,
......@@ -516,7 +511,7 @@ void dispatch_scaled_softmax_forward(
break;
case 3: // 8
scaled_softmax_warp_forward<input_t, output_t, acc_t, 3>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
<<<blocks, threads, 0, stream>>>(dst,
src,
scale,
batch_count,
......@@ -524,7 +519,7 @@ void dispatch_scaled_softmax_forward(
break;
case 4: // 16
scaled_softmax_warp_forward<input_t, output_t, acc_t, 4>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
<<<blocks, threads, 0, stream>>>(dst,
src,
scale,
batch_count,
......@@ -532,7 +527,7 @@ void dispatch_scaled_softmax_forward(
break;
case 5: // 32
scaled_softmax_warp_forward<input_t, output_t, acc_t, 5>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
<<<blocks, threads, 0, stream>>>(dst,
src,
scale,
batch_count,
......@@ -540,7 +535,7 @@ void dispatch_scaled_softmax_forward(
break;
case 6: // 64
scaled_softmax_warp_forward<input_t, output_t, acc_t, 6>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
<<<blocks, threads, 0, stream>>>(dst,
src,
scale,
batch_count,
......@@ -548,7 +543,7 @@ void dispatch_scaled_softmax_forward(
break;
case 7: // 128
scaled_softmax_warp_forward<input_t, output_t, acc_t, 7>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
<<<blocks, threads, 0, stream>>>(dst,
src,
scale,
batch_count,
......@@ -556,7 +551,7 @@ void dispatch_scaled_softmax_forward(
break;
case 8: // 256
scaled_softmax_warp_forward<input_t, output_t, acc_t, 8>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
<<<blocks, threads, 0, stream>>>(dst,
src,
scale,
batch_count,
......@@ -564,7 +559,7 @@ void dispatch_scaled_softmax_forward(
break;
case 9: // 512
scaled_softmax_warp_forward<input_t, output_t, acc_t, 9>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
<<<blocks, threads, 0, stream>>>(dst,
src,
scale,
batch_count,
......@@ -572,7 +567,7 @@ void dispatch_scaled_softmax_forward(
break;
case 10: // 1024
scaled_softmax_warp_forward<input_t, output_t, acc_t, 10>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
<<<blocks, threads, 0, stream>>>(dst,
src,
scale,
batch_count,
......@@ -580,7 +575,7 @@ void dispatch_scaled_softmax_forward(
break;
case 11: // 2048
scaled_softmax_warp_forward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
<<<blocks, threads, 0, stream>>>(dst,
src,
scale,
batch_count,
......@@ -588,7 +583,7 @@ void dispatch_scaled_softmax_forward(
break;
case 12: // 4096
scaled_softmax_warp_forward<input_t, output_t, acc_t, 12>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
<<<blocks, threads, 0, stream>>>(dst,
src,
scale,
batch_count,
......@@ -610,8 +605,9 @@ void dispatch_scaled_masked_softmax_forward(
int key_seq_len,
int batches,
int attn_heads,
int pad_batches) {
TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 4096);
int pad_batches,
cudaStream_t stream) {
NVTE_CHECK(key_seq_len >= 0 && key_seq_len <= 4096, "Unsupported shape.");
if (key_seq_len == 0) {
return;
} else {
......@@ -621,7 +617,8 @@ void dispatch_scaled_masked_softmax_forward(
// This value must match the WARP_SIZE constexpr
// value computed inside softmax_warp_forward.
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
int warp_size = (next_power_of_two < THREADS_PER_WARP) ? next_power_of_two
: THREADS_PER_WARP;
// This value must match the WARP_BATCH constexpr
// value computed inside softmax_warp_forward.
......@@ -632,14 +629,14 @@ void dispatch_scaled_masked_softmax_forward(
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
TORCH_INTERNAL_ASSERT(query_seq_len%batches_per_block == 0);
NVTE_CHECK(query_seq_len%batches_per_block == 0, "Unsupported shape.");
dim3 blocks(query_seq_len/batches_per_block, attn_heads, batches);
dim3 threads(warp_size, warps_per_block, 1);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch (log2_elements) {
case 0: // 1
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 0>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
<<<blocks, threads, 0, stream>>>(dst,
src,
mask,
scale,
......@@ -649,7 +646,7 @@ void dispatch_scaled_masked_softmax_forward(
break;
case 1: // 2
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 1>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
<<<blocks, threads, 0, stream>>>(dst,
src,
mask,
scale,
......@@ -659,7 +656,7 @@ void dispatch_scaled_masked_softmax_forward(
break;
case 2: // 4
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 2>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
<<<blocks, threads, 0, stream>>>(dst,
src,
mask,
scale,
......@@ -669,7 +666,7 @@ void dispatch_scaled_masked_softmax_forward(
break;
case 3: // 8
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 3>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
<<<blocks, threads, 0, stream>>>(dst,
src,
mask,
scale,
......@@ -679,7 +676,7 @@ void dispatch_scaled_masked_softmax_forward(
break;
case 4: // 16
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 4>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
<<<blocks, threads, 0, stream>>>(dst,
src,
mask,
scale,
......@@ -689,7 +686,7 @@ void dispatch_scaled_masked_softmax_forward(
break;
case 5: // 32
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 5>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
<<<blocks, threads, 0, stream>>>(dst,
src,
mask,
scale,
......@@ -699,7 +696,7 @@ void dispatch_scaled_masked_softmax_forward(
break;
case 6: // 64
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 6>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
<<<blocks, threads, 0, stream>>>(dst,
src,
mask,
scale,
......@@ -709,7 +706,7 @@ void dispatch_scaled_masked_softmax_forward(
break;
case 7: // 128
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 7>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
<<<blocks, threads, 0, stream>>>(dst,
src,
mask,
scale,
......@@ -719,7 +716,7 @@ void dispatch_scaled_masked_softmax_forward(
break;
case 8: // 256
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 8>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
<<<blocks, threads, 0, stream>>>(dst,
src,
mask,
scale,
......@@ -729,7 +726,7 @@ void dispatch_scaled_masked_softmax_forward(
break;
case 9: // 512
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 9>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
<<<blocks, threads, 0, stream>>>(dst,
src,
mask,
scale,
......@@ -739,7 +736,7 @@ void dispatch_scaled_masked_softmax_forward(
break;
case 10: // 1024
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 10>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
<<<blocks, threads, 0, stream>>>(dst,
src,
mask,
scale,
......@@ -749,7 +746,7 @@ void dispatch_scaled_masked_softmax_forward(
break;
case 11: // 2048
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
<<<blocks, threads, 0, stream>>>(dst,
src,
mask,
scale,
......@@ -759,7 +756,7 @@ void dispatch_scaled_masked_softmax_forward(
break;
case 12: // 4096
scaled_masked_softmax_warp_forward<input_t, output_t, acc_t, 12>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
<<<blocks, threads, 0, stream>>>(dst,
src,
mask,
scale,
......@@ -782,8 +779,9 @@ void dispatch_scaled_masked_softmax_backward(
int query_seq_len,
int key_seq_len,
int batches,
int attn_heads) {
TORCH_INTERNAL_ASSERT(key_seq_len >= 0 && key_seq_len <= 4096);
int attn_heads,
cudaStream_t stream) {
NVTE_CHECK(key_seq_len >= 0 && key_seq_len <= 4096, "Unsupported shape.");
if (key_seq_len == 0) {
return;
} else {
......@@ -793,7 +791,8 @@ void dispatch_scaled_masked_softmax_backward(
// This value must match the WARP_SIZE constexpr
// value computed inside softmax_warp_backward.
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
int warp_size = (next_power_of_two < THREADS_PER_WARP) ? next_power_of_two
: THREADS_PER_WARP;
// This value must match the WARP_BATCH constexpr
// value computed inside softmax_warp_backward.
......@@ -810,7 +809,7 @@ void dispatch_scaled_masked_softmax_backward(
switch (log2_elements) {
case 0: // 1
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 0>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input,
<<<blocks, threads, 0, stream>>>(grad_input,
grad,
output,
scale,
......@@ -819,7 +818,7 @@ void dispatch_scaled_masked_softmax_backward(
break;
case 1: // 2
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 1>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input,
<<<blocks, threads, 0, stream>>>(grad_input,
grad,
output,
scale,
......@@ -829,7 +828,7 @@ void dispatch_scaled_masked_softmax_backward(
break;
case 2: // 4
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 2>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input,
<<<blocks, threads, 0, stream>>>(grad_input,
grad,
output,
scale,
......@@ -839,7 +838,7 @@ void dispatch_scaled_masked_softmax_backward(
break;
case 3: // 8
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 3>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input,
<<<blocks, threads, 0, stream>>>(grad_input,
grad,
output,
scale,
......@@ -849,7 +848,7 @@ void dispatch_scaled_masked_softmax_backward(
break;
case 4: // 16
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 4>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input,
<<<blocks, threads, 0, stream>>>(grad_input,
grad,
output,
scale,
......@@ -859,7 +858,7 @@ void dispatch_scaled_masked_softmax_backward(
break;
case 5: // 32
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 5>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input,
<<<blocks, threads, 0, stream>>>(grad_input,
grad,
output,
scale,
......@@ -869,7 +868,7 @@ void dispatch_scaled_masked_softmax_backward(
break;
case 6: // 64
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 6>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input,
<<<blocks, threads, 0, stream>>>(grad_input,
grad,
output,
scale,
......@@ -879,7 +878,7 @@ void dispatch_scaled_masked_softmax_backward(
break;
case 7: // 128
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 7>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input,
<<<blocks, threads, 0, stream>>>(grad_input,
grad,
output,
scale,
......@@ -889,7 +888,7 @@ void dispatch_scaled_masked_softmax_backward(
break;
case 8: // 256
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 8>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input,
<<<blocks, threads, 0, stream>>>(grad_input,
grad,
output,
scale,
......@@ -899,7 +898,7 @@ void dispatch_scaled_masked_softmax_backward(
break;
case 9: // 512
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 9>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input,
<<<blocks, threads, 0, stream>>>(grad_input,
grad,
output,
scale,
......@@ -909,7 +908,7 @@ void dispatch_scaled_masked_softmax_backward(
break;
case 10: // 1024
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 10>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input,
<<<blocks, threads, 0, stream>>>(grad_input,
grad,
output,
scale,
......@@ -919,7 +918,7 @@ void dispatch_scaled_masked_softmax_backward(
break;
case 11: // 2048
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input,
<<<blocks, threads, 0, stream>>>(grad_input,
grad,
output,
scale,
......@@ -929,7 +928,7 @@ void dispatch_scaled_masked_softmax_backward(
break;
case 12: // 4096
scaled_masked_softmax_warp_backward<input_t, output_t, acc_t, 12>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input,
<<<blocks, threads, 0, stream>>>(grad_input,
grad,
output,
scale,
......@@ -944,4 +943,172 @@ void dispatch_scaled_masked_softmax_backward(
}
}
#endif // TRANSFORMER_ENGINE_SCALED_MASKED_SOFTMAX_H_
void scaled_softmax_forward(
const Tensor &input,
Tensor *softmax_results,
float scale_factor,
cudaStream_t stream) {
const int batches = input.shape[0];
const int attn_heads = input.shape[1];
const int query_seq_len = input.shape[2];
const int key_seq_len = input.shape[3];
TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(input.dtype, softmax_type,
dispatch_scaled_softmax_forward<softmax_type, softmax_type, float>(
reinterpret_cast<softmax_type*>(softmax_results->dptr),
reinterpret_cast<const softmax_type*>(input.dptr),
scale_factor,
query_seq_len,
key_seq_len,
batches,
attn_heads,
stream););
}
void scaled_softmax_backward(
const Tensor output_grads,
const Tensor softmax_results,
float scale_factor,
cudaStream_t stream) {
// output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const int batches = output_grads.shape[0];
const int attn_heads = output_grads.shape[1];
const int query_seq_len = output_grads.shape[2];
const int key_seq_len = output_grads.shape[3];
// Softmax Grad
TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(output_grads.dtype, softmax_type,
dispatch_scaled_masked_softmax_backward<softmax_type, softmax_type, float>(
reinterpret_cast<softmax_type*>(output_grads.dptr),
reinterpret_cast<softmax_type*>(output_grads.dptr),
reinterpret_cast<softmax_type const*>(softmax_results.dptr),
scale_factor,
query_seq_len,
key_seq_len,
batches,
attn_heads,
stream););
}
void scaled_masked_softmax_forward(
const Tensor input,
const Tensor mask,
Tensor *softmax_results,
float scale_factor,
cudaStream_t stream) {
const int batches = input.shape[0];
const int pad_batches = mask.shape[0];
const int attn_heads = input.shape[1];
const int query_seq_len = input.shape[2];
const int key_seq_len = input.shape[3];
TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(input.dtype, softmax_type,
dispatch_scaled_masked_softmax_forward<softmax_type, softmax_type, float>(
reinterpret_cast<softmax_type*>(softmax_results->dptr),
reinterpret_cast<const softmax_type*>(input.dptr),
reinterpret_cast<const uint8_t*>(mask.dptr),
scale_factor,
query_seq_len,
key_seq_len,
batches,
attn_heads,
pad_batches,
stream););
}
void scaled_masked_softmax_backward(
const Tensor output_grads,
const Tensor softmax_results,
float scale_factor,
cudaStream_t stream
) {
// output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const int batches = output_grads.shape[0];
const int attn_heads = output_grads.shape[1];
const int query_seq_len = output_grads.shape[2];
const int key_seq_len = output_grads.shape[3];
// Softmax Grad
TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(output_grads.dtype, softmax_type,
dispatch_scaled_masked_softmax_backward<softmax_type, softmax_type, float>(
reinterpret_cast<softmax_type*>(output_grads.dptr),
reinterpret_cast<softmax_type*>(output_grads.dptr),
reinterpret_cast<softmax_type const*>(softmax_results.dptr),
scale_factor,
query_seq_len,
key_seq_len,
batches,
attn_heads,
stream););
}
} // end namespace transformer_engine
void nvte_scaled_softmax_forward(
const NVTETensor input,
NVTETensor softmax_results,
float scale_factor,
cudaStream_t stream
) {
using namespace transformer_engine;
scaled_softmax_forward(
*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(softmax_results),
scale_factor,
stream);
}
void nvte_scaled_softmax_backward(
const NVTETensor output_grads,
const NVTETensor softmax_results,
float scale_factor,
cudaStream_t stream
) {
using namespace transformer_engine;
scaled_softmax_backward(
*reinterpret_cast<const Tensor*>(output_grads),
*reinterpret_cast<const Tensor*>(softmax_results),
scale_factor,
stream);
}
void nvte_scaled_masked_softmax_forward(
const NVTETensor input,
const NVTETensor mask,
NVTETensor softmax_results,
float scale_factor,
cudaStream_t stream
) {
using namespace transformer_engine;
scaled_masked_softmax_forward(
*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<const Tensor*>(mask),
reinterpret_cast<Tensor*>(softmax_results),
scale_factor,
stream);
}
void nvte_scaled_masked_softmax_backward(
const NVTETensor input,
NVTETensor softmax_results,
float scale_factor,
cudaStream_t stream
) {
using namespace transformer_engine;
scaled_masked_softmax_backward(
*reinterpret_cast<const Tensor*>(input),
*reinterpret_cast<Tensor*>(softmax_results),
scale_factor,
stream);
}
......@@ -4,42 +4,46 @@
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_SCALED_UPPER_TRIANG_SOFTMAX_H_
#define TRANSFORMER_ENGINE_SCALED_UPPER_TRIANG_SOFTMAX_H_
#include <transformer_engine/softmax.h>
#include <transformer_engine/logging.h>
#include <assert.h>
#include <cuda_fp16.h>
#include <stdint.h>
#include <cfloat>
#include <limits>
#include <stdint.h>
#include <c10/macros/Macros.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include "../utils.cuh"
#include "../common.h"
namespace {
namespace transformer_engine {
template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 1>(c10::BFloat16 *dst,
const c10::BFloat16 *src) {
__device__ __inline__ void copy_vector<bf16, 1>(bf16 *dst,
const bf16 *src) {
*dst = *src;
}
template <>
__device__ __inline__ void copy_vector<c10::BFloat16, 4>(c10::BFloat16 *dst,
const c10::BFloat16 *src) {
__device__ __inline__ void copy_vector<bf16, 4>(bf16 *dst,
const bf16 *src) {
*((float2*) dst) = *((float2*) src); // NOLINT(*)
}
template <>
__device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst,
const c10::Half *src) {
__device__ __inline__ void copy_vector<fp16, 1>(fp16 *dst,
const fp16 *src) {
*dst = *src;
}
template <>
__device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst,
const c10::Half *src) {
__device__ __inline__ void copy_vector<fp16, 4>(fp16 *dst,
const fp16 *src) {
*((float2*) dst) = *((float2*) src); // NOLINT(*)
}
......@@ -59,30 +63,23 @@ template <typename Datatype, int ELEMENTS_PER_LDG>
__device__ __inline__ void copy_zero_vector(Datatype *dst);
template <>
__device__ __inline__ void copy_zero_vector<c10::BFloat16, 1>(c10::BFloat16 *dst) {
*dst = 0.0;
__device__ __inline__ void copy_zero_vector<bf16, 1>(bf16 *dst) {
*dst = 0.0f;
}
template <>
__device__ __inline__ void copy_zero_vector<c10::BFloat16, 4>(c10::BFloat16 *dst) {
__device__ __inline__ void copy_zero_vector<bf16, 4>(bf16 *dst) {
*((float2*) dst) = make_float2(0.0f, 0.0f); // NOLINT(*)
}
template <>
__device__ __inline__ void copy_zero_vector<c10::Half, 1>(c10::Half *dst) { *dst = 0.0; }
__device__ __inline__ void copy_zero_vector<fp16, 1>(fp16 *dst) { *dst = 0.0f; }
template <>
__device__ __inline__ void copy_zero_vector<c10::Half, 4>(c10::Half *dst) {
__device__ __inline__ void copy_zero_vector<fp16, 4>(fp16 *dst) {
*((float2*) dst) = make_float2(0.0f, 0.0f); // NOLINT(*)
}
int log2_ceil(int value) {
int log2_value = 0;
while ((1 << log2_value) < value) ++log2_value;
return log2_value;
}
template<typename T>
struct Add {
__device__ __forceinline__ T operator()(T a, T b) const {
......@@ -136,8 +133,8 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_forward_kernel.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ?
next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_SIZE = (next_power_of_two < THREADS_PER_WARP) ?
next_power_of_two : THREADS_PER_WARP;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
......@@ -232,7 +229,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
if (element_index + element < local_seq) {
out[element] = elements[i][it + element] / sum[i];
} else {
out[element] = 0;
out[element] = 0.0f;
}
}
copy_vector<output_t, ELEMENTS_PER_LDG_STG>(dst + i * element_count * stride
......@@ -260,8 +257,8 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward(
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_backward_kernel.
constexpr int next_power_of_two = 1 << log2_elements;
constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ?
next_power_of_two : C10_WARP_SIZE;
constexpr int WARP_SIZE = (next_power_of_two < THREADS_PER_WARP) ?
next_power_of_two : THREADS_PER_WARP;
constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
......@@ -355,7 +352,6 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward(
}
}
} // end of anonymous namespace
template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_upper_triang_masked_softmax_forward(
......@@ -364,8 +360,9 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(
const input_t scale,
int softmax_elements,
int softmax_elements_stride,
int attn_batches) {
TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048);
int attn_batches,
cudaStream_t stream) {
NVTE_CHECK(softmax_elements >= 0 && softmax_elements <= 2048, "Unsupported shape.");
if (softmax_elements == 0) {
return;
} else {
......@@ -376,7 +373,8 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(
// This value must match the WARP_SIZE constexpr
// value computed inside softmax_warp_forward.
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
int warp_size = (next_power_of_two < THREADS_PER_WARP) ? next_power_of_two
: THREADS_PER_WARP;
// This value must match the WARP_BATCH constexpr
// value computed inside softmax_warp_forward.
......@@ -387,7 +385,7 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);
NVTE_CHECK(attn_batches % batches_per_block == 0, "Unsupported shape.");
int blocks_per_seq = attn_batches / batches_per_block;
dim3 blocks(seq_len, blocks_per_seq, 1);
......@@ -396,7 +394,7 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(
switch (log2_elements) {
case 0: // 1
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 0>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
<<<blocks, threads, 0, stream>>>(dst,
src,
scale,
batch_count,
......@@ -405,7 +403,7 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(
break;
case 1: // 2
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 1>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
<<<blocks, threads, 0, stream>>>(dst,
src,
scale,
batch_count,
......@@ -414,7 +412,7 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(
break;
case 2: // 4
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 2>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
<<<blocks, threads, 0, stream>>>(dst,
src,
scale,
batch_count,
......@@ -423,7 +421,7 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(
break;
case 3: // 8
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 3>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
<<<blocks, threads, 0, stream>>>(dst,
src,
scale,
batch_count,
......@@ -432,7 +430,7 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(
break;
case 4: // 16
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 4>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
<<<blocks, threads, 0, stream>>>(dst,
src,
scale,
batch_count,
......@@ -441,7 +439,7 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(
break;
case 5: // 32
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 5>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
<<<blocks, threads, 0, stream>>>(dst,
src,
scale,
batch_count,
......@@ -450,7 +448,7 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(
break;
case 6: // 64
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 6>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
<<<blocks, threads, 0, stream>>>(dst,
src,
scale,
batch_count,
......@@ -459,7 +457,7 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(
break;
case 7: // 128
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 7>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
<<<blocks, threads, 0, stream>>>(dst,
src,
scale,
batch_count,
......@@ -468,7 +466,7 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(
break;
case 8: // 256
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 8>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
<<<blocks, threads, 0, stream>>>(dst,
src,
scale,
batch_count,
......@@ -477,7 +475,7 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(
break;
case 9: // 512
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 9>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
<<<blocks, threads, 0, stream>>>(dst,
src,
scale,
batch_count,
......@@ -486,7 +484,7 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(
break;
case 10: // 1024
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 10>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
<<<blocks, threads, 0, stream>>>(dst,
src,
scale,
batch_count,
......@@ -495,7 +493,7 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(
break;
case 11: // 2048
scaled_upper_triang_masked_softmax_warp_forward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst,
<<<blocks, threads, 0, stream>>>(dst,
src,
scale,
batch_count,
......@@ -516,8 +514,9 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(
const acc_t scale,
int softmax_elements,
int softmax_elements_stride,
int attn_batches) {
TORCH_INTERNAL_ASSERT(softmax_elements >= 0 && softmax_elements <= 2048);
int attn_batches,
cudaStream_t stream) {
NVTE_CHECK(softmax_elements >= 0 && softmax_elements <= 2048, "Unsupported shape.");
if (softmax_elements == 0) {
return;
} else {
......@@ -528,7 +527,8 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(
// This value must match the WARP_SIZE constexpr
// value computed inside softmax_warp_backward.
int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
int warp_size = (next_power_of_two < THREADS_PER_WARP) ? next_power_of_two
: THREADS_PER_WARP;
// This value must match the WARP_BATCH constexpr
// value computed inside softmax_warp_backward.
......@@ -539,7 +539,7 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);
NVTE_CHECK(attn_batches % batches_per_block == 0, "Unsupported shape.");
int blocks_per_seq = attn_batches / batches_per_block;
dim3 blocks(seq_len, blocks_per_seq, 1);
......@@ -548,7 +548,7 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(
switch (log2_elements) {
case 0: // 1
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 0>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input,
<<<blocks, threads, 0, stream>>>(grad_input,
grad, output,
scale,
batch_count,
......@@ -557,7 +557,7 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(
break;
case 1: // 2
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 1>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input,
<<<blocks, threads, 0, stream>>>(grad_input,
grad, output,
scale,
batch_count,
......@@ -566,7 +566,7 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(
break;
case 2: // 4
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 2>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input,
<<<blocks, threads, 0, stream>>>(grad_input,
grad, output,
scale,
batch_count,
......@@ -575,7 +575,7 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(
break;
case 3: // 8
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 3>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input,
<<<blocks, threads, 0, stream>>>(grad_input,
grad, output,
scale,
batch_count,
......@@ -584,7 +584,7 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(
break;
case 4: // 16
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 4>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input,
<<<blocks, threads, 0, stream>>>(grad_input,
grad, output,
scale,
batch_count,
......@@ -593,7 +593,7 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(
break;
case 5: // 32
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 5>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input,
<<<blocks, threads, 0, stream>>>(grad_input,
grad, output,
scale,
batch_count,
......@@ -602,7 +602,7 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(
break;
case 6: // 64
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 6>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input,
<<<blocks, threads, 0, stream>>>(grad_input,
grad, output,
scale,
batch_count,
......@@ -611,7 +611,7 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(
break;
case 7: // 128
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 7>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input,
<<<blocks, threads, 0, stream>>>(grad_input,
grad, output,
scale,
batch_count,
......@@ -620,7 +620,7 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(
break;
case 8: // 256
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 8>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input,
<<<blocks, threads, 0, stream>>>(grad_input,
grad, output,
scale,
batch_count,
......@@ -629,7 +629,7 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(
break;
case 9: // 512
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 9>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input,
<<<blocks, threads, 0, stream>>>(grad_input,
grad, output,
scale,
batch_count,
......@@ -638,7 +638,7 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(
break;
case 10: // 1024
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 10>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input,
<<<blocks, threads, 0, stream>>>(grad_input,
grad, output,
scale,
batch_count,
......@@ -647,7 +647,7 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(
break;
case 11: // 2048
scaled_upper_triang_masked_softmax_warp_backward<input_t, output_t, acc_t, 11>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(grad_input,
<<<blocks, threads, 0, stream>>>(grad_input,
grad, output,
scale,
batch_count,
......@@ -660,4 +660,78 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(
}
}
#endif // TRANSFORMER_ENGINE_SCALED_UPPER_TRIANG_SOFTMAX_H_
void scaled_upper_triang_masked_softmax_forward(
const Tensor input,
Tensor *softmax_results,
float scale_factor,
cudaStream_t stream) {
const int attn_batches = input.shape[0];
const int seq_len = input.shape[1];
TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(input.dtype, softmax_type,
dispatch_scaled_upper_triang_masked_softmax_forward<softmax_type, softmax_type, float>(
reinterpret_cast<softmax_type*>(softmax_results->dptr),
reinterpret_cast<const softmax_type*>(input.dptr),
scale_factor,
seq_len,
seq_len,
attn_batches,
stream););
}
void scaled_upper_triang_masked_softmax_backward(
const Tensor output_grads,
const Tensor softmax_results,
float scale_factor,
cudaStream_t stream) {
const int attn_batches = output_grads.shape[0];
const int seq_len = output_grads.shape[1];
// Softmax Grad
TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(output_grads.dtype, softmax_type,
dispatch_scaled_upper_triang_masked_softmax_backward<softmax_type, softmax_type, float>(
reinterpret_cast<softmax_type*>(output_grads.dptr),
reinterpret_cast<softmax_type*>(output_grads.dptr),
reinterpret_cast<softmax_type const*>(softmax_results.dptr),
scale_factor,
seq_len,
seq_len,
attn_batches,
stream););
}
} // end namespace transformer_engine
void nvte_scaled_upper_triang_masked_softmax_forward(
const NVTETensor input,
NVTETensor softmax_results,
float scale_factor,
cudaStream_t stream
) {
using namespace transformer_engine;
scaled_upper_triang_masked_softmax_forward(
*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(softmax_results),
scale_factor,
stream);
}
void nvte_scaled_upper_triang_masked_softmax_backward(
const NVTETensor output_grads,
const NVTETensor softmax_results,
float scale_factor,
cudaStream_t stream
) {
using namespace transformer_engine;
scaled_upper_triang_masked_softmax_backward(
*reinterpret_cast<const Tensor*>(output_grads),
*reinterpret_cast<const Tensor*>(softmax_results),
scale_factor,
stream);
}
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_SOFTMAX_H_
#define TRANSFORMER_ENGINE_SOFTMAX_H_
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include "transformer_engine.h"
#ifdef __cplusplus
extern "C" {
#endif
void nvte_scaled_softmax_forward(
const NVTETensor input,
NVTETensor softmax_results,
float scale_factor,
cudaStream_t stream
);
void nvte_scaled_softmax_backward(
const NVTETensor output_grads,
const NVTETensor softmax_results,
float scale_factor,
cudaStream_t stream
);
void nvte_scaled_masked_softmax_forward(
const NVTETensor input,
const NVTETensor mask,
NVTETensor softmax_results,
float scale_factor,
cudaStream_t stream
);
void nvte_scaled_masked_softmax_backward(
const NVTETensor input,
NVTETensor softmax_results,
float scale_factor,
cudaStream_t stream
);
void nvte_scaled_upper_triang_masked_softmax_forward(
const NVTETensor input,
NVTETensor softmax_results,
float scale_factor,
cudaStream_t stream
);
void nvte_scaled_upper_triang_masked_softmax_backward(
const NVTETensor output_grads,
const NVTETensor softmax_results,
float scale_factor,
cudaStream_t stream
);
#ifdef __cplusplus
} // extern "C"
#endif
#endif // TRANSFORMER_ENGINE_SOFTMAX_H_
......@@ -14,6 +14,7 @@
#include <transformer_engine/logging.h>
#include <transformer_engine/transformer_engine.h>
#include <transformer_engine/cast.h>
#include <transformer_engine/softmax.h>
#include <ATen/ATen.h>
#include <ATen/cudnn/Handle.h>
#include <ATen/cuda/CUDAContext.h>
......
......@@ -483,8 +483,219 @@ at::Tensor cast_from_fp8(const at::Tensor &input,
}
at::Tensor scaled_softmax_forward(at::Tensor input,
float scale_factor
) {
using namespace transformer_engine;
AT_ASSERTM(input.dim() == 4, "expected 4D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
const int batches = input.size(0);
const int attn_heads = input.size(1);
const int query_seq_len = input.size(2);
const int key_seq_len = input.size(3);
TORCH_CHECK(key_seq_len <= 4096);
TORCH_CHECK(query_seq_len > 1);
// Output
auto act_options = input.options().requires_grad(false);
auto softmax_results =
torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options);
auto input_cu = makeTransformerEngineTensor(input);
auto softmax_results_cu = makeTransformerEngineTensor(softmax_results);
nvte_scaled_softmax_forward(input_cu.data(), softmax_results_cu.data(), scale_factor,
at::cuda::getCurrentCUDAStream());
return softmax_results;
}
at::Tensor scaled_softmax_backward(at::Tensor output_grad_,
at::Tensor softmax_results_,
float scale_factor
) {
using namespace transformer_engine;
auto output_grads = output_grad_.contiguous();
auto softmax_results = softmax_results_.contiguous();
AT_ASSERTM(output_grads.dim() == 4, "expected 4D tensor");
AT_ASSERTM(softmax_results.dim() == 4, "expected 4D tensor");
AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
auto output_grads_cu = makeTransformerEngineTensor(output_grads);
auto softmax_results_cu = makeTransformerEngineTensor(softmax_results);
nvte_scaled_softmax_backward(
output_grads_cu.data(), softmax_results_cu.data(),
scale_factor, at::cuda::getCurrentCUDAStream());
return output_grads;
}
at::Tensor scaled_masked_softmax_forward(at::Tensor input,
at::Tensor mask,
float scale_factor
) {
using namespace transformer_engine;
AT_ASSERTM(input.dim() == 4, "expected 4D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM(mask.dim() == 4, "expected 4D tensor");
const int batches = input.size(0);
const int pad_batches = mask.size(0);
const int attn_heads = input.size(1);
const int query_seq_len = input.size(2);
const int key_seq_len = input.size(3);
TORCH_CHECK(key_seq_len <= 4096);
TORCH_CHECK(query_seq_len > 1);
TORCH_CHECK(pad_batches == 1 || pad_batches == batches);
TORCH_CHECK(mask.size(1) == 1);
TORCH_CHECK(mask.size(2) == query_seq_len);
TORCH_CHECK(mask.size(3) == key_seq_len);
auto act_options = input.options().requires_grad(false);
auto softmax_results =
torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options);
auto input_cu = makeTransformerEngineTensor(input);
auto mask_cu = makeTransformerEngineTensor(mask);
auto softmax_results_cu = makeTransformerEngineTensor(softmax_results);
nvte_scaled_masked_softmax_forward(
input_cu.data(), mask_cu.data(), softmax_results_cu.data(),
scale_factor, at::cuda::getCurrentCUDAStream());
return softmax_results;
}
at::Tensor scaled_masked_softmax_backward(at::Tensor output_grad_,
at::Tensor softmax_results_,
float scale_factor
) {
using namespace transformer_engine;
auto output_grads = output_grad_.contiguous();
auto softmax_results = softmax_results_.contiguous();
AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor");
AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
auto output_grads_cu = makeTransformerEngineTensor(output_grads);
auto softmax_results_cu = makeTransformerEngineTensor(softmax_results);
nvte_scaled_softmax_backward(
output_grads_cu.data(), softmax_results_cu.data(),
scale_factor, at::cuda::getCurrentCUDAStream());
return output_grads;
}
at::Tensor scaled_upper_triang_masked_softmax_forward(at::Tensor input,
float scale_factor
) {
using namespace transformer_engine;
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
const int attn_batches = input.size(0);
const int seq_len = input.size(1);
TORCH_CHECK(seq_len <= 2048);
// Output
auto act_options = input.options().requires_grad(false);
auto softmax_results =
torch::empty({attn_batches, seq_len, seq_len}, act_options);
auto input_cu = makeTransformerEngineTensor(input);
auto softmax_results_cu = makeTransformerEngineTensor(softmax_results);
nvte_scaled_upper_triang_masked_softmax_forward(input_cu.data(),
softmax_results_cu.data(),
scale_factor,
at::cuda::getCurrentCUDAStream());
return softmax_results;
}
at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_,
at::Tensor softmax_results_,
float scale_factor
) {
using namespace transformer_engine;
auto output_grads = output_grads_.contiguous();
auto softmax_results = softmax_results_.contiguous();
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
TORCH_CHECK(output_grads.size(1) == output_grads.size(2));
auto output_grads_cu = makeTransformerEngineTensor(output_grads);
auto softmax_results_cu = makeTransformerEngineTensor(softmax_results);
nvte_scaled_upper_triang_masked_softmax_backward(output_grads_cu.data(),
softmax_results_cu.data(),
scale_factor,
at::cuda::getCurrentCUDAStream());
return output_grads;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// Granular functions
// Softmax functions
m.def("scaled_softmax_forward", &scaled_softmax_forward, "Scaled Softmax FWD");
m.def("scaled_softmax_backward", &scaled_softmax_backward, "Scaled Softmax BWD");
m.def("scaled_masked_softmax_forward", &scaled_masked_softmax_forward,
"Scaled Masked Softmax FWD");
m.def("scaled_masked_softmax_backward", &scaled_masked_softmax_backward,
"Scaled Masked Softmax BWD");
m.def("scaled_upper_triang_masked_softmax_forward",
&scaled_upper_triang_masked_softmax_forward,
"Scaled Upper-Triangular Masked Softmax FWD");
m.def("scaled_upper_triang_masked_softmax_backward",
&scaled_upper_triang_masked_softmax_backward,
"Scaled Upper-Triangular Masked Softmax BWD");
// Other granular functions
m.def("layernorm_fwd_fp8", &layernorm_fwd_fp8, "LN FWD FP8");
m.def("layernorm_bwd", &layernorm_bwd, "LN BWD");
m.def("layernorm_fwd", &layernorm_fwd, "LN FWD");
......
......@@ -116,3 +116,37 @@ at::Tensor cast_from_fp8(const at::Tensor &input,
transformer_engine::DType itype,
transformer_engine::DType otype
);
at::Tensor scaled_softmax_forward(at::Tensor input,
float scale_factor
);
at::Tensor scaled_softmax_backward(at::Tensor output_grad_,
at::Tensor softmax_results_,
float scale_factor
);
at::Tensor scaled_masked_softmax_forward(at::Tensor input,
at::Tensor mask,
float scale_factor
);
at::Tensor scaled_masked_softmax_backward(at::Tensor output_grad_,
at::Tensor softmax_results_,
float scale_factor
);
at::Tensor scaled_upper_triang_masked_softmax_forward(at::Tensor input,
float scale_factor
);
at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_,
at::Tensor softmax_results_,
float scale_factor
);
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TORCH_CHECK
#define TORCH_CHECK AT_CHECK
#endif
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <vector>
namespace transformer_engine {
namespace scaled_masked_softmax {
torch::Tensor fwd_cuda(
torch::Tensor const& input,
torch::Tensor const& mask,
float scale_factor);
torch::Tensor bwd_cuda(
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor);
int get_batch_per_block_cuda(
int query_seq_len,
int key_seq_len,
int batches,
int attn_heads);
torch::Tensor fwd(
torch::Tensor const& input,
torch::Tensor const& mask,
float scale_factor) {
AT_ASSERTM(input.dim() == 4, "expected 4D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM(mask.dim() == 4, "expected 4D tensor");
return fwd_cuda(input, mask, scale_factor);
}
torch::Tensor bwd(
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor) {
AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor");
AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
return bwd_cuda(output_grads, softmax_results, scale_factor);
}
int get_batch_per_block(
int query_seq_len,
int key_seq_len,
int batches,
int attn_heads) {
return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, attn_heads);
}
} // end namespace scaled_masked_softmax
} // end namespace transformer_engine
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward",
&transformer_engine::scaled_masked_softmax::fwd,
"Self Multihead Attention scaled, time masked softmax -- Forward.");
m.def("backward",
&transformer_engine::scaled_masked_softmax::bwd,
"Self Multihead Attention scaled, time masked softmax -- Backward.");
m.def("get_batch_per_block",
&transformer_engine::scaled_masked_softmax::get_batch_per_block,
"Return Batch per block size.");
}
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "scaled_masked_softmax.h"
#include "type_shim.h"
namespace transformer_engine {
namespace scaled_masked_softmax {
int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, int attn_heads) {
return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads);
}
torch::Tensor fwd_cuda(
torch::Tensor const& input,
torch::Tensor const& mask,
float scale_factor) {
// input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const int batches = input.size(0);
const int pad_batches = mask.size(0);
const int attn_heads = input.size(1);
const int query_seq_len = input.size(2);
const int key_seq_len = input.size(3);
TORCH_INTERNAL_ASSERT(key_seq_len <= 4096);
TORCH_INTERNAL_ASSERT(query_seq_len > 1);
TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches);
TORCH_INTERNAL_ASSERT(mask.size(1) == 1);
TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len);
TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len);
// Output
auto act_options = input.options().requires_grad(false);
torch::Tensor softmax_results =
torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options);
// Softmax Intermediate Result Ptr
void* input_ptr = static_cast<void*>(input.data_ptr());
void* mask_ptr = static_cast<void*>(mask.data_ptr());
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
DISPATCH_HALF_AND_BFLOAT(
input.scalar_type(),
"dispatch_scaled_masked_softmax_forward",
dispatch_scaled_masked_softmax_forward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(softmax_results_ptr),
reinterpret_cast<const scalar_t*>(input_ptr),
reinterpret_cast<const uint8_t*>(mask_ptr),
scale_factor,
query_seq_len,
key_seq_len,
batches,
attn_heads,
pad_batches););
return softmax_results;
}
torch::Tensor bwd_cuda(
torch::Tensor const& output_grads_,
torch::Tensor const& softmax_results_,
float scale_factor) {
auto output_grads = output_grads_.contiguous();
auto softmax_results = softmax_results_.contiguous();
// output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const int batches = output_grads.size(0);
const int attn_heads = output_grads.size(1);
const int query_seq_len = output_grads.size(2);
const int key_seq_len = output_grads.size(3);
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
// Softmax Grad
DISPATCH_HALF_AND_BFLOAT(
output_grads_.scalar_type(),
"dispatch_scaled_masked_softmax_backward",
dispatch_scaled_masked_softmax_backward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
scale_factor,
query_seq_len,
key_seq_len,
batches,
attn_heads););
// backward pass is completely in-place
return output_grads;
}
} // end namespace scaled_masked_softmax
} // end namespace transformer_engine
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <vector>
namespace transformer_engine {
namespace scaled_softmax {
torch::Tensor fwd_cuda(
torch::Tensor const& input,
float scale_factor);
torch::Tensor bwd_cuda(
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor);
torch::Tensor fwd(
torch::Tensor const& input,
float scale_factor) {
AT_ASSERTM(input.dim() == 4, "expected 4D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
return fwd_cuda(input, scale_factor);
}
torch::Tensor bwd(
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor) {
AT_ASSERTM(output_grads.dim() == 4, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 4, "expected 3D tensor");
AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
return bwd_cuda(output_grads, softmax_results, scale_factor);
}
} // end namespace scaled_softmax
} // end namespace transformer_engine
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward",
&transformer_engine::scaled_softmax::fwd,
"Self Multihead Attention scaled, softmax -- Forward.");
m.def("backward",
&transformer_engine::scaled_softmax::bwd,
"Self Multihead Attention scaled, softmax -- Backward.");
}
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "scaled_masked_softmax.h"
#include "type_shim.h"
namespace transformer_engine {
namespace scaled_softmax {
torch::Tensor fwd_cuda(
torch::Tensor const& input,
float scale_factor) {
// input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const int batches = input.size(0);
const int attn_heads = input.size(1);
const int query_seq_len = input.size(2);
const int key_seq_len = input.size(3);
TORCH_INTERNAL_ASSERT(key_seq_len <= 4096);
TORCH_INTERNAL_ASSERT(query_seq_len > 1);
// Output
auto act_options = input.options().requires_grad(false);
torch::Tensor softmax_results =
torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options);
// Softmax Intermediate Result Ptr
void* input_ptr = static_cast<void*>(input.data_ptr());
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
DISPATCH_HALF_AND_BFLOAT(
input.scalar_type(),
"dispatch_scaled_softmax_forward",
dispatch_scaled_softmax_forward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(softmax_results_ptr),
reinterpret_cast<const scalar_t*>(input_ptr),
scale_factor,
query_seq_len,
key_seq_len,
batches,
attn_heads););
return softmax_results;
}
torch::Tensor bwd_cuda(
torch::Tensor const& output_grads_,
torch::Tensor const& softmax_results_,
float scale_factor) {
auto output_grads = output_grads_.contiguous();
auto softmax_results = softmax_results_.contiguous();
// output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const int batches = output_grads.size(0);
const int attn_heads = output_grads.size(1);
const int query_seq_len = output_grads.size(2);
const int key_seq_len = output_grads.size(3);
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
// Softmax Grad
DISPATCH_HALF_AND_BFLOAT(
output_grads_.scalar_type(),
"dispatch_scaled_masked_softmax_backward",
dispatch_scaled_masked_softmax_backward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
scale_factor,
query_seq_len,
key_seq_len,
batches,
attn_heads););
// backward pass is completely in-place
return output_grads;
}
} // end namespace scaled_softmax
} // end namespace transformer_engine
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <vector>
namespace transformer_engine {
namespace scaled_upper_triang_masked_softmax {
torch::Tensor fwd_cuda(
torch::Tensor const& input,
float scale_factor);
torch::Tensor bwd_cuda(
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor);
torch::Tensor fwd(torch::Tensor const& input, float scale_factor) {
AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
(input.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
return fwd_cuda(input, scale_factor);
}
torch::Tensor bwd(
torch::Tensor const& output_grads,
torch::Tensor const& softmax_results,
float scale_factor) {
AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");
AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
(output_grads.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
(softmax_results.scalar_type() == at::ScalarType::BFloat16),
"Only fp16 and bf16 are supported");
return bwd_cuda(output_grads, softmax_results, scale_factor);
}
} // end namespace scaled_upper_triang_masked_softmax
} // end namespace transformer_engine
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward",
&transformer_engine::scaled_upper_triang_masked_softmax::fwd,
"Self Multihead Attention scaled, time masked softmax -- Forward.");
m.def("backward",
&transformer_engine::scaled_upper_triang_masked_softmax::bwd,
"Self Multihead Attention scaled, time masked softmax -- Backward.");
}
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "scaled_upper_triang_masked_softmax.h"
#include "type_shim.h"
namespace transformer_engine {
namespace scaled_upper_triang_masked_softmax {
torch::Tensor fwd_cuda(
torch::Tensor const& input,
float scale_factor) {
// input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
const int attn_batches = input.size(0);
const int seq_len = input.size(1);
TORCH_INTERNAL_ASSERT(seq_len <= 2048);
// Output
auto act_options = input.options().requires_grad(false);
torch::Tensor softmax_results =
torch::empty({attn_batches, seq_len, seq_len}, act_options);
// Softmax Intermediate Result Ptr
void* input_ptr = static_cast<void*>(input.data_ptr());
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
DISPATCH_HALF_AND_BFLOAT(
input.scalar_type(),
"dispatch_scaled_upper_triang_masked_softmax_forward",
dispatch_scaled_upper_triang_masked_softmax_forward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(softmax_results_ptr),
reinterpret_cast<const scalar_t*>(input_ptr),
scale_factor,
seq_len,
seq_len,
attn_batches););
return softmax_results;
}
torch::Tensor bwd_cuda(
torch::Tensor const& output_grads_,
torch::Tensor const& softmax_results_,
float scale_factor) {
auto output_grads = output_grads_.contiguous();
auto softmax_results = softmax_results_.contiguous();
// output grads is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
const int attn_batches = output_grads.size(0);
const int seq_len = output_grads.size(1);
TORCH_INTERNAL_ASSERT(output_grads.size(1) == output_grads.size(2));
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
// Softmax Grad
DISPATCH_HALF_AND_BFLOAT(
output_grads_.scalar_type(),
"dispatch_scaled_upper_triang_masked_softmax_backward",
dispatch_scaled_upper_triang_masked_softmax_backward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
scale_factor,
seq_len,
seq_len,
attn_batches););
// backward pass is completely in-place
return output_grads;
}
} // end namespace scaled_upper_triang_masked_softmax
} // end namespace transformer_engine
/*************************************************************************
* Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <ATen/ATen.h>
#include "compat.h"
#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \
switch (TYPE) \
{ \
case at::ScalarType::Half: \
{ \
using scalar_t = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_HALF_BFLOAT_AND_FLOAT(TYPE, NAME, ...) \
switch (TYPE) \
{ \
case at::ScalarType::Half: \
{ \
using scalar_t = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t = float; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
switch (TYPEIN) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_in = float; \
switch (TYPEOUT) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_out = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
} \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_in = at::Half; \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_in = at::BFloat16; \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
}
......@@ -9,6 +9,11 @@ from typing import Callable, Tuple, Union
import torch
from torch import nn
import transformer_engine_extensions as tex
THREADS_PER_WARP = 32
THREADS_PER_BLOCK = 128
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
"""
......@@ -21,10 +26,8 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
@staticmethod
def forward(ctx, inputs: torch.Tensor, scale: float) -> torch.Tensor:
"""ScaledUpperTriangMaskedSoftmax fwd"""
import scaled_upper_triang_masked_softmax_cuda
scale_t = torch.tensor([scale])
softmax_results = scaled_upper_triang_masked_softmax_cuda.forward(
softmax_results = tex.scaled_upper_triang_masked_softmax_forward(
inputs, scale_t[0]
)
......@@ -36,10 +39,8 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
ctx, output_grads: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]:
"""ScaledUpperTriangMaskedSoftmax bwd"""
import scaled_upper_triang_masked_softmax_cuda
softmax_results, scale_t = ctx.saved_tensors
input_grads = scaled_upper_triang_masked_softmax_cuda.backward(
input_grads = tex.scaled_upper_triang_masked_softmax_backward(
output_grads, softmax_results, scale_t[0]
)
......@@ -59,11 +60,9 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
ctx, inputs: torch.Tensor, mask: torch.Tensor, scale: float
) -> torch.Tensor:
"""ScaledMaskedSoftmax fwd"""
import scaled_masked_softmax_cuda
scale_t = torch.tensor([scale])
softmax_results = scaled_masked_softmax_cuda.forward(inputs, mask, scale_t[0])
softmax_results = tex.scaled_masked_softmax_forward(inputs, mask, scale_t[0])
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
......@@ -72,11 +71,9 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
ctx, output_grads: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]:
"""ScaledMaskedSoftmax bwd"""
import scaled_masked_softmax_cuda
softmax_results, scale_t = ctx.saved_tensors
input_grads = scaled_masked_softmax_cuda.backward(
input_grads = tex.scaled_masked_softmax_backward(
output_grads, softmax_results, scale_t[0]
)
return input_grads, None, None
......@@ -92,11 +89,9 @@ class ScaledSoftmax(torch.autograd.Function):
@staticmethod
def forward(ctx, inputs: torch.Tensor, scale: float) -> torch.Tensor:
"""ScaledSoftmax fwd"""
import scaled_softmax_cuda
scale_t = torch.tensor([scale])
softmax_results = scaled_softmax_cuda.forward(inputs, scale_t[0])
softmax_results = tex.scaled_softmax_forward(inputs, scale_t[0])
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
......@@ -105,11 +100,9 @@ class ScaledSoftmax(torch.autograd.Function):
ctx, output_grads: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]:
"""ScaledSoftmax bwd"""
import scaled_softmax_cuda
softmax_results, scale_t = ctx.saved_tensors
input_grads = scaled_softmax_cuda.backward(
input_grads = tex.scaled_softmax_backward(
output_grads, softmax_results, scale_t[0]
)
return input_grads, None, None
......@@ -170,7 +163,7 @@ class FusedScaleMaskSoftmax(nn.Module):
and attn_batches % 4 == 0 # np * b must be divisor of 4
):
if 0 <= sk <= 4096:
batch_per_block = self.get_batch_per_block(sq, sk, b, np)
batch_per_block = self.get_batch_per_block(sk)
if self.attn_mask_type == "causal":
if attn_batches % batch_per_block == 0:
......@@ -220,8 +213,11 @@ class FusedScaleMaskSoftmax(nn.Module):
return probs
@staticmethod
def get_batch_per_block(sq: int, sk: int, b: int, np: int) -> int:
def get_batch_per_block(key_seq_len: int) -> int:
"""Softmax utility"""
import scaled_masked_softmax_cuda
return scaled_masked_softmax_cuda.get_batch_per_block(sq, sk, b, np)
pow2 = 1 << (key_seq_len - 1).bit_length()
warp_size = pow2 if pow2 < THREADS_PER_WARP else THREADS_PER_WARP
batches_per_warp = 2 if pow2 <= 128 else 1
warps_per_block = THREADS_PER_BLOCK / warp_size
batches_per_block = warps_per_block * batches_per_warp
return batches_per_block
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment