Unverified Commit e0e3d123 authored by vasunvidia's avatar vasunvidia Committed by GitHub
Browse files

Dropout with 8-bit RNG (#2014)



* Add dropout kernel with 8-bit RNG
Co-authored-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix license
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Avoid ambiguous types
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Do not enforce dropout prob is representable in 8 bits
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Expand error message
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix small statistical bug from using less-equal instead of less-than

Refactor kernel implementations and add comments. Interpret masks as bytes rather than 16-bit uints.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix linter warning
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Remove unnecessary helper function in PyTorch extensions
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 607fcc43
...@@ -1749,25 +1749,44 @@ class TestBasicOps: ...@@ -1749,25 +1749,44 @@ class TestBasicOps:
torch.testing.assert_close(y_test, y_ref, **tols) torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols) torch.testing.assert_close(dx_test, x_ref.grad, **tols)
@pytest.mark.parametrize("prob", (0.1, 0.5, 0.75)) @pytest.mark.parametrize("prob", (0.0625, 0.5, 0.75))
@pytest.mark.parametrize("is_training", (True, False)) @pytest.mark.parametrize("is_training", (True, False))
@pytest.mark.parametrize("shape", ((101,), (2, 4, 16))) @pytest.mark.parametrize("quantization", (None, "fp8_current_scaling"))
@pytest.mark.parametrize("shape", ((101,), (2, 4, 16), (128, 128)))
@pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("dtype", _dtypes)
def test_dropout( def test_dropout(
self, self,
*, *,
prob: float, prob: float,
is_training: bool, is_training: bool,
quantization: Optional[str],
shape: Iterable[int], shape: Iterable[int],
dtype: torch.dtype, dtype: torch.dtype,
device: torch.device = "cuda", device: torch.device = "cuda",
): ):
# Skip invalid configurations
quantized_input = quantization is not None
maybe_skip_quantization(quantization, dims=shape, device=device)
# Random data # Random data
x_ref = torch.rand(shape, dtype=dtype, device=device) + 0.5 # Note: Shift values to make sure inputs are non-zero
x_test = x_ref.clone().requires_grad_() x_ref, x_test = make_reference_and_test_tensors(
dy_ref = torch.rand(shape, dtype=dtype, device=device) + 0.5 shape,
dy_test = dy_ref.clone() quantization=quantization,
test_dtype=dtype,
test_device=device,
test_is_quantized=quantized_input,
)
with torch.no_grad():
x_test += 1
x_ref.copy_(x_test)
dy_ref, dy_test = make_reference_and_test_tensors(
shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Apply dropout # Apply dropout
op = te_ops.Dropout(prob) op = te_ops.Dropout(prob)
...@@ -1775,17 +1794,20 @@ class TestBasicOps: ...@@ -1775,17 +1794,20 @@ class TestBasicOps:
op.train() op.train()
else: else:
op.eval() op.eval()
y = op(x_test) y_test = op(x_test)
y.backward(dy_test) y_test.backward(dy_test)
# Check values # Check values
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
if is_training: if is_training:
mask = ((y != 0) / (1 - prob)).to(dtype=dtype) tols = dtype_tols(dtype)
torch.testing.assert_close(y, x_ref * mask) mask = ((y_test != 0) / (1 - prob)).to(dtype=dtype)
torch.testing.assert_close(x_test.grad, dy_ref * mask) torch.testing.assert_close(y_test, x_ref * mask, **tols)
torch.testing.assert_close(dx_test, dy_ref * mask, **tols)
else: else:
torch.testing.assert_close(y, x_ref, rtol=0, atol=0) torch.testing.assert_close(y_test, x_ref, rtol=0, atol=0)
torch.testing.assert_close(x_test.grad, dy_ref, rtol=0, atol=0) torch.testing.assert_close(dx_test, dy_ref, rtol=0, atol=0)
# Hypothesis testing for number of zeros # Hypothesis testing for number of zeros
# Note: A Bernoulli random variable with probability p has # Note: A Bernoulli random variable with probability p has
...@@ -1797,9 +1819,11 @@ class TestBasicOps: ...@@ -1797,9 +1819,11 @@ class TestBasicOps:
# p-value is less than 1% and we assume that the dropout # p-value is less than 1% and we assume that the dropout
# distribution is incorrect. # distribution is incorrect.
if is_training: if is_training:
prob_observed = 1 - torch.count_nonzero(y).item() / y.numel() prob_observed = 1 - torch.count_nonzero(y_test).item() / y_test.numel()
z_score = (prob_observed - prob) / math.sqrt(prob * (1 - prob) / y.numel()) z_score = (prob_observed - prob) / math.sqrt(prob * (1 - prob) / y_test.numel())
assert abs(z_score) < 2.5758, "Number of zeros is outside 99% confidence interval" assert (
abs(z_score) < 2.5758
), f"Number of zeros is outside 99% confidence interval ({prob=}, {prob_observed=})"
class TestFusedOps: class TestFusedOps:
......
...@@ -69,6 +69,7 @@ list(APPEND transformer_engine_SOURCES ...@@ -69,6 +69,7 @@ list(APPEND transformer_engine_SOURCES
transpose/quantize_transpose_vector_blockwise.cu transpose/quantize_transpose_vector_blockwise.cu
transpose/swap_first_dims.cu transpose/swap_first_dims.cu
activation/gelu.cu activation/gelu.cu
dropout/dropout.cu
fused_attn/flash_attn.cu fused_attn/flash_attn.cu
fused_attn/context_parallel.cu fused_attn/context_parallel.cu
fused_attn/kv_cache.cu fused_attn/kv_cache.cu
......
...@@ -294,6 +294,38 @@ def _load_nvrtc(): ...@@ -294,6 +294,38 @@ def _load_nvrtc():
return ctypes.CDLL(f"libnvrtc{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL) return ctypes.CDLL(f"libnvrtc{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)
@functools.lru_cache(maxsize=None)
def _load_curand():
"""Load cuRAND shared library."""
# Attempt to locate cuRAND in CUDA_HOME, CUDA_PATH or /usr/local/cuda
cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") or "/usr/local/cuda"
libs = glob.glob(f"{cuda_home}/**/libcurand{_get_sys_extension()}*", recursive=True)
libs = list(filter(lambda x: not ("stub" in x), libs))
libs.sort(reverse=True, key=os.path.basename)
if libs:
return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)
# Attempt to locate cuRAND in Python dist-packages
found, handle = _load_nvidia_cuda_library("curand")
if found:
return handle
# Attempt to locate cuRAND via ldconfig
libs = subprocess.check_output(
f"ldconfig -p | grep 'libcurand{_get_sys_extension()}'", shell=True
)
libs = libs.decode("utf-8").split("\n")
sos = []
for lib in libs:
if "libcurand" in lib and "=>" in lib:
sos.append(lib.split(">")[1].strip())
if sos:
return ctypes.CDLL(sos[0], mode=ctypes.RTLD_GLOBAL)
# If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise
return ctypes.CDLL(f"libcurand{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
def _load_core_library(): def _load_core_library():
"""Load shared library with Transformer Engine C extensions""" """Load shared library with Transformer Engine C extensions"""
...@@ -303,6 +335,7 @@ def _load_core_library(): ...@@ -303,6 +335,7 @@ def _load_core_library():
if "NVTE_PROJECT_BUILDING" not in os.environ or bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): if "NVTE_PROJECT_BUILDING" not in os.environ or bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
_CUDNN_LIB_CTYPES = _load_cudnn() _CUDNN_LIB_CTYPES = _load_cudnn()
_NVRTC_LIB_CTYPES = _load_nvrtc() _NVRTC_LIB_CTYPES = _load_nvrtc()
_CURAND_LIB_CTYPES = _load_curand()
_CUBLAS_LIB_CTYPES = _load_nvidia_cuda_library("cublas") _CUBLAS_LIB_CTYPES = _load_nvidia_cuda_library("cublas")
_CUDART_LIB_CTYPES = _load_nvidia_cuda_library("cuda_runtime") _CUDART_LIB_CTYPES = _load_nvidia_cuda_library("cuda_runtime")
_TE_LIB_CTYPES = _load_core_library() _TE_LIB_CTYPES = _load_core_library()
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <curand.h>
#include <curand_kernel.h>
#include <curand_philox4x32_x.h>
#include <cmath>
#include "../common.h"
#include "../utils.cuh"
#include "transformer_engine/dropout.h"
namespace transformer_engine {
namespace {
// RNG kernels process chunks of 16 entries
constexpr size_t rng_chunk_size = 16;
// CUDA block size
constexpr size_t block_size = 128;
// Vector class to help with vectorized memory accesses
template <typename T, size_t kSize>
union Vector {
using StorageType = typename BytesToType<sizeof(T) * kSize>::Type;
StorageType storage;
T entries[kSize];
};
/* Byte-wise less-than comparison
*
* Results are stored in each byte's most-significant bit (MSB). All
* other bits are zero.
*/
__device__ __forceinline__ uint32_t bytewise_less_than(uint32_t a, uint32_t b) {
// Compare low bits by masking MSBs and subtracting. The resulting
// MSBs are 0 if the low bits of a are less than the low bits of b.
uint32_t result = (a | 0x80808080) - (b & 0x7F7F7F7F);
// Bitwise logical op to get answer in MSBs
// Equivalent logic: result = (a == b) ? !result : b
asm("lop3.b32 %0, %1, %2, %3, 0x4D;\n\t" : "=r"(result) : "r"(a), "r"(b), "r"(result));
// Mask out everything except MSBs and return
result &= 0x80808080;
return result;
}
/* Generate dropout mask with 16 bits.
*
* 1 corresponds to keep and 0 to drop.
*
* Consumes 4 values from cuRAND Philox generator.
*/
__device__ __forceinline__ uint16_t make_16bit_mask(uint64_t chunk_idx, uint64_t rng_seed,
uint64_t rng_offset,
uint32_t bytewise_drop_prob) {
// Generate random bits
curandStatePhilox4_32_10_t state;
curand_init(rng_seed, chunk_idx, rng_offset, &state);
const uint4 rand_bits = curand4(&state);
// Compute mask
// Note: bytewise_less_than fills MSBs (bits 7, 15, 23, 31). By
// shifting 2 bits after every call, every other bit will be filled.
uint32_t result = bytewise_less_than(rand_bits.x, bytewise_drop_prob);
result = (result >> 2) | bytewise_less_than(rand_bits.y, bytewise_drop_prob);
result = (result >> 2) | bytewise_less_than(rand_bits.z, bytewise_drop_prob);
result = (result >> 2) | bytewise_less_than(rand_bits.w, bytewise_drop_prob);
// Consolidate mask in lowest 16 bits
result |= result >> 17;
// Flip bits so 0 corresponds to drop
result = ~result;
return result;
}
// Dropout forward with FP16/BF16 input and output.
template <typename T>
__global__ void __launch_bounds__(block_size)
dropout_kernel_fwd_f16(const T *__restrict__ input_ptr, T *__restrict__ output_ptr,
uint8_t *__restrict__ mask_ptr,
const uint64_t *__restrict__ rng_state_ptr, size_t num_chunks,
uint32_t bytewise_drop_prob, float scale) {
static_assert(sizeof(T) == 2);
// Each thread processes a chunk of 16 entries
const size_t gid = threadIdx.x + blockIdx.x * block_size;
const size_t nthreads = gridDim.x * block_size;
for (size_t chunk_idx = gid; chunk_idx < num_chunks; chunk_idx += nthreads) {
// Generate dropout mask
auto local_mask =
make_16bit_mask(chunk_idx, rng_state_ptr[0], rng_state_ptr[1], bytewise_drop_prob);
reinterpret_cast<uint16_t *>(mask_ptr)[chunk_idx] = local_mask;
// Read input data
using VectorType = Vector<T, rng_chunk_size>;
VectorType local_data;
local_data = reinterpret_cast<const VectorType *>(input_ptr)[chunk_idx];
// Apply dropout based on mask
#pragma unroll
for (size_t i = 0; i < rng_chunk_size; i++) {
float val = static_cast<float>(local_data.entries[i]);
if ((local_mask & 0x1) == 0) {
val = 0;
}
val *= scale;
local_data.entries[i] = static_cast<T>(val);
local_mask >>= 1;
}
// Write output data
reinterpret_cast<VectorType *>(output_ptr)[chunk_idx] = local_data;
}
}
// Dropout forward with FP8 input and FP16/BF16 output.
template <typename InputType, typename OutputType>
__global__ void __launch_bounds__(block_size)
dropout_kernel_fwd_fp8(const InputType *__restrict__ input_ptr,
const float *__restrict__ input_scale_inv_ptr,
OutputType *__restrict__ output_ptr, uint8_t *__restrict__ mask_ptr,
const uint64_t *__restrict__ rng_state_ptr, size_t num_chunks,
uint32_t bytewise_drop_prob, float scale) {
static_assert(sizeof(InputType) == 1);
static_assert(sizeof(OutputType) == 2);
const float input_scale_inv = *input_scale_inv_ptr;
// Each thread processes a chunk of 16 entries
const size_t gid = threadIdx.x + blockIdx.x * block_size;
const size_t nthreads = gridDim.x * block_size;
for (size_t chunk_idx = gid; chunk_idx < num_chunks; chunk_idx += nthreads) {
// Generate dropout mask
auto local_mask =
make_16bit_mask(chunk_idx, rng_state_ptr[0], rng_state_ptr[1], bytewise_drop_prob);
reinterpret_cast<uint16_t *>(mask_ptr)[chunk_idx] = local_mask;
// Read input data
using InputVectorType = Vector<InputType, rng_chunk_size>;
InputVectorType local_input;
local_input = reinterpret_cast<const InputVectorType *>(input_ptr)[chunk_idx];
// Apply dropout based on mask
using OutputVectorType = Vector<OutputType, rng_chunk_size>;
OutputVectorType local_output;
#pragma unroll
for (size_t i = 0; i < rng_chunk_size; i++) {
float val = static_cast<float>(local_input.entries[i]);
val *= input_scale_inv;
if ((local_mask & 0x1) == 0) {
val = 0;
}
val *= scale;
local_output.entries[i] = static_cast<OutputType>(val);
local_mask >>= 1;
}
// Write output data
reinterpret_cast<OutputVectorType *>(output_ptr)[chunk_idx] = local_output;
}
}
// Apply dropout mask and scale.
template <typename T>
__global__ void __launch_bounds__(block_size)
apply_dropout_mask(const T *__restrict__ input_ptr, const uint8_t *__restrict__ mask_ptr,
T *__restrict__ output_ptr, size_t num_chunks, float scale) {
// Each thread processes a chunk of 8 entries.
const size_t gid = threadIdx.x + blockIdx.x * block_size;
const size_t nthreads = gridDim.x * block_size;
constexpr size_t chunk_size = 8;
for (size_t chunk_idx = gid; chunk_idx < num_chunks; chunk_idx += nthreads) {
// Read dropout mask
uint8_t local_mask = mask_ptr[chunk_idx];
// Read input data
using VectorType = Vector<T, chunk_size>;
VectorType local_data;
local_data = reinterpret_cast<const VectorType *>(input_ptr)[chunk_idx];
// Apply dropout based on mask
#pragma unroll
for (size_t i = 0; i < chunk_size; i++) {
float val = static_cast<float>(local_data.entries[i]);
if ((local_mask & 0x1) == 0) {
val = 0;
}
val *= scale;
local_data.entries[i] = static_cast<T>(val);
local_mask >>= 1;
}
// Write output data
reinterpret_cast<VectorType *>(output_ptr)[chunk_idx] = local_data;
}
}
} // namespace
void dropout_fwd(const Tensor &input, Tensor &output, Tensor &mask, Tensor &rng_state,
float dropout_probability, cudaStream_t stream) {
// Check tensors
const size_t numel = input.numel();
NVTE_CHECK(input.scaling_mode == NVTE_DELAYED_TENSOR_SCALING,
"Input tensor must be FP16/BF16 tensor or tensor-scaled FP8 tensor, ",
"but scaling mode is ", to_string(input.scaling_mode), ".");
NVTE_CHECK(output.scaling_mode == NVTE_DELAYED_TENSOR_SCALING,
"Output tensor must be FP16/BF16 tensor, ", "but scaling mode is ",
to_string(output.scaling_mode), ".");
NVTE_CHECK(mask.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, "Mask tensor must be plain tensor, ",
"but scaling mode is ", to_string(mask.scaling_mode), ".");
NVTE_CHECK(rng_state.scaling_mode == NVTE_DELAYED_TENSOR_SCALING,
"RNG state tensor must be INT64 tensor with two entries, ", "but scaling mode is ",
to_string(rng_state.scaling_mode), ".");
NVTE_CHECK(output.dtype() == DType::kFloat16 || output.dtype() == DType::kBFloat16,
"Output tensor must be FP16/BF16 tensor, but dtype is ", to_string(output.dtype()),
".");
NVTE_CHECK(rng_state.dtype() == DType::kInt64,
"RNG state tensor must be INT64 tensor with two entries, but dtype is ",
to_string(rng_state.dtype()), ".");
NVTE_CHECK(numel % 16 == 0,
"Input tensor number of elements must be divisible by 16, but shape is ",
input.shape(), ".");
NVTE_CHECK(numel == output.numel(), "Input tensor (shape=", input.shape(),
") and output tensor (shape=", output.shape(), ") do not match.");
NVTE_CHECK(typeToNumBits(mask.dtype()) * mask.numel() == numel, "Mask tensor must have ", numel,
" bits, but found dtype=", to_string(mask.dtype()), " and shape=", mask.shape(), ".");
NVTE_CHECK(rng_state.numel() == 2, "RNG state tensor must be INT64 tensor with two entries, ",
"but shape is ", rng_state.shape(), ".");
NVTE_CHECK(input.data.dptr != nullptr, "Input tensor is missing data.");
NVTE_CHECK(output.data.dptr != nullptr, "Output tensor is missing data.");
NVTE_CHECK(mask.data.dptr != nullptr, "Mask tensor is missing data.");
NVTE_CHECK(rng_state.data.dptr != nullptr, "RNG state tensor is missing data.");
// Convert dropout probablity to scale and 8-bit representation
NVTE_CHECK(dropout_probability >= 0 && dropout_probability < 1, "Invalid dropout probability (",
dropout_probability, ").");
const float scale = 1 / (1 - dropout_probability);
uint32_t bytewise_drop_prob = static_cast<uint32_t>(std::floor(dropout_probability * 256));
bytewise_drop_prob |= bytewise_drop_prob << 8;
bytewise_drop_prob |= bytewise_drop_prob << 16;
// CUDA config
const size_t num_chunks = numel / rng_chunk_size;
const size_t num_blocks = DIVUP(num_chunks, block_size);
// Launch kernel depending on input dtype
if (input.dtype() == DType::kFloat16 || input.dtype() == DType::kBFloat16) {
NVTE_CHECK(input.dtype() == output.dtype(), "Input tensor (dtype=", to_string(input.dtype()),
") and output tensor (dtype=", to_string(output.dtype()), ") do not match.");
TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(
input.dtype(), DType,
dropout_kernel_fwd_f16<DType><<<num_blocks, block_size, 0, stream>>>(
reinterpret_cast<const DType *>(input.data.dptr),
reinterpret_cast<DType *>(output.data.dptr),
reinterpret_cast<uint8_t *>(mask.data.dptr),
reinterpret_cast<const uint64_t *>(rng_state.data.dptr), num_chunks, bytewise_drop_prob,
scale););
NVTE_CHECK_CUDA(cudaGetLastError());
} else if (input.dtype() == DType::kFloat8E4M3 || input.dtype() == DType::kFloat8E5M2) {
NVTE_CHECK(input.scale_inv.dptr != nullptr, "Input tensor scale-inverse is not allocated.");
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
input.dtype(), InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(
output.dtype(), OutputType,
dropout_kernel_fwd_fp8<InputType, OutputType><<<num_blocks, block_size, 0, stream>>>(
reinterpret_cast<const InputType *>(input.data.dptr),
reinterpret_cast<const float *>(input.scale_inv.dptr),
reinterpret_cast<OutputType *>(output.data.dptr),
reinterpret_cast<uint8_t *>(mask.data.dptr),
reinterpret_cast<const uint64_t *>(rng_state.data.dptr), num_chunks,
bytewise_drop_prob, scale);
););
NVTE_CHECK_CUDA(cudaGetLastError());
} else {
NVTE_ERROR("Input tensor must be FP16/BF16 tensor or tensor-scaled FP8 tensor, ",
"but dtype is ", to_string(input.dtype()), ".");
}
}
void dropout_bwd(const Tensor &grad_output, const Tensor &mask, Tensor &grad_input,
float dropout_probability, cudaStream_t stream) {
// Check tensors
const size_t numel = grad_output.numel();
NVTE_CHECK(grad_output.scaling_mode == NVTE_DELAYED_TENSOR_SCALING,
"Grad output tensor must be FP16/BF16 tensor, ", "but scaling mode is ",
to_string(grad_output.scaling_mode), ".");
NVTE_CHECK(grad_input.scaling_mode == NVTE_DELAYED_TENSOR_SCALING,
"Grad input tensor must be FP16/BF16 tensor, ", "but scaling mode is ",
to_string(grad_input.scaling_mode), ".");
NVTE_CHECK(mask.scaling_mode == NVTE_DELAYED_TENSOR_SCALING,
"Mask tensor must be a plain tensor, but scaling mode is ",
to_string(mask.scaling_mode), ".");
NVTE_CHECK(grad_output.dtype() == DType::kFloat16 || grad_output.dtype() == DType::kBFloat16,
"Grad output tensor must be FP16/BF16 tensor, but dtype is ",
to_string(grad_output.dtype()), ".");
NVTE_CHECK(grad_output.dtype() == grad_input.dtype(),
"Grad output tensor (dtype=", to_string(grad_output.dtype()),
") and grad input tensor (dtype=", to_string(grad_input.dtype()), ") do not match.");
NVTE_CHECK(numel % 16 == 0,
"Grad output tensor number of elements must be divisible by 16, but shape is ",
grad_output.shape(), ".");
NVTE_CHECK(numel == grad_input.numel(), "Grad output tensor (shape=", grad_output.shape(),
") and grad input tensor (shape=", grad_input.shape(), ") do not match.");
NVTE_CHECK(typeToNumBits(mask.dtype()) * mask.numel() == numel, "Mask tensor must have ", numel,
" bits, but found dtype=", to_string(mask.dtype()), " and shape=", mask.shape(), ".");
NVTE_CHECK(grad_output.data.dptr != nullptr, "Grad output tensor is missing data.");
NVTE_CHECK(grad_input.data.dptr != nullptr, "Grad input tensor is missing data.");
NVTE_CHECK(mask.data.dptr != nullptr, "Mask tensor is missing data.");
// Convert dropout probablity to scale
NVTE_CHECK(dropout_probability >= 0 && dropout_probability < 1, "Invalid dropout probability (",
dropout_probability, ").");
const float scale = 1 / (1 - dropout_probability);
// CUDA config
const size_t num_chunks = numel / 8;
const size_t num_blocks = DIVUP(num_chunks, block_size);
// Launch kernel
TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(
grad_output.dtype(), DType,
apply_dropout_mask<DType><<<num_blocks, block_size, 0, stream>>>(
reinterpret_cast<const DType *>(grad_output.data.dptr),
reinterpret_cast<const uint8_t *>(mask.data.dptr),
reinterpret_cast<DType *>(grad_input.data.dptr), num_chunks, scale););
NVTE_CHECK_CUDA(cudaGetLastError());
}
} // namespace transformer_engine
void nvte_dropout_fwd(const NVTETensor input, NVTETensor output, NVTETensor mask,
NVTETensor rng_state, float dropout_probability, cudaStream_t stream) {
NVTE_API_CALL(nvte_dropout_fwd);
using namespace transformer_engine;
dropout_fwd(*convertNVTETensorCheck(input), *convertNVTETensorCheck(output),
*convertNVTETensorCheck(mask), *convertNVTETensorCheck(rng_state),
dropout_probability, stream);
}
void nvte_dropout_bwd(const NVTETensor grad_output, const NVTETensor mask, NVTETensor grad_input,
float dropout_probability, cudaStream_t stream) {
NVTE_API_CALL(nvte_dropout_bwd);
using namespace transformer_engine;
dropout_bwd(*convertNVTETensorCheck(grad_output), *convertNVTETensorCheck(mask),
*convertNVTETensorCheck(grad_input), dropout_probability, stream);
}
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file dropout.h
* \brief Functions for dropout.
*/
#ifndef TRANSFORMER_ENGINE_DROPOUT_FP8_H_
#define TRANSFORMER_ENGINE_DROPOUT_FP8_H_
#include "transformer_engine.h"
#ifdef __cplusplus
extern "C" {
#endif
/*! \brief Dropout forward kernel.
*
* \param[in] input Input tensor.
* \param[out] output Output tensor.
* \param[out] mask Mask tensor. Each bit corresponds to an
* output tensor entry. Ones indicate kept
* entries and zeros indicate dropped entries.
* \param[in] rng_state RNG engine inputs.
* \param[in] dropout_probability Dropout probability.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_dropout_fwd(const NVTETensor input, NVTETensor output, NVTETensor mask,
NVTETensor rng_state, float dropout_probability, cudaStream_t stream);
/*! \brief Dropout backward kernel.
*
* \param[in] grad_output Gradient of output tensor.
* \param[out] mask Mask tensor. Each bit corresponds to an
* output tensor entry. Ones indicate kept
* entries and zeros indicate dropped entries.
* \param[out] grad_input Gradient of input tensor.
* \param[in] dropout_probability Dropout probability.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_dropout_bwd(const NVTETensor grad_output, const NVTETensor mask, NVTETensor grad_input,
float dropout_probability, cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
#endif
...@@ -265,6 +265,17 @@ std::vector<py::object> dbias_dqgelu(const at::Tensor &grad_output, const at::Te ...@@ -265,6 +265,17 @@ std::vector<py::object> dbias_dqgelu(const at::Tensor &grad_output, const at::Te
std::vector<py::object> dbias_dsrelu(const at::Tensor &grad_output, const at::Tensor &act_input, std::vector<py::object> dbias_dsrelu(const at::Tensor &grad_output, const at::Tensor &act_input,
py::handle quantizer); py::handle quantizer);
/***************************************************************************************************
* Dropout
**************************************************************************************************/
std::vector<py::object> dropout_fwd(const py::handle &input, const float dropout_probability,
std::optional<at::Tensor> out = std::nullopt);
py::object dropout_bwd(const at::Tensor &grad_output, const at::Tensor &mask,
const float dropout_probability,
std::optional<at::Tensor> grad_input = std::nullopt);
/*************************************************************************************************** /***************************************************************************************************
* Softmax * Softmax
**************************************************************************************************/ **************************************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "transformer_engine/dropout.h"
#include <ATen/cuda/CUDAGeneratorImpl.h>
#include <pybind.h>
#include <ATen/cuda/CUDAGraphsUtils.cuh>
#include "../common.h"
#include "../extensions.h"
#include "../pybind.h"
#include "transformer_engine/transformer_engine.h"
namespace transformer_engine {
namespace pytorch {
std::vector<py::object> dropout_fwd(const py::handle &input, float dropout_probability,
std::optional<at::Tensor> out) {
using namespace transformer_engine::pytorch::detail;
// Input tensor
const TensorWrapper input_nvte = makeTransformerEngineTensor(input, py::none());
// Allocate output tensor if needed
if (!out) {
at::ScalarType dtype = GetATenDType(input_nvte.dtype());
if (dtype == at::kFloat8_e4m3fn || dtype == at::kFloat8_e5m2) {
dtype = input.attr("dtype").cast<at::ScalarType>();
}
const auto shape_uint64 = convertShape(input_nvte.shape());
const std::vector<int64_t> shape_int64(shape_uint64.begin(), shape_uint64.end());
const auto opts = at::TensorOptions().dtype(dtype).device(torch::kCUDA);
out = at::empty(shape_int64, opts);
}
TensorWrapper out_nvte = makeTransformerEngineTensor(*out);
// Mask tensor
auto mask_pyt = allocateTorchTensor(input_nvte.numel() / 8, DType::kByte);
auto mask_nvte = makeTransformerEngineTensor(mask_pyt);
// RNG state tensor
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
std::nullopt, at::cuda::detail::getDefaultCUDAGenerator());
at::PhiloxCudaState philox_args;
{
std::lock_guard<std::mutex> lock(gen->mutex_);
constexpr int64_t rng_elts_per_thread = 4;
philox_args = gen->philox_cuda_state(rng_elts_per_thread);
}
auto rng_state_pyt = allocateTorchTensor(2, DType::kInt64);
NVTE_SCOPED_GIL_RELEASE({
nvte_extract_seed_and_offset(
reinterpret_cast<int64_t *>(rng_state_pyt.data_ptr()), philox_args.captured_,
philox_args.seed_.ptr, philox_args.seed_.val, philox_args.offset_.ptr,
philox_args.offset_.val, philox_args.offset_intragraph_, at::cuda::getCurrentCUDAStream());
});
auto rng_state_nvte = makeTransformerEngineTensor(rng_state_pyt);
// Launch kernel
NVTE_SCOPED_GIL_RELEASE({
nvte_dropout_fwd(input_nvte.data(), out_nvte.data(), mask_nvte.data(), rng_state_nvte.data(),
dropout_probability, at::cuda::getCurrentCUDAStream());
});
return {py::cast(std::move(*out)), py::cast(mask_pyt)};
}
py::object dropout_bwd(const at::Tensor &grad_output, const at::Tensor &mask,
const float dropout_probability, std::optional<at::Tensor> grad_input) {
const auto grad_output_nvte = makeTransformerEngineTensor(grad_output);
const auto mask_nvte = makeTransformerEngineTensor(mask);
if (!grad_input) {
grad_input = at::empty_like(grad_output);
}
auto grad_input_nvte = makeTransformerEngineTensor(*grad_input);
NVTE_SCOPED_GIL_RELEASE({
nvte_dropout_bwd(grad_output_nvte.data(), mask_nvte.data(), grad_input_nvte.data(),
dropout_probability, at::cuda::getCurrentCUDAStream());
});
return py::cast(std::move(*grad_input));
}
} // namespace pytorch
} // namespace transformer_engine
...@@ -305,6 +305,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -305,6 +305,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("Const_buf"), py::arg("tokens_per_expert"), py::arg("num_rows"), py::arg("Const_buf"), py::arg("tokens_per_expert"), py::arg("num_rows"),
py::arg("num_cols"), py::arg("grad_aux_loss"), "Fused aux loss bwd"); py::arg("num_cols"), py::arg("grad_aux_loss"), "Fused aux loss bwd");
// Dropout
m.def("dropout_fwd", transformer_engine::pytorch::dropout_fwd, "Dropout forward with 8-bit RNG",
py::arg("input"), py::arg("dropout_probability"), py::arg("out") = std::nullopt);
m.def("dropout_bwd", transformer_engine::pytorch::dropout_bwd, "Dropout backward with 8-bit RNG",
py::arg("grad_output"), py::arg("mask"), py::arg("dropout_probability"),
py::arg("grad_input") = std::nullopt);
// Misc // Misc
m.def("get_cublasLt_version", &transformer_engine::pytorch::get_cublasLt_version, m.def("get_cublasLt_version", &transformer_engine::pytorch::get_cublasLt_version,
"Get cublasLt version", py::call_guard<py::gil_scoped_release>()); "Get cublasLt version", py::call_guard<py::gil_scoped_release>());
......
...@@ -8,12 +8,11 @@ from __future__ import annotations ...@@ -8,12 +8,11 @@ from __future__ import annotations
from typing import Optional from typing import Optional
import torch import torch
import transformer_engine_torch as tex
from transformer_engine.pytorch.ops.op import (
BasicOperation,
OperationContext,
)
from ...tensor import Quantizer from ...tensor import Quantizer
from ...tensor._internal.float8_tensor_base import Float8TensorBase
from .._common import maybe_autocast_dtype, maybe_dequantize
from ..op import BasicOperation, OperationContext
class Dropout(BasicOperation): class Dropout(BasicOperation):
...@@ -27,7 +26,7 @@ class Dropout(BasicOperation): ...@@ -27,7 +26,7 @@ class Dropout(BasicOperation):
def __init__(self, p: float) -> None: def __init__(self, p: float) -> None:
super().__init__() super().__init__()
self.dropout_probability = p self.dropout_probability: float = p
def op_forward( def op_forward(
self, self,
...@@ -37,21 +36,44 @@ class Dropout(BasicOperation): ...@@ -37,21 +36,44 @@ class Dropout(BasicOperation):
next_op_input_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer],
) -> torch.Tensor: ) -> torch.Tensor:
# Compute dropout if training # Output dtype
dtype = maybe_autocast_dtype(default_dtype=input_.dtype)
# Choose implementation
impl = None
if not self.training:
impl = "evaluation"
elif input_.numel() % 16 == 0 and dtype in (torch.float16, torch.bfloat16):
impl = "fused"
else:
impl = "unfused"
# Perform dropout
out: torch.Tensor
mask: Optional[torch.Tensor] = None
if impl == "evaluation":
out = input_ out = input_
is_training = self.training elif impl == "fused":
mask = None x = input_
if is_training: if not isinstance(x, Float8TensorBase):
x = maybe_dequantize(x, dtype=dtype)
out, mask = tex.dropout_fwd(x, self.dropout_probability)
elif impl == "unfused":
x = maybe_dequantize(input_, dtype=dtype)
keep_prob = 1 - self.dropout_probability keep_prob = 1 - self.dropout_probability
mask = torch.empty_like(input_) mask = torch.empty_like(x)
mask.bernoulli_(keep_prob) mask.bernoulli_(keep_prob)
mask *= 1 / keep_prob mask *= 1 / keep_prob
out = out * mask out = x * mask
else:
raise ValueError(f"Unsupported forward implementation {impl}")
# Save context for backward # Save context for backward
if ctx.requires_grad: if ctx.requires_grad:
ctx.save_for_backward(mask) ctx.save_for_backward(mask)
ctx.is_training = is_training ctx.impl = impl
ctx.dropout_probability = self.dropout_probability
ctx.dtype = dtype
return out return out
...@@ -60,8 +82,21 @@ class Dropout(BasicOperation): ...@@ -60,8 +82,21 @@ class Dropout(BasicOperation):
ctx: OperationContext, ctx: OperationContext,
grad_output: torch.Tensor, grad_output: torch.Tensor,
) -> tuple[torch.Tensor, tuple[()]]: ) -> tuple[torch.Tensor, tuple[()]]:
# Saved tensors from forward pass
(mask,) = ctx.saved_tensors (mask,) = ctx.saved_tensors
# Perform dropout backward pass
grad_input: torch.Tensor
if ctx.impl == "evaluation":
grad_input = grad_output grad_input = grad_output
if ctx.is_training: elif ctx.impl == "fused":
grad_input = grad_input * mask dy = maybe_dequantize(grad_output, dtype=ctx.dtype)
grad_input = tex.dropout_bwd(dy, mask, ctx.dropout_probability)
elif ctx.impl == "unfused":
dy = maybe_dequantize(grad_output, dtype=ctx.dtype)
grad_input = dy * mask
else:
raise ValueError(f"Unsupported backward implementation {ctx.impl}")
return grad_input, () return grad_input, ()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment