Commit f8c2af4c authored by yuguo's avatar yuguo
Browse files

Merge commit '1d903f5e' of...

Merge commit '1d903f5e' of https://github.com/NVIDIA/TransformerEngine
parents e92773a3 1d903f5e
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#pragma once
#include <ATen/ATen.h>
#include "common/utils.cuh"
// Forward/backward compatiblity hack around
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
// pending more future-proof guidance from upstream.
// struct TypeShim
// {
// const at::Type& payload;
// TypeShim(const at::Type& type) : payload(type) {}
// // Enable trivial conversion to a const at::Type& for pre-3aeb78
// operator const at::Type&(){ return payload; };
// // Enable dispatch switch statements to take *this directly for post-3aeb78
// //operator at::ScalarType(){ return payload.; };
// };
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t_##LEVEL = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Byte: { \
using scalar_t_##LEVEL = uint8_t; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Double: { \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Double: { \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t_##LEVEL = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Double: { \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Half: { \
using scalar_t = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
switch (TYPEIN) { \
case at::ScalarType::Float: { \
using scalar_t_in = float; \
switch (TYPEOUT) { \
case at::ScalarType::Float: { \
using scalar_t_out = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
} \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_in = at::Half; \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t_in = at::BFloat16; \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
switch (TYPEIN) { \
case at::ScalarType::Double: { \
using scalar_t_in = double; \
switch (TYPEOUT) { \
case at::ScalarType::Double: { \
using scalar_t_out = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: { \
using scalar_t_out = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
} \
break; \
} \
case at::ScalarType::Float: { \
using scalar_t_in = float; \
switch (TYPEOUT) { \
case at::ScalarType::Float: { \
using scalar_t_out = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
} \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_in = at::Half; \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t_in = at::BFloat16; \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
}
template <typename T>
__device__ __forceinline__ T
reduce_block_into_lanes(T *x, T val, int lanes = 1,
bool share_result = false) { // lanes is intended to be <= 32.
int tid = threadIdx.x + threadIdx.y * blockDim.x;
int blockSize = blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
if (blockSize >= 64) {
x[tid] = val;
__syncthreads();
}
#pragma unroll
for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
if (tid < i) x[tid] = x[tid] + x[tid + i];
__syncthreads();
}
T final;
if (tid < 32) {
if (blockSize >= 64)
final = x[tid] + x[tid + 32];
else
final = val;
// __SYNCWARP();
#pragma unroll
#ifdef __HIP_PLATFORM_AMD__
for (int i = 16; i >= lanes; i >>= 1) final = final + __shfl_down(final, i, THREADS_PER_WARP);
#else
for (int i = 16; i >= lanes; i >>= 1) final = final + __shfl_down_sync(0xffffffff, final, i);
#endif
}
if (share_result) {
if (tid < lanes) x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
}
__syncthreads();
// Avoid potential write before read race when reduce_block_into_lanes is called back to back
return final;
}
template <typename T>
__device__ __forceinline__ T
reduce_block_into_lanes_max_op(T *x, T val, int lanes = 1,
bool share_result = false) { // lanes is intended to be <= 32.
int tid = threadIdx.x + threadIdx.y * blockDim.x;
int blockSize = blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
if (blockSize >= 64) {
x[tid] = val;
__syncthreads();
}
#pragma unroll
for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
if (tid < i) x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid + i]));
__syncthreads();
}
T final;
if (tid < 32) {
if (blockSize >= 64)
final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32]));
else
final = val;
// __SYNCWARP();
#pragma unroll
for (int i = 16; i >= lanes; i >>= 1)
#ifdef __HIP_PLATFORM_AMD__
final = fmaxf(fabsf(final), fabsf(__shfl_down(final, i, THREADS_PER_WARP)));
#else
final = fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i)));
#endif
}
if (share_result) {
if (tid < lanes) x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads();
}
return final;
}
......@@ -4,10 +4,10 @@
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
#include "transformer_engine/transformer_engine.h"
#include "util.h"
#include "common.h"
std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrapper& input,
bool rowwise) {
using namespace transformer_engine::pytorch;
......@@ -45,80 +45,33 @@ std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrap
transformer_engine::TensorWrapper input_cu(NVTE_MXFP8_1D_SCALING);
transformer_engine::TensorWrapper output_cu(NVTE_MXFP8_1D_SCALING);
if (rowwise) {
input_cu.set_rowwise_data(input.dptr(), DType::kFloat8E4M3, input_shape);
input_cu.set_rowwise_scale_inv(scale_inv_dptr, DType::kFloat8E8M0, scale_inv_shape);
output_cu.set_rowwise_data(input.dptr(), DType::kFloat8E4M3, input_shape);
output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, DType::kFloat8E8M0, scale_inv_shape);
input_cu.set_rowwise_data(input.dptr(), transformer_engine::DType::kFloat8E4M3, input_shape);
input_cu.set_rowwise_scale_inv(scale_inv_dptr, transformer_engine::DType::kFloat8E8M0,
scale_inv_shape);
output_cu.set_rowwise_data(input.dptr(), transformer_engine::DType::kFloat8E4M3, input_shape);
output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0,
scale_inv_shape);
} else {
input_cu.set_columnwise_data(input.columnwise_dptr(), DType::kFloat8E4M3, input_shape);
input_cu.set_columnwise_scale_inv(scale_inv_dptr, DType::kFloat8E8M0, scale_inv_shape);
output_cu.set_columnwise_data(input.columnwise_dptr(), DType::kFloat8E4M3, input_shape);
output_cu.set_columnwise_scale_inv(swizzled_scale_inv_dptr, DType::kFloat8E8M0,
scale_inv_shape);
input_cu.set_columnwise_data(input.columnwise_dptr(), transformer_engine::DType::kFloat8E4M3,
input_shape);
input_cu.set_columnwise_scale_inv(scale_inv_dptr, transformer_engine::DType::kFloat8E8M0,
scale_inv_shape);
output_cu.set_columnwise_data(input.columnwise_dptr(), transformer_engine::DType::kFloat8E4M3,
input_shape);
output_cu.set_columnwise_scale_inv(swizzled_scale_inv_dptr,
transformer_engine::DType::kFloat8E8M0, scale_inv_shape);
}
// Launch kernel
nvte_swizzle_scaling_factors(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
if (rowwise) {
input.set_rowwise_scale_inv(swizzled_scale_inv_dptr, DType::kFloat8E8M0, scale_inv_shape);
input.set_rowwise_scale_inv(swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0,
scale_inv_shape);
} else {
input.set_columnwise_scale_inv(swizzled_scale_inv_dptr, DType::kFloat8E8M0, scale_inv_shape);
input.set_columnwise_scale_inv(swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0,
scale_inv_shape);
}
return swizzled_scale_inv;
}
at::Tensor rowwise_swizzle(at::Tensor input, at::Tensor scale_inv) {
using namespace transformer_engine::pytorch;
NVTE_CHECK(input.element_size() == 1, "8-bit input required for swizzling scaling factors.");
auto options = at::TensorOptions().dtype(scale_inv.dtype()).device(torch::kCUDA);
auto swizzled_scale_inv = at::empty_like(scale_inv, options);
void* scale_inv_dptr = getDataPtr(scale_inv, 0);
void* swizzled_scale_inv_dptr = getDataPtr(swizzled_scale_inv, 0);
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), getTensorShape(input),
DType::kFloat8E4M3, nullptr, nullptr, scale_inv_dptr,
getTensorShape(scale_inv), NVTE_MXFP8_1D_SCALING);
auto output_cu = makeTransformerEngineTensor(
input.data_ptr(), getTensorShape(input), DType::kFloat8E4M3, nullptr, nullptr,
swizzled_scale_inv_dptr, getTensorShape(swizzled_scale_inv), NVTE_MXFP8_1D_SCALING);
// Launch kernel
nvte_swizzle_scaling_factors(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
return swizzled_scale_inv;
}
at::Tensor columnwise_swizzle(at::Tensor input, at::Tensor scale_inv) {
using namespace transformer_engine::pytorch;
NVTE_CHECK(input.element_size() == 1, "8-bit input required for swizzling scaling factors.");
auto options = at::TensorOptions().dtype(scale_inv.dtype()).device(torch::kCUDA);
auto swizzled_scale_inv = at::empty_like(scale_inv, options);
// Return immediately if tensor is empty
if (scale_inv.numel() == 0) {
return swizzled_scale_inv;
}
void* scale_inv_dptr = getDataPtr(scale_inv, 0);
void* swizzled_scale_inv_dptr = getDataPtr(swizzled_scale_inv, 0);
auto input_cu = makeTransformerEngineTensor(
nullptr, input.data_ptr(), {1}, getTensorShape(input), DType::kFloat8E4M3, nullptr, nullptr,
nullptr, scale_inv_dptr, {1}, getTensorShape(scale_inv), NVTE_MXFP8_1D_SCALING);
auto output_cu = makeTransformerEngineTensor(
nullptr, input.data_ptr(), {1}, getTensorShape(input), DType::kFloat8E4M3, nullptr, nullptr,
nullptr, swizzled_scale_inv_dptr, {1}, getTensorShape(swizzled_scale_inv),
NVTE_MXFP8_1D_SCALING);
// Launch kernel
nvte_swizzle_scaling_factors(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
return swizzled_scale_inv;
}
......@@ -13,8 +13,6 @@
#include "transformer_engine/transformer_engine.h"
bool non_tn_fp8_gemm_supported();
/* Swizzle the scaling factor of the input tensor.
*
* The returned swizzled scaling factor tensor should be kept alive during the GEMM.
......
......@@ -19,7 +19,12 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp._common_utils import _get_module_fsdp_state
from torch.distributed.fsdp._traversal_utils import _get_fsdp_states_with_modules
from .utils import non_tn_fp8_gemm_supported, safely_set_viewless_tensor_data, needs_quantized_gemm
from . import torch_version
from .utils import (
is_non_tn_fp8_gemm_supported,
safely_set_viewless_tensor_data,
needs_quantized_gemm,
)
from .constants import dist_group_type
from .fp8 import FP8GlobalStateManager, fp8_autocast
from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer
......@@ -267,17 +272,36 @@ def _get_active_autocast_contexts():
"""
autocast_cached = torch.is_autocast_cache_enabled()
gpu_autocast_enabled = torch.is_autocast_enabled()
gpu_autocast_dtype = torch.get_autocast_gpu_dtype()
gpu_autocast_ctx = torch.cuda.amp.autocast(
gpu_autocast_enabled, gpu_autocast_dtype, autocast_cached
)
if torch_version() >= (2, 4, 0):
gpu_autocast_enabled = torch.is_autocast_enabled("cuda")
gpu_autocast_dtype = torch.get_autocast_dtype("cuda")
gpu_autocast_ctx = torch.amp.autocast(
"cuda",
enabled=gpu_autocast_enabled,
dtype=gpu_autocast_dtype,
cache_enabled=autocast_cached,
)
cpu_autocast_enabled = torch.is_autocast_cpu_enabled()
cpu_autocast_dtype = torch.get_autocast_cpu_dtype()
cpu_autocast_ctx = torch.cpu.amp.autocast(
cpu_autocast_enabled, cpu_autocast_dtype, autocast_cached
)
cpu_autocast_enabled = torch.is_autocast_enabled("cpu")
cpu_autocast_dtype = torch.get_autocast_dtype("cpu")
cpu_autocast_ctx = torch.amp.autocast(
"cpu",
enabled=cpu_autocast_enabled,
dtype=cpu_autocast_dtype,
cache_enabled=autocast_cached,
)
else:
gpu_autocast_enabled = torch.is_autocast_enabled()
gpu_autocast_dtype = torch.get_autocast_gpu_dtype()
gpu_autocast_ctx = torch.cuda.amp.autocast(
gpu_autocast_enabled, gpu_autocast_dtype, autocast_cached
)
cpu_autocast_enabled = torch.is_autocast_cpu_enabled()
cpu_autocast_dtype = torch.get_autocast_cpu_dtype()
cpu_autocast_ctx = torch.cpu.amp.autocast(
cpu_autocast_enabled, cpu_autocast_dtype, autocast_cached
)
return gpu_autocast_ctx, cpu_autocast_ctx
......@@ -561,7 +585,9 @@ def has_te_modules(network):
"""
from .module import LayerNorm, RMSNorm
from .module.base import TransformerEngineBaseModule
from .attention import UnfusedDotProductAttention, DotProductAttention, MultiheadAttention
from .attention.dot_product_attention.backends import UnfusedDotProductAttention
from .attention.dot_product_attention.dot_product_attention import DotProductAttention
from .attention.multi_head_attention import MultiheadAttention
from .transformer import TransformerLayer
te_classes_list = [
......@@ -893,8 +919,10 @@ def _all_gather_fp8(
# Note: We cannot directly all-gather the transposed FP8 tensor,
# so temporarily modify quantizer to avoid creating FP8 transpose.
if not isinstance(inp, Float8TensorBase):
if quantizer is None:
raise ValueError("Input tensor is not FP8 and no quantizer was provided")
assert isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer))
# we cannot directly gather the transposed fp8 tensor
# so we need to disable columnwise usage for the quantizer
# and then set it back to the original value after quantizing
init_rowwise_usage = quantizer.rowwise_usage
init_columnwise_usage = quantizer.columnwise_usage
quantizer.set_usage(rowwise=True, columnwise=False)
......@@ -938,7 +966,7 @@ def _all_gather_fp8(
# Make sure FP8 transpose is populated if needed
needs_transpose = (
quantizer is not None and quantizer.columnwise_usage and not non_tn_fp8_gemm_supported()
quantizer is not None and quantizer.columnwise_usage and not is_non_tn_fp8_gemm_supported()
)
if needs_transpose:
if handle is not None:
......@@ -1037,11 +1065,11 @@ def _all_gather_mxfp8(
dtype = inp.dtype
elif isinstance(inp, MXFP8TensorBase):
if inp._rowwise_data is not None:
in_shape = inp._rowwise_data.device.size()
in_shape = inp._rowwise_data.size()
device = inp._rowwise_data.device
dtype = inp._rowwise_data.dtype
elif inp._columnwise_data is not None:
in_shape = inp._columnwise_data.device.size()
in_shape = inp._columnwise_data.size()
device = inp._columnwise_data.device
dtype = inp._columnwise_data.dtype
else:
......@@ -1474,7 +1502,9 @@ def _is_te_module(module):
"""
from .module import LayerNorm, RMSNorm
from .module.base import TransformerEngineBaseModule
from .attention import UnfusedDotProductAttention, DotProductAttention, MultiheadAttention
from .attention.dot_product_attention.dot_product_attention import DotProductAttention
from .attention.dot_product_attention.backends import UnfusedDotProductAttention
from .attention.multi_head_attention import MultiheadAttention
from .transformer import TransformerLayer
te_classes_list = [
......
......@@ -520,8 +520,8 @@ class FP8GlobalStateManager:
return
# Store updated amaxes and scales from phase 1 post forward.
fp8_meta["updated_amax_history_fwd"] = fp8_meta["scaling_fwd"].amax_history
fp8_meta["updated_scale_fwd"] = fp8_meta["scaling_fwd"].scale
fp8_meta["updated_amax_history_fwd"] = fp8_meta["scaling_fwd"].amax_history.clone()
fp8_meta["updated_scale_fwd"] = fp8_meta["scaling_fwd"].scale.clone()
# Retrieve stashed amaxes and scales from phase 1 pre forward.
buffer_position_key = "global_fp8_buffer_pos_fwd_recompute"
......
......@@ -536,7 +536,9 @@ def _make_graphed_callables(
# Only Set the FP8 meta for the modules included by forward
continue
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
from transformer_engine.pytorch.attention import DotProductAttention
from transformer_engine.pytorch.attention.dot_product_attention import (
DotProductAttention,
)
if (
isinstance(m, DotProductAttention)
......
......@@ -8,6 +8,9 @@ from functools import wraps
from typing import Callable, Optional, Tuple
import torch
from . import torch_version
from .utils import gpu_autocast_ctx
from torch.utils.cpp_extension import IS_HIP_EXTENSION
# pylint: disable=unnecessary-lambda-assignment
......@@ -32,13 +35,13 @@ def lazy_compile(func):
jit_fuser = lambda func: func
if torch.__version__ >= "2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))):
if torch_version() >= (2, 0, 0) and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))):
jit_fuser = lazy_compile
# See: https://github.com/NVIDIA/TransformerEngine/issues/597
dropout_fuser = torch.jit.script
if torch.__version__ >= "2.2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))):
if torch_version() >= (2, 2, 0) and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))):
dropout_fuser = lazy_compile
......@@ -51,11 +54,9 @@ def set_jit_fusion_options() -> None:
if not IS_HIP_EXTENSION:
"""Set PyTorch JIT layer fusion options."""
# flags required to enable jit fusion kernels
TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split(".")[1])
if TORCH_MAJOR == 2 and TORCH_MINOR >= 2:
if torch_version() >= (2, 2, 0):
pass
elif (TORCH_MAJOR == 2) or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10):
elif torch_version() >= (1, 10, 0):
# nvfuser
torch._C._jit_set_profiling_executor(True)
torch._C._jit_set_profiling_mode(True)
......@@ -124,7 +125,7 @@ def dgelu_fused_(grad_output: torch.Tensor, inp: torch.Tensor) -> torch.Tensor:
def bias_gelu_fused(inp: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
"""Disable native AMP for bias_gelu_fused_"""
with torch.cuda.amp.autocast(enabled=False):
with gpu_autocast_ctx(enabled=False):
if bias is not None and bias.numel() != 0:
return bias_gelu_fused_(inp, bias)
return gelu_fused_(inp)
......@@ -134,7 +135,7 @@ def bgrad_dgelu_fused(
grad_output: torch.Tensor, inp: torch.Tensor, bias: torch.Tensor
) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
"""Disable native AMP for `bgrad_dgelu_fused_`"""
with torch.cuda.amp.autocast(enabled=False):
with gpu_autocast_ctx(enabled=False):
if bias is not None and bias.numel() != 0:
return bgrad_dgelu_fused_(grad_output, inp, bias)
return None, dgelu_fused_(grad_output, inp)
......@@ -175,7 +176,7 @@ def bias_dropout_add_fused_train(
) -> torch.Tensor:
"""Disable native AMP and enable grad for BDA"""
with torch.enable_grad():
with torch.cuda.amp.autocast(enabled=False):
with gpu_autocast_ctx(enabled=False):
return bias_dropout_add_fused_train_(x, bias, residual, prob)
......@@ -191,7 +192,7 @@ def bias_dropout_add_fused_inference(
x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float
) -> torch.Tensor:
"""Disable native AMP for BDA"""
with torch.cuda.amp.autocast(enabled=False):
with gpu_autocast_ctx(enabled=False):
return bias_dropout_add_fused_inference_(x, bias, residual, prob)
......
......@@ -6,8 +6,6 @@
from typing import Any, List, Optional, Tuple, Union, Callable
from dataclasses import dataclass
from functools import reduce
from operator import mul as multiply_op
import queue
import torch
......@@ -15,7 +13,6 @@ import torch
from .. import cpp_extensions as tex
from ..constants import TE_DType
from ..utils import get_default_init_method
from ..tensor.float8_tensor import Float8Tensor
import warnings
try:
from lightop import rmsnorm_forward,rmsnorm_backward
......@@ -24,7 +21,6 @@ except ImportError:
enable_lightop = False
warnings.warn("Failed to import lightop module. Falling back to alternative implementation.", UserWarning)
def _get_normalization_func(normalization: str, forward: bool):
fwd_normalization_funcs = {
"LayerNorm": tex.layernorm_fwd,
......@@ -40,39 +36,6 @@ def _get_normalization_func(normalization: str, forward: bool):
return bwd_normalization_funcs[normalization]
def _fix_gathered_fp8_transpose(fp8_tensor: Float8Tensor, tp_size: int) -> Float8Tensor:
"""Reorder FP8 transposes after Userbuffers gather.
The all-gather is performed in-place in the Float8Tensor's
row-wise data, and afterwards we need to do a transpose to get the
correct ordering. This misuses data fields in Float8Tensor and
should be considered an evil hack. It would be best to move
transpose logic into CommOverlap::get_buffer.
Responsibility for fixing: adener, tmoon
"""
assert isinstance(fp8_tensor, Float8Tensor), "Tensor is not a Float8Tensor"
assert tp_size > 1, "The tensor transpose cannot be interleaved when TP size is 1"
assert fp8_tensor._data is not None, "The tensor does not hold any rowwise data"
assert (
fp8_tensor._data.shape[0] % tp_size == 0
), "Leading dimension of data is not divisble by TP size"
data = fp8_tensor._data
batched_size = reduce(multiply_op, data.shape[1:])
interleaved_shape = [tp_size, data.shape[0] // tp_size, batched_size]
transposed_shape = [data.shape[0] // tp_size, batched_size * tp_size]
fp8_tensor._transpose = (
data.view(interleaved_shape).transpose(0, 1).contiguous().view(transposed_shape)
)
fp8_tensor._transpose_invalid = False
fp8_tensor._data = None
return fp8_tensor
def apply_normalization(
inputmat: torch.Tensor,
ln_out: torch.Tensor,
......
......@@ -4,6 +4,7 @@
"""Base modules and utilities for TransformerEngine PyTorch API"""
import io
import math
import os
import pickle
import warnings
......@@ -35,10 +36,13 @@ from ..distributed import (
_fsdp_gather_tensors,
)
from ..constants import dist_group_type
from ..tensor import QuantizedTensor, Quantizer
from ..tensor.quantized_tensor import QuantizedTensor, QuantizedTensorBase, Quantizer
from ..tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..tensor._internal.float8_tensor_base import Float8TensorBase
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..utils import torch_get_autocast_gpu_dtype
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ...common.recipe import Recipe
from ...debug.pytorch.debug_state import TEDebugState
......@@ -451,6 +455,142 @@ def destroy_ub():
layers_atomic_ring_exchange = []
def fill_userbuffers_buffer_for_all_gather(
comm,
local_tensor: torch.Tensor,
quantizer: Optional[Quantizer],
process_group,
) -> tuple[torch.Tensor | QuantizedTensorBase, torch.Tensor | QuantizedTensorBase]:
"""Fill local shard of Userbuffers buffer with data for all-gather
Returns the full tensor and the local shard, both using the
Userbuffers buffer as their underlying data. These tensors should
be used carefully (e.g. only immediately before and after a
Userbuffers operation) since the underlying data may be
overwritten by other Userbuffers operations.
May perform blocking communication if needed for the gathered
tensor's metadata, e.g. scaling factors.
"""
# Tensor dimensions
local_shape = local_tensor.size()
if not local_shape:
raise ValueError(f"Invalid local tensor (shape={tuple(local_shape)})")
process_group_size = torch.distributed.get_world_size(process_group)
global_shape = list(local_shape)
global_shape[0] *= process_group_size
# Unquantized data
if quantizer is None:
if isinstance(local_tensor, QuantizedTensorBase):
local_tensor = local_tensor.dequantize()
if comm.is_fp8_ubuf():
raise RuntimeError(
"Attempting to all-gather unquantized tensor, "
"but Userbuffers is initialized with FP8 buffers"
)
comm.copy_into_buffer(local_tensor, local_chunk=True)
global_tensor = comm.get_buffer(shape=global_shape)
return global_tensor, local_tensor
# FP8 data
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
if not isinstance(local_tensor, Float8TensorBase):
if isinstance(local_tensor, QuantizedTensorBase):
local_tensor.dequantize()
quantizer.set_usage(rowwise=True, columnwise=False)
local_tensor = quantizer(local_tensor)
if not comm.is_fp8_ubuf():
raise RuntimeError(
"Attempting to all-gather FP8 tensor, "
"but Userbuffers is not initialized with FP8 buffers"
)
comm.copy_into_buffer(local_tensor._data, local_chunk=True)
global_tensor_data = comm.get_buffer(shape=global_shape)
global_tensor = Float8TensorBase(
data=global_tensor_data,
fp8_scale_inv=local_tensor._scale_inv,
fp8_dtype=local_tensor._fp8_dtype,
quantizer=quantizer,
)
return global_tensor, local_tensor
# MXFP8 data
if isinstance(quantizer, MXFP8Quantizer):
# Cast to MXFP8 if needed
if not isinstance(local_tensor, MXFP8TensorBase):
if isinstance(local_tensor, QuantizedTensorBase):
local_tensor.dequantize()
local_tensor = quantizer(local_tensor)
if not comm.is_fp8_ubuf():
raise RuntimeError(
"Attempting to all-gather MXFP8 tensor, "
"but Userbuffers is not initialized with FP8 buffers"
)
# Check which MXFP8 buffer to communicate
if quantizer.rowwise_usage == quantizer.columnwise_usage:
raise ValueError(
"Userbuffers can only communicate one MXFP8 buffer at a time, "
f"but quantizer has rowwise_usage={quantizer.rowwise_usage}, "
f"columnwise_usage={quantizer.columnwise_usage}"
)
with_rowwise_data = quantizer.rowwise_usage
# Copy MXFP8 data to local chunk of Userbuffers buffer
local_data = (
local_tensor._rowwise_data if with_rowwise_data else local_tensor._columnwise_data
)
comm.copy_into_buffer(local_data, local_chunk=True)
# Gather scaling-inverses
if math.prod(local_shape[:-1]) % 128 != 0:
raise ValueError(
"Userbuffers requires MXFP8 tensor dims that are divisible by 128, "
f"but got MXFP8 tensor with shape={tuple(local_shape)}"
)
local_scale_inv = (
local_tensor._rowwise_scale_inv
if with_rowwise_data
else local_tensor._columnwise_scale_inv
)
local_scale_inv_size = list(local_scale_inv.size())
global_scale_inv = torch.empty(
[process_group_size * local_scale_inv_size[0]] + local_scale_inv_size[1:],
dtype=local_scale_inv.dtype,
device=local_scale_inv.device,
)
torch.distributed.all_gather_into_tensor(
global_scale_inv,
local_scale_inv,
group=process_group,
)
# Construct MXFP8 tensor with Userbuffers buffer
rowwise_data, rowwise_scale_inv = None, None
columnwise_data, columnwise_scale_inv = None, None
global_data = comm.get_buffer(shape=global_shape)
if with_rowwise_data:
rowwise_data, rowwise_scale_inv = global_data, global_scale_inv
else:
columnwise_data, columnwise_scale_inv = global_data, global_scale_inv
global_tensor = MXFP8TensorBase(
rowwise_data=rowwise_data,
rowwise_scale_inv=rowwise_scale_inv,
columnwise_data=columnwise_data,
columnwise_scale_inv=columnwise_scale_inv,
fp8_dtype=local_tensor._fp8_dtype,
quantizer=quantizer,
)
return global_tensor, local_tensor
# Unsupported data format
raise ValueError(f"Unsupported quantizer for Userbuffers ({quantizer})")
class TransformerEngineBaseModule(torch.nn.Module, ABC):
"""Base TE module."""
......@@ -625,7 +765,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
reset("scaling_fwd")
reset("scaling_bwd")
def get_extra_state(self) -> torch.Tensor:
def get_extra_state(self) -> Optional[torch.Tensor]:
"""Save before checkpointing."""
# This implementation is working around a few issues:
......@@ -659,25 +799,26 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Store FP8 state if needed
state = None
fp8_checkpoint = self.fp8_meta["fp8_checkpoint"] or self.fp8 or self.fp8_calibration
if fp8_checkpoint:
# Copy tensors to CPU and store
state = {}
state["recipe"] = self.fp8_meta["recipe"]
if state["recipe"].delayed():
state["scale_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].scale)
state["amax_history_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].amax_history)
state["scale_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].scale)
state["amax_history_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].amax_history)
# Store other pickelable values
extra = {}
for k, v in self.fp8_meta.items():
if k != "buffer_index_and_autocast_key" and isinstance(
v, (bool, int, float, str, tuple, list)
):
extra[k] = v
state["extra_fp8_variables"] = extra
if not fp8_checkpoint:
return None
# Copy tensors to CPU and store
state = {}
state["recipe"] = self.fp8_meta["recipe"]
if state["recipe"].delayed():
state["scale_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].scale)
state["amax_history_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].amax_history)
state["scale_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].scale)
state["amax_history_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].amax_history)
# Store other pickelable values
extra = {}
for k, v in self.fp8_meta.items():
if k != "buffer_index_and_autocast_key" and isinstance(
v, (bool, int, float, str, tuple, list)
):
extra[k] = v
state["extra_fp8_variables"] = extra
# Serialize state into byte tensor
torch.cuda.synchronize()
......@@ -685,7 +826,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
state_serialized = torch.frombuffer(state_serialized, dtype=torch.uint8)
return state_serialized
def set_extra_state(self, state: torch.Tensor) -> None:
def set_extra_state(self, state: Optional[torch.Tensor]) -> None:
"""Load previous state."""
if state is None:
return
......@@ -734,7 +875,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
"""Get activation data type for AMP."""
# Native AMP (`torch.autocast`) gets highest priority
if torch.is_autocast_enabled():
self.activation_dtype = torch.get_autocast_gpu_dtype()
self.activation_dtype = torch_get_autocast_gpu_dtype()
return
# All checks after this have already been performed once, thus skip
......@@ -898,11 +1039,15 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Non-FP8 case: bgrad is fused with wgrad for this case.
if not ctx.fp8 and not ctx.debug:
if gather_grad_output:
if not ctx.ub_overlap_ag or ctx.ub_obj_gradout is None:
if not ctx.ub_overlap_ag or ctx.ub_obj_gradout is None: # Perform NCCL all-gather
grad_output, _ = gather_along_first_dim(grad_output, ctx.tp_group)
else:
ctx.ub_obj_gradout.copy_into_buffer(grad_output, quantizer, local_chunk=True)
grad_output = ctx.ub_obj_gradout.get_buffer(quantizer)
else: # Initialize Userbuffers all-gather
grad_output, _ = fill_userbuffers_buffer_for_all_gather(
ctx.ub_obj_gradout,
grad_output,
None,
ctx.tp_group,
)
return grad_output, None
# FP8 with all-gather: unfused bgrad, fused cast + transpose
......@@ -925,8 +1070,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
grad_output = quantizer(grad_output)
# Copy into communication buffer, and replace original gradient with it
ctx.ub_obj_gradout.copy_into_buffer(grad_output, quantizer, local_chunk=True)
grad_output = ctx.ub_obj_gradout.get_buffer(quantizer)
grad_output, _ = fill_userbuffers_buffer_for_all_gather(
ctx.ub_obj_gradout,
grad_output,
quantizer,
ctx.tp_group,
)
else:
grad_output, _ = gather_along_first_dim(
grad_output,
......@@ -1140,7 +1289,16 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
raise ValueError(
"tensor and quantizer kwargs must be provided to construct FP8 workspace"
)
if cache_name is not None:
# Ensure the tensor in the cache is an instance of torch.Tensor,
# as it persists beyond a single forward pass.
# Setting internal=True would cause the data to be removed in prepare_for_saving(...).
quantizer_internal = quantizer.internal
quantizer.internal = False
out = quantizer.quantize(tensor, dtype=workspace_dtype)
if cache_name is not None:
quantizer.internal = quantizer_internal
# Update cache
if cache_name is not None:
......@@ -1188,7 +1346,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if self.wgrad_store is None or not self.wgrad_store.delay_wgrad_compute():
return
with torch.cuda.nvtx.range(f"_{self.__class__.__name__}_wgrad"):
(wgrad, grad_bias_, _, _), _ = self.wgrad_store.pop()
(wgrad, bgrad), _ = self.wgrad_store.pop()
if not self.fuse_wgrad_accumulation:
unfused_weights = [getattr(self, name) for name in self.weight_names]
weight_tensor = noop_cat(unfused_weights)
......@@ -1197,9 +1355,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if self.use_bias:
bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names])
if bias_tensor.grad is None:
bias_tensor.grad = grad_bias_.to(bias_tensor.dtype)
del grad_bias_
del wgrad
bias_tensor.grad = bgrad.to(bias_tensor.dtype)
def _validate_name(self):
"""
......
......@@ -4,6 +4,7 @@
"""GroupedLinear API"""
from typing import Union, Optional, Callable, Tuple, List
import warnings
import functools
import torch
......@@ -43,7 +44,7 @@ from ..graph import is_graph_capturing
from ..cpu_offload import is_cpu_offload_enabled
from ..tensor.quantized_tensor import (
QuantizedTensor,
QuantizedTensorBase,
Quantizer,
prepare_for_saving,
restore_from_saved,
......@@ -182,11 +183,11 @@ class _GroupedLinear(torch.autograd.Function):
# TODO: update after #1638 is merged. # pylint: disable=fixme
if weight_requires_grad:
for inputmat in inputmats:
if isinstance(inputmat, QuantizedTensor):
if isinstance(inputmat, QuantizedTensorBase):
inputmat.update_usage(rowwise_usage=False, columnwise_usage=True)
if inp.requires_grad:
for weight in weights_fp8:
if isinstance(weight, QuantizedTensor):
if isinstance(weight, QuantizedTensorBase):
weight.update_usage(columnwise_usage=True)
tensors_to_save, tensor_objects = prepare_for_saving(
......@@ -299,7 +300,7 @@ class _GroupedLinear(torch.autograd.Function):
)
for weight, quantizer in zip(weights, ctx.weight_quantizers):
if quantizer is not None and isinstance(weight, QuantizedTensor):
if quantizer is not None and isinstance(weight, QuantizedTensorBase):
weight.update_usage(
rowwise_usage=quantizer.rowwise_usage,
columnwise_usage=quantizer.columnwise_usage,
......@@ -663,7 +664,7 @@ class GroupedLinear(TransformerEngineBaseModule):
produced)
"""
assert not isinstance(
inp, QuantizedTensor
inp, QuantizedTensorBase
), "GroupedLinear doesn't support input tensor in FP8."
assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs."
......@@ -675,9 +676,14 @@ class GroupedLinear(TransformerEngineBaseModule):
weight_tensors = [getattr(self, f"weight{i}") for i in range(self.num_gemms)]
bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)]
if not self.fp8:
if not self.fp8 and any(isinstance(w, QuantizedTensorBase) for w in weight_tensors):
warnings.warn(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
weight_tensors = [
w.dequantize() if isinstance(w, QuantizedTensor) else w for w in weight_tensors
w.dequantize() if isinstance(w, QuantizedTensorBase) else w
for w in weight_tensors
]
input_quantizers, weight_quantizers, output_quantizers = (
......
......@@ -94,6 +94,9 @@ class LayerNorm(_LayerNormOp):
)
kwargs["dtype"] = params_dtype
# Flag for sequence parallelism (custom Megatron-LM integration)
self.sequence_parallel: Optional[bool] = sequence_parallel
# Initialize layer norm operation
super().__init__(
normalized_shape,
......@@ -102,12 +105,6 @@ class LayerNorm(_LayerNormOp):
**kwargs,
)
# Flag for sequence parallelism (custom Megatron-LM integration)
self.sequence_parallel: Optional[bool] = sequence_parallel
if sequence_parallel is not None:
self.weight.sequence_parallel = sequence_parallel
self.bias.sequence_parallel = sequence_parallel
def reset_layer_norm_parameters(self) -> None:
"""Init LN params"""
warnings.warn(
......@@ -136,7 +133,7 @@ class LayerNorm(_LayerNormOp):
super().reset_parameters()
# Set flag for sequence parallelism (custom Megatron-LM integration)
if getattr(self, "sequence_parallel", None) is not None:
if self.sequence_parallel is not None:
self.weight.sequence_parallel = self.sequence_parallel
self.bias.sequence_parallel = self.sequence_parallel
......
......@@ -9,7 +9,6 @@ from typing import Callable, Dict, Optional, Tuple, Union
from functools import reduce
from operator import mul as multiply_op
import functools
import torch
from torch.nn import init
......@@ -18,6 +17,7 @@ import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch import torch_version
from .base import (
fill_userbuffers_buffer_for_all_gather,
get_workspace,
get_ub,
TransformerEngineBaseModule,
......@@ -53,9 +53,10 @@ from ..distributed import (
from ..constants import GemmParallelModes, dist_group_type
from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing
from ._common import apply_normalization, noop_cat, _fix_gathered_fp8_transpose, WeightGradStore
from ._common import apply_normalization, noop_cat, WeightGradStore
from ..tensor.quantized_tensor import (
QuantizedTensor,
QuantizedTensorBase,
Quantizer,
prepare_for_saving,
restore_from_saved,
......@@ -144,8 +145,10 @@ class _LayerNormLinear(torch.autograd.Function):
# Make sure input dimensions are compatible
out_features, in_features = weight.shape
inp_shape = inp.shape
inp_requires_grad = inp.requires_grad
assert inp_shape[-1] == in_features, "GEMM not possible"
inputmat = inp.view((-1, in_features))
inp = inp.view((-1, in_features))
inputmat = inp
if fp8:
assert_dim_for_fp8_exec(inputmat, weight)
......@@ -158,42 +161,43 @@ class _LayerNormLinear(torch.autograd.Function):
nvtx_range_pop(f"{nvtx_label}.norm_input_cast")
tp_world_size = get_distributed_world_size(tp_group)
ub_overlap_ag_fprop = (
ub_overlap_ag_fprop and is_grad_enabled and not return_layernorm_output
)
weight_requires_grad = weight.requires_grad
backward_needs_input = is_grad_enabled and weight_requires_grad
with_input_all_gather = parallel_mode == "column" and sequence_parallel
# Check if Userbuffers is supported
if fp8:
if any([ub_overlap_ag_fprop, ub_overlap_rs_fprop]) and not (
FP8GlobalStateManager.get_fp8_recipe().float8_per_tensor_scaling()
):
raise NotImplementedError(
"Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor"
" current scaling"
)
# Configure Userbuffers communication (comm+GEMM overlap)
ub_obj = None
ub_type = None
ub_overlap_ag_fprop = (
ub_overlap_ag_fprop and is_grad_enabled and not return_layernorm_output
)
if ub_overlap_rs_fprop:
ub_obj = get_ub(ub_name + "_fprop")
ub_type = tex.CommOverlapType.RS
elif ub_overlap_ag_fprop:
ub_obj = get_ub(ub_name + "_fprop")
ub_type = tex.CommOverlapType.AG
# Configure quantizer for norm output
if fp8:
if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor")
columnwise_usage = backward_needs_input
if (
columnwise_usage
and with_input_all_gather
and not isinstance(input_quantizer, MXFP8Quantizer)
input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input)
if with_input_all_gather and isinstance(
input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
):
columnwise_usage = False
input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
# All-gather is not supported with FP8 column-wise data
input_quantizer.set_usage(columnwise=False)
# Avoid quantized norm kernel if norm output will be returned
# or if a gather of ln_out must be in high precision.
# Do TP communication in high precision if quantized format
# does not support communication
force_hp_blockwise_ln_out_gather = (
fp8 and with_input_all_gather and isinstance(input_quantizer, Float8BlockQuantizer)
) # Perform TP communication in high precision.
)
# Avoid quantized norm kernel if norm output will be returned
# or if a gather of ln_out must be in high precision.
with_quantized_norm = (
fp8
and not return_layernorm_output
......@@ -215,16 +219,19 @@ class _LayerNormLinear(torch.autograd.Function):
fwd_ln_sm_margin,
zero_centered_gamma,
)
nvtx_range_pop(f"{nvtx_label}.norm")
# Store unquantized layer norm output if we need to return it
ln_out_return = None
if return_layernorm_output or return_layernorm_output_gathered:
ln_out_return = ln_out
nvtx_range_pop(f"{nvtx_label}.norm")
# Prepare GEMM input
# ------------------------------------------------------
# Prepare GEMM input tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
# ------------------------------------------------------
nvtx_range_push(f"{nvtx_label}.gemm_input_cast_comm")
ln_out_total = None
ub_obj_fprop = None
if with_input_all_gather:
if return_layernorm_output_gathered:
# Perform all-gather in high precision if gathered
......@@ -232,47 +239,53 @@ class _LayerNormLinear(torch.autograd.Function):
ln_out_total, _ = gather_along_first_dim(ln_out, tp_group)
ln_out_return = ln_out_total
if fp8 or debug:
ln_out = input_quantizer(ln_out)
if not force_hp_blockwise_ln_out_gather:
ln_out = input_quantizer(ln_out)
input_quantizer.set_usage(rowwise=True, columnwise=False)
ln_out_total = input_quantizer(ln_out_total)
else:
quantizer = None
if fp8 or debug:
quantizer = input_quantizer
if not with_quantized_norm and not force_hp_blockwise_ln_out_gather:
ln_out = input_quantizer(ln_out)
input_quantizer.set_usage(rowwise=True, columnwise=False)
if ub_overlap_ag_fprop:
# Copy into Userbuffers buffer
ub_obj_fprop = get_ub(ub_name + "_fprop")
ub_obj_fprop.get_buffer(input_quantizer, local_chunk=True).copy_(ln_out)
ln_out_total = ub_obj_fprop.get_buffer(input_quantizer)
else:
# All-gather with NCCL
ln_out = quantizer(ln_out)
quantizer.set_usage(rowwise=True, columnwise=False)
if ub_overlap_ag_fprop: # Initialize Userbuffers all-gather
ln_out_total, _ = fill_userbuffers_buffer_for_all_gather(
ub_obj,
ln_out,
quantizer,
tp_group,
)
else: # Perform NCCL all-gather
ln_out_total, _ = gather_along_first_dim(
ln_out,
tp_group,
quantizer=(input_quantizer if fp8 or debug else None),
quantizer=quantizer,
)
else:
if (fp8 or debug) and not with_quantized_norm:
ln_out = input_quantizer(ln_out)
ln_out_total = ln_out
nvtx_range_pop(f"{nvtx_label}.gemm_input_cast_comm")
# ------------------------------------------------------
# GEMM input tensor is ready...
# ------------------------------------------------------
# Cast weight to expected dtype
# ------------------------------------------------------
# Prepare weight tensor
# ------------------------------------------------------
weightmat = weight
quantized_weight = False
if not fp8 and not debug:
weightmat = cast_if_needed(weightmat, activation_dtype)
else:
quantized_weight = not isinstance(weight, QuantizedTensor)
if fp8 or debug:
quantized_weight = not isinstance(weight, QuantizedTensorBase)
# Configure quantizer
if weight_quantizer is not None:
weight_quantizer.set_usage(rowwise=True, columnwise=True)
# FP8 cast to workspace buffer
# Get quantized weight
update_workspace = is_first_microbatch is None or is_first_microbatch
weightmat = module.get_weight_workspace(
tensor=weight,
quantizer=weight_quantizer,
......@@ -282,17 +295,21 @@ class _LayerNormLinear(torch.autograd.Function):
fsdp_group=fsdp_group,
workspace_dtype=activation_dtype,
)
weightmat.update_usage(rowwise_usage=True)
else:
weightmat = cast_if_needed(weightmat, activation_dtype) # Cast for AMP
# ------------------------------------------------------
# Weight tensor is ready for GEMM...
# ------------------------------------------------------
# Cast bias to expected dtype
bias_dtype = activation_dtype
if needs_quantized_gemm(ln_out_total) and activation_dtype == torch.float32:
# cuBLAS does not support FP8 GEMM with FP32 bias, so we cast to BF16
bias_dtype = torch.bfloat16
bias = cast_if_needed(bias, bias_dtype) if bias is not None else bias
# Configure output quantizer
if output_quantizer is not None:
output_quantizer.set_usage(rowwise=True, columnwise=False)
# Calibrate quantizers if needed
if not fp8 and fp8_calibration:
if input_quantizer is not None:
......@@ -300,47 +317,80 @@ class _LayerNormLinear(torch.autograd.Function):
if weight_quantizer is not None:
weight_quantizer.calibrate(weight)
ub_obj = None
ub_type = None
rs_out = None
if ub_overlap_rs_fprop:
ub_obj = get_ub(ub_name + "_fprop")
ub_type = tex.CommOverlapType.RS
out_shape = [reduce(multiply_op, inp_shape[:-1]) // tp_world_size, out_features]
rs_out = torch.empty(out_shape, dtype=activation_dtype, device=ln_out_total.device)
elif ub_overlap_ag_fprop:
ub_obj = get_ub(ub_name + "_fprop")
ub_type = tex.CommOverlapType.AG
if fp8:
assert ub_obj.is_fp8_ubuf(), "AG overlap with FP8 GEMM inputs requires FP8 buffer."
ln_out_total = ub_obj.get_buffer(input_quantizer)
nvtx_range_push(f"{nvtx_label}.gemm")
fprop_gemm_use_split_accumulator = _2X_ACC_FPROP
# Choose whether to use GEMM kernel with split accumulator
use_split_accumulator = _2X_ACC_FPROP
if fp8:
recipe = FP8GlobalStateManager.get_fp8_recipe()
if hasattr(recipe, "fp8_gemm_fprop"):
fprop_gemm_use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator
use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator
# Configure output quantizer
if output_quantizer is not None:
output_quantizer.set_usage(rowwise=True, columnwise=False)
out, *_, rs_out = general_gemm(
# Output buffer for Userbuffers reduce-scatter
reduce_scatter_out = None
if ub_overlap_rs_fprop:
out_shape = list(inp_shape)
out_shape[0] //= tp_world_size
out_shape[-1] = out_features
reduce_scatter_out = torch.empty(out_shape, dtype=activation_dtype, device=inp.device)
# ------------------------------------------------------
# Forward GEMM
# Note: y = x * w^T
# ------------------------------------------------------
nvtx_range_push(f"{nvtx_label}.gemm")
gemm_out, *_, reduce_scatter_out = general_gemm(
weightmat,
ln_out_total,
get_workspace(),
quantization_params=output_quantizer,
out_dtype=activation_dtype,
bias=bias,
use_split_accumulator=fprop_gemm_use_split_accumulator,
use_split_accumulator=use_split_accumulator,
ub=ub_obj,
ub_type=ub_type,
extra_output=rs_out,
extra_output=reduce_scatter_out,
)
nvtx_range_pop(f"{nvtx_label}.gemm")
# ------------------------------------------------------
# Finished forward GEMM...
# ------------------------------------------------------
# Deallocate GEMM input tensor if no longer needed
if not weight.requires_grad and not return_layernorm_output:
ln_out = ln_out_total = None
clear_tensor_data(ln_out, ln_out_total)
# ------------------------------------------------------
# Prepare output tensor
# Note: Perform tensor-parallel communication
# ------------------------------------------------------
out = None
if ub_overlap_rs_fprop:
out = reduce_scatter_out
elif parallel_mode == "row" and tp_size > 1:
nvtx_range_push(f"{nvtx_label}.row_parallel_comm")
out = gemm_out
if sequence_parallel:
out, _ = reduce_scatter_along_first_dim(out, tp_group)
elif tensor_parallel:
if symmetric_ar_type is not None:
out, _ = symmetric_all_reduce(out, tp_group, all_reduce_type=symmetric_ar_type)
else:
out, _ = allreduce(out, tp_group)
nvtx_range_pop(f"{nvtx_label}.row_parallel_comm")
else:
out = gemm_out
out = out.view(-1, *inp_shape[1:-1], out_features)
# ------------------------------------------------------
# Output tensor is ready to return...
# ------------------------------------------------------
if not weight.requires_grad:
if not return_layernorm_output:
ln_out = ln_out_total = None
clear_tensor_data(ln_out, ln_out_total)
# ------------------------------------------------------
# Cache state for backward pass
# ------------------------------------------------------
if is_grad_enabled:
ctx.weight_quantizer = weight_quantizer
......@@ -351,19 +401,15 @@ class _LayerNormLinear(torch.autograd.Function):
# Input with column-wise usage is needed for wgrad GEMM.
if backward_needs_input:
if isinstance(ln_out, QuantizedTensor):
if isinstance(ln_out, QuantizedTensorBase):
# For sequence parallel in vanilla FP8, rowwise data is
# to gather the input. For MXFP8, columnwise only data
# can be allgathered.
if isinstance(ln_out, MXFP8TensorBase) or not ctx.ln_out_needs_gather:
ln_out.update_usage(rowwise_usage=False)
# For force_hp_blockwise_ln_out_gather, we should
# be saving the unquantized ln_out to ctx.
assert not force_hp_blockwise_ln_out_gather
# Weight with column-wise usage is needed for dgrad GEMM.
if isinstance(weightmat, QuantizedTensor):
if isinstance(weightmat, QuantizedTensorBase):
weightmat.update_usage(columnwise_usage=True)
if cpu_offloading:
......@@ -406,7 +452,7 @@ class _LayerNormLinear(torch.autograd.Function):
)
ctx.save_for_backward(*tensors_to_save)
ctx.tensor_objects = tensor_objects
ctx.requires_dgrad = inp.requires_grad
ctx.requires_dgrad = inp_requires_grad
ctx.requires_wgrad = weight.requires_grad
ctx.quantized_weight = quantized_weight
if fuse_wgrad_accumulation and weight.requires_grad:
......@@ -439,7 +485,7 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.ub_bulk_wgrad = ub_bulk_wgrad
ctx.ub_bulk_dgrad = ub_bulk_dgrad
ctx.ub_name = ub_name
ctx.requires_dgrad = inp.requires_grad
ctx.requires_dgrad = inp_requires_grad
ctx.normalization = normalization
ctx.reduce_and_update_bwd_fp8_tensors = False
if ctx.fp8 and requires_grad(inp, ln_weight, ln_bias, weight, bias):
......@@ -450,29 +496,16 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.wgrad_store = wgrad_store
ctx.debug = debug
# Row Parallel Linear
if ub_overlap_rs_fprop:
out = rs_out
elif parallel_mode == "row":
nvtx_range_push(f"{nvtx_label}.row_parallel_comm")
if sequence_parallel:
out, _ = reduce_scatter_along_first_dim(out, tp_group)
elif tensor_parallel:
if symmetric_ar_type is not None:
out, _ = symmetric_all_reduce(out, tp_group, all_reduce_type=symmetric_ar_type)
else:
out, _ = allreduce(out, tp_group)
nvtx_range_pop(f"{nvtx_label}.row_parallel_comm")
# [*, in_features] -> [*, out_features] except first dimension changes for SP
out = out.view(-1, *inp_shape[1:-1], out_features)
# ------------------------------------------------------
# Cached state for backward pass is ready...
# ------------------------------------------------------
if return_layernorm_output:
if return_layernorm_output_gathered:
shape = list(inp.shape)
shape = list(inp_shape)
shape[0] *= tp_size
return out, ln_out_return.view(shape)
return out, ln_out_return.view_as(inp)
return out, ln_out_return.view(inp_shape)
return out
@staticmethod
......@@ -487,24 +520,6 @@ class _LayerNormLinear(torch.autograd.Function):
nvtx_label = f"{nvtx_label}.{ctx.ub_name}"
with torch.cuda.nvtx.range("_LayerNormLinear_backward"):
if (
ctx.fp8
and any(
[
ctx.ub_overlap_ag,
ctx.ub_overlap_rs_dgrad,
ctx.ub_bulk_dgrad,
ctx.ub_bulk_wgrad,
]
)
and (ctx.fp8_recipe is not None)
):
if not ctx.fp8_recipe.float8_per_tensor_scaling():
raise NotImplementedError(
"Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor"
" current scaling"
)
saved_tensors = ctx.saved_tensors
( # pylint: disable=unbalanced-tuple-unpacking
inputmat,
......@@ -549,66 +564,50 @@ class _LayerNormLinear(torch.autograd.Function):
if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation:
origin_weight.main_grad = main_grad
# Configure Userbuffers communication (comm+GEMM overlap)
ctx.ub_obj_gradout = None
ub_obj_dgrad = None
ub_obj_wgrad = None
ub_type_dgrad = None
ub_type_wgrad = None
dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]]
rs_out = None
dgrad_bulk = None
if ctx.ub_overlap_ag:
# Overlap grad_output all-gather with dgrad compute
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
ub_obj_dgrad = ctx.ub_obj_gradout
ub_type_dgrad = tex.CommOverlapType.AG
elif ctx.ub_overlap_rs_dgrad:
# Overlap dgrad reduce-scatter with dgrad compute
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
ub_obj_dgrad = ctx.ub_obj_gradout
ub_type_dgrad = tex.CommOverlapType.RS
rs_out = torch.empty(
dgrad_shape, dtype=ctx.activation_dtype, device=inputmat.device
)
else:
if ctx.ub_bulk_dgrad:
# Overlap inputmat all-gather with dgrad compute
# NOTE: Copying into communication buffer will always prefer rowwise data,
# and will copy columnwise data if rowwise does not exist. In that case,
# the all-gather will apply to the leading dimension of the transpose,
# which then needs to be interleaved correctly before WGRAD.
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
ub_obj_dgrad = ctx.ub_obj_gradout
ub_type_dgrad = tex.CommOverlapType.AG
ub_obj_dgrad.copy_into_buffer(ln_out, ctx.input_quantizer, local_chunk=True)
if ctx.ub_bulk_wgrad:
# Overlap dgrad reduce-scatter with wgrad compute
ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad")
ub_type_wgrad = tex.CommOverlapType.RS
ub_obj_wgrad.set_buffer_params(ctx.grad_input_quantizer)
dgrad_bulk = ub_obj_wgrad.get_buffer(ctx.grad_input_quantizer)
# --------------------------------------------------
# Prepare grad output tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
# --------------------------------------------------
# Configure quantizer for grad output tensor
# Note: dgrad GEMM requires row-wise usage, wgrad GEMM
# requires column-wise usage
if ctx.grad_output_quantizer is not None:
rowwise_usage = True
columnwise_usage = True
if ctx.ub_overlap_ag and isinstance(
ctx.grad_output_quantizer,
(Float8Quantizer, Float8CurrentScalingQuantizer),
):
# If data is in FP8 and communication is handled
# with Userbuffers, we compute FP8 transposes
# manually
columnwise_usage = False
ctx.grad_output_quantizer.set_usage(
rowwise=rowwise_usage,
columnwise=columnwise_usage,
)
quantizer = ctx.grad_output_quantizer
quantizer.set_usage(rowwise=True, columnwise=True)
if ctx.ub_overlap_ag:
# Userbuffers only supports communication for one
# tensor usage at a time. Configure quantizer with
# usage for only dgrad GEMM.
quantizer.set_usage(columnwise=False)
# Prepare grad output tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
......@@ -624,12 +623,21 @@ class _LayerNormLinear(torch.autograd.Function):
)
nvtx_range_pop(f"{nvtx_label}.grad_output_preprocess")
# Launch tensor-parallel communication for LayerNorm out tensor
# --------------------------------------------------
# Grad output tensor is ready for computing grad input...
# --------------------------------------------------
# --------------------------------------------------
# Prepare GEMM input tensor
# Note: Input tensor is needed for wgrad GEMM.
# Tensor-parallel communication is overlapped with dgrad
# GEMM.
# --------------------------------------------------
ln_out_total = None
ln_out_total_work = None
if ctx.ln_out_needs_gather and not ctx.ub_bulk_dgrad:
if ctx.ln_out_needs_gather:
quantizer = None
if ctx.input_quantizer is not None:
if ctx.input_quantizer is not None and not ctx.force_hp_blockwise_ln_out_gather:
quantizer = ctx.input_quantizer
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
# If data is in FP8, we compute FP8 transposes manually
......@@ -637,70 +645,92 @@ class _LayerNormLinear(torch.autograd.Function):
else:
# wgrad GEMM requires input with column-wise usage
quantizer.set_usage(rowwise=False, columnwise=True)
nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input")
# async_op is not compatible with high precision gather since
# gather_along_first_dim does not offer callback chaining.
gather_quantizer = None if ctx.force_hp_blockwise_ln_out_gather else quantizer
ln_out_total, ln_out_total_work = gather_along_first_dim(
ln_out,
ctx.tp_group,
async_op=True,
quantizer=gather_quantizer,
)
nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_input")
if ctx.ub_bulk_dgrad:
ln_out_total, _ = fill_userbuffers_buffer_for_all_gather(
ub_obj_dgrad,
ln_out,
quantizer,
ctx.tp_group,
)
else:
nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input")
ln_out_total, ln_out_total_work = gather_along_first_dim(
ln_out,
ctx.tp_group,
async_op=True,
quantizer=quantizer,
)
nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_input")
else:
ln_out_total = ln_out
# Check whether to output wgrad GEMM directly into main grad
if ctx.is_first_microbatch is not None:
accumulate_wgrad_into_param_main_grad = (
ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch
)
else:
accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
# dgrad GEMM
if ctx.grad_input_quantizer is not None:
ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False)
nvtx_range_push(f"{nvtx_label}.dgrad_gemm")
dgrad_gemm_use_split_accumulator = _2X_ACC_DGRAD
# --------------------------------------------------
# Input tensor is ready for computing grad weight...
# --------------------------------------------------
# --------------------------------------------------
# Compute grad input tensor
# Note: Gradient w.r.t. GEMM input (i.e. norm output).
# --------------------------------------------------
# Make sure required data is available
if isinstance(grad_output, QuantizedTensorBase):
grad_output.update_usage(rowwise_usage=True)
if ctx.weight_quantizer is not None and isinstance(weight, QuantizedTensorBase):
weight.update_usage(columnwise_usage=True)
# Choose whether to use GEMM kernel with split accumulator
use_split_accumulator = _2X_ACC_DGRAD
if ctx.fp8:
recipe = ctx.fp8_recipe
if hasattr(recipe, "fp8_gemm_dgrad"):
dgrad_gemm_use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator
use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator
# Update grad input quantizer
if ctx.grad_input_quantizer is not None:
ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False)
if ctx.weight_quantizer is not None and isinstance(weight, QuantizedTensor):
weight.update_usage(
rowwise_usage=ctx.weight_quantizer.rowwise_usage,
columnwise_usage=ctx.weight_quantizer.columnwise_usage,
# Output buffers for Userbuffers reduce-scatter
gemm_out = None
reduce_scatter_out = None
if ctx.ub_overlap_rs_dgrad:
reduce_scatter_out = torch.empty(
dgrad_shape, dtype=ctx.activation_dtype, device=grad_outputs[0].device
)
dgrad, *_ = general_gemm(
elif ctx.ub_bulk_wgrad:
gemm_out = ub_obj_wgrad.get_buffer(local_chunk=False)
# dgrad GEMM
# Note: dx = dy * w
nvtx_range_push(f"{nvtx_label}.dgrad_gemm")
gemm_out, *_, reduce_scatter_out = general_gemm(
weight,
grad_output,
get_workspace(),
layout="NN",
grad=True,
quantization_params=ctx.grad_input_quantizer,
out=dgrad_bulk,
out=gemm_out,
out_dtype=ctx.activation_dtype,
use_split_accumulator=dgrad_gemm_use_split_accumulator,
use_split_accumulator=use_split_accumulator,
ub=ub_obj_dgrad,
ub_type=ub_type_dgrad,
extra_output=rs_out,
extra_output=reduce_scatter_out,
bulk_overlap=ctx.ub_bulk_dgrad,
)
nvtx_range_pop(f"{nvtx_label}.dgrad_gemm")
# Launch tensor-parallel communication
# Prepare grad input tensor
# Note: Perform tensor-parallel communication
dgrad = None
dgrad_work = None
if ctx.ub_overlap_rs_dgrad:
dgrad = rs_out
elif ctx.parallel_mode == "column" and not ctx.ub_bulk_wgrad:
dgrad = reduce_scatter_out
elif ctx.ub_bulk_wgrad:
dgrad = ub_obj_wgrad.get_buffer(local_chunk=True)
elif ctx.parallel_mode == "column" and ctx.tp_size > 1:
nvtx_range_push(f"{nvtx_label}.column_parallel_comm_dgrad")
dgrad = gemm_out
if ctx.sequence_parallel:
if ctx.return_layernorm_output and ctx.return_layernorm_output_gathered:
dgrad = dgrad + grad_outputs[1].view_as(dgrad)
dgrad, dgrad_work = reduce_scatter_along_first_dim(
dgrad,
ctx.tp_group,
......@@ -709,41 +739,55 @@ class _LayerNormLinear(torch.autograd.Function):
else:
dgrad, dgrad_work = allreduce(dgrad, ctx.tp_group, async_op=True)
nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_dgrad")
else:
dgrad = gemm_out
# --------------------------------------------------
# Grad input tensor has been computed...
# --------------------------------------------------
# --------------------------------------------------
# Compute grad weight
# --------------------------------------------------
# Compute grad weight tensor
wgrad = None
if ctx.requires_wgrad:
# Synchronize tensor-parallel communication for input tensor
if ctx.ub_bulk_dgrad:
ln_out_total = ub_obj_dgrad.get_buffer(ctx.input_quantizer)
if ctx.fp8:
# FP8 GEMM on Hopper only supports TN layout so the gathered input must have
# a valid transpose.
if ln_out._data is None:
# All-gather executed on columnwise data and result is in rowwise data,
# so we need to fix the interleaving before WGRAD.
ln_out_total = _fix_gathered_fp8_transpose(ln_out_total, ctx.tp_size)
else:
# FP8 GEMM on Hopper only supports TN layout so the gathered input must
# have a valid transpose.
ln_out_total._create_transpose()
# Prepare input tensor
# Note: Synchronize tensor-parallel communication and
# make sure required data is available
if ln_out_total_work is not None:
ln_out_total_work.wait()
ln_out_total_work = None
if ctx.input_quantizer is not None and not isinstance(
ln_out_total, QuantizedTensor
):
# Async gather may have been done in BF16
# call quantizer after gather.
ctx.input_quantizer.set_usage(rowwise=False, columnwise=True)
ln_out_total = ctx.input_quantizer(ln_out_total)
# Make sure GEMM inputs have required data
if isinstance(ln_out_total, QuantizedTensor):
ln_out_total.update_usage(columnwise_usage=True)
if isinstance(grad_output, QuantizedTensor):
grad_output.update_usage(columnwise_usage=True)
if ctx.fp8 or ctx.debug:
if isinstance(ln_out_total, QuantizedTensorBase):
ln_out_total.update_usage(columnwise_usage=True)
else:
ctx.input_quantizer.set_usage(rowwise=False, columnwise=True)
ln_out_total = ctx.input_quantizer(ln_out_total)
# Prepare grad output tensor
# Note: Synchronize tensor-parallel communication and
# make sure required data is available
if ctx.ub_overlap_ag and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer):
# UB does not support overlapping grad output
# all-gather with wgrad GEMM. Also, we can't
# convert row-scaled MXFP8 to column-scaled, so we
# can't reuse the grad output that was gathered
# for the dgrad GEMM. We work around with blocking
# all-gather for column-scaled MXFP8 data.
ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
grad_output, _ = gather_along_first_dim(
grad_outputs[0],
ctx.tp_group,
quantizer=ctx.grad_output_quantizer,
)
if ctx.fp8 or ctx.debug:
if isinstance(grad_output, QuantizedTensorBase):
grad_output.update_usage(columnwise_usage=True)
else:
ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
grad_output = ctx.grad_output_quantizer(grad_output)
# Figure out whether to use split accumulator
use_split_accumulator = _2X_ACC_WGRAD
......@@ -752,55 +796,95 @@ class _LayerNormLinear(torch.autograd.Function):
if hasattr(recipe, "fp8_gemm_wgrad"):
use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator
# Output buffer for overlapping grad input
# Figure out whether to output wgrad GEMM directly into main grad
if ctx.is_first_microbatch is not None:
accumulate_wgrad_into_param_main_grad = (
ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch
)
else:
accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
# Output buffer for overlapping FP8 grad input
# reduce-scatter with wgrad GEMM
reduce_scatter_out = None
if ctx.ub_bulk_wgrad and ub_obj_wgrad.is_fp8_ubuf():
rs_out = torch.empty(
dgrad_shape, dtype=ctx.activation_dtype, device=inputmat.device
reduce_scatter_out = torch.empty(
dgrad_shape, dtype=ctx.activation_dtype, device=grad_outputs[0].device
)
# wgrad GEMM
# Note: Fuse with bgrad computation if needed
nvtx_range_push(f"{nvtx_label}.wgrad_gemm")
general_gemm_wgrad = functools.partial(
general_gemm,
out_dtype=(
# Arguments to include in wgrad GEMM closure
wgrad_gemm_kwargs = {
"workspace": get_workspace(),
"out_dtype": (
main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype
),
workspace=get_workspace(),
layout="NT",
grad=True,
bias=(bias if (grad_bias is None and not ctx.fp8) else None),
out=main_grad if ctx.fuse_wgrad_accumulation else None,
use_split_accumulator=use_split_accumulator,
accumulate=accumulate_wgrad_into_param_main_grad,
quantization_params=ctx.grad_weight_quantizer,
ub=ub_obj_wgrad,
ub_type=ub_type_wgrad,
extra_output=rs_out,
bulk_overlap=ctx.ub_bulk_wgrad,
)
"quantization_params": ctx.grad_weight_quantizer,
"accumulate": accumulate_wgrad_into_param_main_grad,
"layout": "NT",
"out": main_grad if ctx.fuse_wgrad_accumulation else None,
"bias": (bias if (grad_bias is None and not ctx.fp8) else None),
"use_split_accumulator": use_split_accumulator,
"grad": True,
"ub": ub_obj_wgrad,
"ub_type": ub_type_wgrad,
"extra_output": reduce_scatter_out,
"bulk_overlap": ctx.ub_bulk_wgrad,
}
def wgrad_gemm(
x: torch.Tensor,
dy: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Perform wgrad GEMM: dw = dy^T * x
May be fused with bgrad computation.
May be called outside of this function to enable
some advanced communication/compute overlapping.
"""
nvtx_range_push(f"{nvtx_label}.wgrad_gemm")
dw, db, *_ = general_gemm(x, dy, **wgrad_gemm_kwargs)
nvtx_range_pop(f"{nvtx_label}.wgrad_gemm")
return dw, db
# Choose whether to call wgrad GEMM now or delay
if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
ctx.wgrad_store.put([ln_out_total, grad_output], general_gemm_wgrad)
if (
wgrad_gemm_kwargs["ub"] is not None
or wgrad_gemm_kwargs["ub_type"] is not None
or wgrad_gemm_kwargs["extra_output"] is not None
or wgrad_gemm_kwargs["bulk_overlap"]
):
raise NotImplementedError(
"Delayed weight grad computation is not supported "
"with Userbuffers (tensor-parallel communication overlapping)"
)
ctx.wgrad_store.put([ln_out_total, grad_output], wgrad_gemm)
else:
wgrad, grad_bias_, _, rs_out = general_gemm_wgrad(ln_out_total, grad_output)
# Call wgrad GEMM now
wgrad, grad_bias_ = wgrad_gemm(ln_out_total, grad_output)
# Update grad bias if needed
if grad_bias is None:
grad_bias = grad_bias_
del grad_bias_
# Deallocate input tensor
# Deallocate input tensor if permitted
if not ctx.return_layernorm_output:
# TODO (pgadzinski) - deallocate transpose only # pylint: disable=fixme
clear_tensor_data(ln_out_total)
nvtx_range_pop(f"{nvtx_label}.wgrad_gemm")
# Update grad input if overlapping reduce-scatter with wgrad GEMM
if ctx.ub_bulk_wgrad:
if ub_obj_wgrad.is_fp8_ubuf():
dgrad = rs_out
dgrad = reduce_scatter_out
else:
dgrad = ub_obj_wgrad.get_buffer(None, local_chunk=True)
dgrad = ub_obj_wgrad.get_buffer(local_chunk=True).clone()
# --------------------------------------------------
# Grad weight has been computed...
# --------------------------------------------------
# Don't return grad bias if not needed
if not ctx.use_bias:
......@@ -879,7 +963,7 @@ class _LayerNormLinear(torch.autograd.Function):
nvtx_range_pop(f"{nvtx_label}.reduce_and_update_fp8_tensors")
# Scatter fp8 weight buffers
# if ctx.fp8 and not isinstance(weight, QuantizedTensor):
# if ctx.fp8 and not isinstance(weight, QuantizedTensorBase):
# _fsdp_scatter_tensors(ctx.fsdp_group, weight_fp8)
return (
......@@ -1405,6 +1489,10 @@ class LayerNormLinear(TransformerEngineBaseModule):
"Splitting QuantizedTensor into multiple params is not supported"
)
else:
warnings.warn(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
unfused_weights = [w.dequantize() for w in unfused_weights]
weight_tensor = noop_cat(unfused_weights)
......@@ -1511,7 +1599,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
grad_output_quantizer = None
output_quantizer = None
input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
input_quantizer.internal = False
input_quantizer.internal = True
weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
weight_quantizer.internal = True
if fp8_output:
......@@ -1579,3 +1667,12 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon
# parallel related
if self.sequence_parallel and self.parallel_mode == "row":
# customize grad_output_quantizer with amax reduction TP group
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
].with_amax_reduction = True
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
].amax_reduction_group = self.tp_group
......@@ -8,7 +8,6 @@ import warnings
from typing import Callable, Optional, Tuple, Union
from functools import reduce
from operator import mul as multiply_op
import functools
import torch
from torch.nn.parameter import Parameter
......@@ -20,6 +19,7 @@ import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch import torch_version
from .base import (
fill_userbuffers_buffer_for_all_gather,
get_workspace,
_ub_communicators,
get_ub,
......@@ -43,7 +43,6 @@ from ..utils import (
assert_dim_for_fp8_exec,
clear_tensor_data,
requires_grad,
non_tn_fp8_gemm_supported,
needs_quantized_gemm,
)
from ..distributed import (
......@@ -67,10 +66,11 @@ from ..tensor.float8_tensor import (
)
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ._common import apply_normalization, _fix_gathered_fp8_transpose, WeightGradStore
from ._common import apply_normalization, WeightGradStore
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ..tensor.quantized_tensor import (
QuantizedTensor,
QuantizedTensorBase,
Quantizer,
prepare_for_saving,
restore_from_saved,
......@@ -201,24 +201,16 @@ class _LayerNormMLP(torch.autograd.Function):
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# pylint: disable=missing-function-docstring
in_features, inp_shape = ln_weight.numel(), inp.shape
# Make sure input dimensions are compatible
in_features, inp_shape = ln_weight.numel(), inp.shape
assert inp_shape[-1] == in_features, "GEMM not possible"
inputmat = inp.view((-1, in_features))
if fp8:
assert_dim_for_fp8_exec(inputmat, fc1_weight, fc2_weight)
if any([ub_overlap_ag, ub_overlap_rs]) and not (
FP8GlobalStateManager.get_fp8_recipe().float8_per_tensor_scaling()
):
raise NotImplementedError(
"Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor"
" current scaling"
)
activation_func = _act_func(
activation, FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
)[0]
device = inp.device
# Cast for native AMP
inputmat = cast_if_needed(inputmat, activation_dtype)
......@@ -226,6 +218,38 @@ class _LayerNormMLP(torch.autograd.Function):
if ln_bias is not None:
ln_bias = cast_if_needed(ln_bias, activation_dtype)
tp_world_size = get_distributed_world_size(tp_group)
backwards_needs_fc1_input = is_grad_enabled and fc1_weight.requires_grad
device = inp.device
# Configure Userbuffers communication (comm+GEMM overlap)
ub_overlap_ag = ub_overlap_ag and is_grad_enabled and not return_layernorm_output_gathered
ub_overlap_rs = ub_overlap_rs and is_grad_enabled
# Choose whether to use GEMM kernel with split accumulator
use_split_accumulator = _2X_ACC_FPROP
if fp8:
recipe = FP8GlobalStateManager.get_fp8_recipe()
if hasattr(recipe, "fp8_gemm_fprop"):
use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator
# Configure quantizer for norm output
if fp8:
if fc1_input_quantizer is None:
raise ValueError("Missing quantizer for FC1 input tensor")
fc1_input_quantizer.set_usage(rowwise=True, columnwise=backwards_needs_fc1_input)
if sequence_parallel and isinstance(
fc1_input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
):
# All-gather is not supported with FP8 column-wise data
fc1_input_quantizer.set_usage(columnwise=False)
# Do TP communication in high precision if quantized format
# does not support communication
force_hp_fc1_input_gather = (
fp8 and sequence_parallel and isinstance(fc1_input_quantizer, Float8BlockQuantizer)
)
# for fp8 DelayedScaling: layernorm output = FP8
# only output of the linear is returned
# for return_layernorm_output: layernorm output = High precision, then cast to FP8
......@@ -241,29 +265,6 @@ class _LayerNormMLP(torch.autograd.Function):
# Kernels not available for norm fusion.
with_quantized_norm = False
tp_world_size = get_distributed_world_size(tp_group)
ub_overlap_ag = ub_overlap_ag and is_grad_enabled and not return_layernorm_output_gathered
ub_overlap_rs = ub_overlap_rs and is_grad_enabled
backwards_needs_fc1_input = is_grad_enabled and fc1_weight.requires_grad
# TODO(kwyss): Support FP8 allgather of Float8BlockQuantizer recipe
force_hp_fc1_input_gather = (
fp8 and sequence_parallel and isinstance(fc1_input_quantizer, Float8BlockQuantizer)
) # Perform TP communication in high precision.
# Configure quantizer for norm output
if fp8:
if fc1_input_quantizer is None:
raise ValueError("Missing quantizer for FC1 input tensor")
columnwise_usage = backwards_needs_fc1_input
if (
columnwise_usage
and sequence_parallel
and not isinstance(fc1_input_quantizer, MXFP8Quantizer)
):
columnwise_usage = False
fc1_input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
# Apply normalization
ln_out, mu, rsigma = apply_normalization(
inputmat,
......@@ -297,39 +298,43 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_input_quantizer.set_usage(rowwise=True, columnwise=False)
ln_out_total = fc1_input_quantizer(ln_out_total)
else:
quantizer = None
if fp8 or debug:
quantizer = fc1_input_quantizer
if not with_quantized_norm and not force_hp_fc1_input_gather:
ln_out = fc1_input_quantizer(ln_out)
fc1_input_quantizer.set_usage(rowwise=True, columnwise=False)
if ub_overlap_ag:
# Copy into Userbuffers buffer
ub_obj_lnout = get_ub("fc1_fprop")
ub_obj_lnout.get_buffer(fc1_input_quantizer, local_chunk=True).copy_(ln_out)
ln_out_total = ub_obj_lnout.get_buffer(fc1_input_quantizer)
ln_out_total, _ = fill_userbuffers_buffer_for_all_gather(
ub_obj_lnout,
ln_out,
quantizer,
tp_group,
)
else:
# All-gather with NCCL
ln_out_total, _ = gather_along_first_dim(
ln_out,
tp_group,
quantizer=(fc1_input_quantizer if fp8 or debug else None),
quantizer=quantizer,
)
else:
# NOTE: force_hp_fc1_input_gather is redundant with else, but
# here for clarity. We should not quantize ln_out if bwd needs
# to gather in hp.
if (fp8 or debug) and not with_quantized_norm and not force_hp_fc1_input_gather:
if (fp8 or debug) and not with_quantized_norm:
ln_out = fc1_input_quantizer(ln_out)
ln_out_total = ln_out
# Cast weights to expected dtype
fc1_weight_final = fc1_weight
fc2_weight_final = fc2_weight
if fp8 or debug:
# If weights are not quantized, we call get_weight_workspace,
# which handles weight caching etc.
# FP8 cast to workspace buffer
update_workspace = is_first_microbatch is None or is_first_microbatch
fc1_weight_quantizer.set_usage(rowwise=True, columnwise=True)
fc2_weight_quantizer.set_usage(rowwise=True, columnwise=True)
fc1_weight_final = module.get_weight_workspace(
tensor=fc1_weight,
quantizer=fc1_weight_quantizer,
......@@ -339,7 +344,6 @@ class _LayerNormMLP(torch.autograd.Function):
fsdp_group=fsdp_group,
workspace_dtype=activation_dtype,
)
fc2_weight_quantizer.set_usage(rowwise=True, columnwise=True)
fc2_weight_final = module.get_weight_workspace(
tensor=fc2_weight,
quantizer=fc2_weight_quantizer,
......@@ -349,6 +353,8 @@ class _LayerNormMLP(torch.autograd.Function):
fsdp_group=fsdp_group,
workspace_dtype=activation_dtype,
)
fc1_weight_final.update_usage(rowwise_usage=True)
fc2_weight_final.update_usage(rowwise_usage=True)
else:
fc1_weight_final = cast_if_needed(fc1_weight_final, activation_dtype)
fc2_weight_final = cast_if_needed(fc2_weight_final, activation_dtype)
......@@ -356,6 +362,7 @@ class _LayerNormMLP(torch.autograd.Function):
# Cast biases to expected dtype
bias_dtype = activation_dtype
if needs_quantized_gemm(ln_out_total) and activation_dtype == torch.float32:
# cuBLAS does not support FP8 GEMM with FP32 bias, so we cast to BF16
bias_dtype = torch.bfloat16
if fc1_bias is not None:
fc1_bias = cast_if_needed(fc1_bias, bias_dtype)
......@@ -369,7 +376,9 @@ class _LayerNormMLP(torch.autograd.Function):
if fc1_weight_quantizer is not None:
fc1_weight_quantizer.calibrate(fc1_weight)
# ------------------------------------------------------
# FC1 GEMM
# ------------------------------------------------------
# There are 2 fusions possible:
# - gemm_gelu_fusion - default for full precision, optional for fp8 - need to turn on gemm_gelu_fusion,
......@@ -401,11 +410,16 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_bias if not bias_gelu_fusion else None
), # otherwise bias is added later (fused with gelu)
gelu=gemm_gelu_fusion,
accumulate=_2X_ACC_FPROP,
use_split_accumulator=use_split_accumulator,
ub=ub_obj_lnout,
ub_type=tex.CommOverlapType.AG if ub_overlap_ag else None,
)
# ------------------------------------------------------
# Finished FC1 GEMM...
# ------------------------------------------------------
# Deallocate FC1 GEMM input tensor if no longer needed
if not is_grad_enabled and (ln_out_total is not ln_out_return):
clear_tensor_data(ln_out_total)
......@@ -439,45 +453,66 @@ class _LayerNormMLP(torch.autograd.Function):
fc2_input_quantizer.calibrate(act_out)
fc2_weight_quantizer.calibrate(fc2_weight)
# Configure Userbuffers reduce-scatter if needed
ub_obj_fc2out = None
rs_out = None
fc2_out = None
reduce_scatter_out = None
if ub_overlap_rs:
ub_obj_fc2out = get_ub("fc2_fprop")
dim_size = list(act_out.size())
dim_size[0] = dim_size[0] // tp_world_size
dim_size[1] = fc2_weight.size(0)
rs_out = torch.empty(dim_size, dtype=activation_dtype, device=device)
else:
dim_size = list(act_out.size())
dim_size[1] = fc2_weight.size(0)
fc2_out = torch.empty(dim_size, dtype=activation_dtype, device=device)
dim_size[0] //= tp_world_size
dim_size[-1] = fc2_weight.size(0)
reduce_scatter_out = torch.empty(dim_size, dtype=activation_dtype, device=device)
# ------------------------------------------------------
# FC2 GEMM
_ = general_gemm(
# ------------------------------------------------------
gemm_out, *_, reduce_scatter_out = general_gemm(
fc2_weight_final,
act_out,
get_workspace(),
out_dtype=activation_dtype,
bias=fc2_bias,
quantization_params=fc2_output_quantizer,
out=fc2_out,
use_split_accumulator=_2X_ACC_FPROP,
use_split_accumulator=use_split_accumulator,
ub=ub_obj_fc2out,
ub_type=tex.CommOverlapType.RS if ub_overlap_rs else None,
extra_output=rs_out,
extra_output=reduce_scatter_out,
)
# ------------------------------------------------------
# Finished FC2 GEMM...
# ------------------------------------------------------
# Deallocate tensors if no longer needed
if not is_grad_enabled:
clear_tensor_data(act_out, fc1_out_without_bias, fc1_out)
# Prepare output tensor
# Note: Perform tensor-parallel communication if needed
fc2_out = None
if ub_overlap_rs:
fc2_out = reduce_scatter_out
elif set_parallel_mode and sequence_parallel:
fc2_out, _ = reduce_scatter_along_first_dim(gemm_out, tp_group)
elif set_parallel_mode and tensor_parallel:
if symmetric_ar_type is not None:
fc2_out, _ = symmetric_all_reduce(
gemm_out, tp_group, all_reduce_type=symmetric_ar_type
)
else:
fc2_out, _ = allreduce(gemm_out, tp_group)
else:
fc2_out = gemm_out
fc2_out = fc2_out.view(-1, *inp_shape[1:-1], fc2_out.shape[-1])
# Weight with column-wise usage is needed for dgrad GEMM.
# Cache state for backward pass
if is_grad_enabled:
if isinstance(fc1_weight_final, QuantizedTensor):
# Weight with column-wise usage is needed for dgrad GEMM.
if isinstance(fc1_weight_final, QuantizedTensorBase):
fc1_weight_final.update_usage(columnwise_usage=True)
if isinstance(fc2_weight_final, QuantizedTensor):
if isinstance(fc2_weight_final, QuantizedTensorBase):
fc2_weight_final.update_usage(columnwise_usage=True)
if not is_grad_enabled:
clear_tensor_data(act_out, fc1_out_without_bias, fc1_out)
else:
if cpu_offloading:
mark_activation_offload(
inputmat, mu, rsigma, ln_out, fc1_out, fc1_out_without_bias, act_out
......@@ -504,8 +539,6 @@ class _LayerNormMLP(torch.autograd.Function):
if not return_layernorm_output:
clear_tensor_data(ln_out)
ln_out = None
elif force_hp_fc1_input_gather:
assert not isinstance(ln_out, QuantizedTensor)
if not fc2_weight.requires_grad:
clear_tensor_data(act_out)
act_out = None
......@@ -592,28 +625,12 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.wgrad_store = wgrad_store
# Row Parallel Linear
if ub_overlap_rs:
fc2_out = rs_out
elif set_parallel_mode and sequence_parallel:
fc2_out, _ = reduce_scatter_along_first_dim(fc2_out, tp_group)
elif set_parallel_mode and tensor_parallel:
if symmetric_ar_type is not None:
fc2_out, _ = symmetric_all_reduce(
fc2_out, tp_group, all_reduce_type=symmetric_ar_type
)
else:
fc2_out, _ = allreduce(fc2_out, tp_group)
# [*, in_features] -> [*, out_features] except first dimension changes for SP
fc2_out = fc2_out.view(-1, *inp_shape[1:-1], fc2_out.shape[-1])
if return_layernorm_output:
if return_layernorm_output_gathered:
shape = list(inp_shape)
shape[0] *= tp_size
return fc2_out, ln_out_return.view(shape)
return fc2_out, ln_out_return.view_as(inp)
return fc2_out, ln_out_return.view(inp_shape)
return fc2_out
@staticmethod
......@@ -622,24 +639,6 @@ class _LayerNormMLP(torch.autograd.Function):
) -> Tuple[Union[torch.Tensor, None], ...]:
# pylint: disable=missing-function-docstring
with torch.cuda.nvtx.range("_LayerNormMLP_backward"):
if (
ctx.fp8
and any(
[
ctx.ub_overlap_ag,
ctx.ub_overlap_rs_dgrad,
ctx.ub_bulk_dgrad,
ctx.ub_bulk_wgrad,
]
)
and (ctx.fp8_recipe is not None)
):
if not ctx.fp8_recipe.float8_per_tensor_scaling():
raise NotImplementedError(
"Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor"
" current scaling"
)
saved_tensors = ctx.saved_tensors
( # pylint: disable=unbalanced-tuple-unpacking
inputmat,
......@@ -699,6 +698,16 @@ class _LayerNormMLP(torch.autograd.Function):
# fc2_weight_fp8 if ctx.fp8 and not isinstance(fc2_weight, Float8Tensor) else None,
# )
# Choose whether to use GEMM kernel with split accumulator
dgrad_use_split_accumulator = _2X_ACC_DGRAD
wgrad_use_split_accumulator = _2X_ACC_WGRAD
if ctx.fp8:
recipe = ctx.fp8_recipe
if hasattr(recipe, "fp8_gemm_dgrad"):
dgrad_use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator
if hasattr(recipe, "fp8_gemm_wgrad"):
wgrad_use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator
# No need to do bulk DGRAD/WGRAD overlap if WGRAD is not required
ctx.ub_bulk_dgrad = ctx.fc1_weight_requires_grad and ctx.ub_bulk_dgrad
ctx.ub_bulk_wgrad = ctx.fc1_weight_requires_grad and ctx.ub_bulk_wgrad
......@@ -707,20 +716,13 @@ class _LayerNormMLP(torch.autograd.Function):
# Note: dgrad GEMM requires row-wise usage, wgrad GEMM
# requires column-wise usage
if ctx.fc2_grad_output_quantizer is not None:
rowwise_usage = True
columnwise_usage = True
if ctx.ub_overlap_ag and isinstance(
ctx.fc2_grad_output_quantizer,
(Float8Quantizer, Float8CurrentScalingQuantizer),
):
# If data is in FP8 and communication is handled
# with Userbuffers, we compute FP8 transposes
# manually
columnwise_usage = False
ctx.fc2_grad_output_quantizer.set_usage(
rowwise=rowwise_usage,
columnwise=columnwise_usage,
)
quantizer = ctx.fc2_grad_output_quantizer
quantizer.set_usage(rowwise=True, columnwise=True)
if ctx.ub_overlap_ag:
# Userbuffers only supports communication for one
# tensor usage at a time. Configure quantizer with
# usage for only dgrad GEMM.
quantizer.set_usage(columnwise=False)
# Prepare FC2 grad output tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
......@@ -738,14 +740,10 @@ class _LayerNormMLP(torch.autograd.Function):
# Launch tensor-parallel communication for FC1 GEMM input
ln_out_total = None
ln_out_total_work = None
if (
ctx.fc1_weight_requires_grad
and ctx.tensor_parallel
and ctx.sequence_parallel
and not ctx.ub_bulk_dgrad
):
ub_obj_fc1_dgrad = None
if ctx.fc1_weight_requires_grad and ctx.tensor_parallel and ctx.sequence_parallel:
quantizer = None
if ctx.fp8 or ctx.debug:
if ctx.fp8 or ctx.debug and not ctx.force_hp_fc1_input_gather:
quantizer = ctx.fc1_input_quantizer
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
# If data is in FP8, we compute FP8 transposes manually
......@@ -753,13 +751,21 @@ class _LayerNormMLP(torch.autograd.Function):
else:
# wgrad GEMM requires input with column-wise usage
quantizer.set_usage(rowwise=False, columnwise=True)
gather_quantizer = None if ctx.force_hp_fc1_input_gather else quantizer
ln_out_total, ln_out_total_work = gather_along_first_dim(
ln_out,
ctx.tp_group,
async_op=True,
quantizer=gather_quantizer,
)
if ctx.ub_bulk_dgrad:
ub_obj_fc1_dgrad = get_ub("fc1_dgrad")
ln_out_total, _ = fill_userbuffers_buffer_for_all_gather(
ub_obj_fc1_dgrad,
ln_out,
quantizer,
ctx.tp_group,
)
else:
ln_out_total, ln_out_total_work = gather_along_first_dim(
ln_out,
ctx.tp_group,
async_op=True,
quantizer=quantizer,
)
else:
ln_out_total = ln_out
......@@ -770,6 +776,11 @@ class _LayerNormMLP(torch.autograd.Function):
)
else:
accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
# --------------------------------------------------
# FC2 DGRAD
# --------------------------------------------------
# There are 6 possible fusion paths
# 1 high-precision bias_gelu_fusion: gemm, FC1_bias + gelu,
# 2 high-precision fc2_dgrad_gemm_gelu_fusion: gemm + gelu, FC1_bias + quantize
......@@ -784,12 +795,15 @@ class _LayerNormMLP(torch.autograd.Function):
and (not ctx.debug)
)
# FC2 DGRAD; Unconditional
if ctx.fc2_weight_quantizer is not None and isinstance(ctx.fc2_weight, QuantizedTensor):
ctx.fc2_weight.update_usage(
rowwise_usage=ctx.fc2_weight_quantizer.rowwise_usage,
columnwise_usage=ctx.fc2_weight_quantizer.columnwise_usage,
)
# Make sure required data is available
if isinstance(grad_output, QuantizedTensorBase):
grad_output.update_usage(rowwise_usage=True)
if ctx.fc2_weight_quantizer is not None and isinstance(
ctx.fc2_weight, QuantizedTensorBase
):
ctx.fc2_weight.update_usage(columnwise_usage=True)
# Perform GEMM
gemm_output, *_ = general_gemm(
fc2_weight,
grad_output,
......@@ -804,52 +818,107 @@ class _LayerNormMLP(torch.autograd.Function):
out_dtype=ctx.activation_dtype,
gelu=fc2_dgrad_gemm_gelu_fusion,
gelu_in=fc1_out if fc2_dgrad_gemm_gelu_fusion else None,
use_split_accumulator=_2X_ACC_DGRAD,
use_split_accumulator=dgrad_use_split_accumulator,
ub=ub_obj_fc2_dgrad,
ub_type=tex.CommOverlapType.AG if ctx.ub_overlap_ag else None,
)
# Prepare input grad tensor
dact = None
fc2_dgrad = None
if fc2_dgrad_gemm_gelu_fusion:
dact = gemm_output
fc2_dgrad = None
else:
fc2_dgrad = gemm_output
# --------------------------------------------------
# Finished FC2 DGRAD...
# --------------------------------------------------
# --------------------------------------------------
# FC2 WGRAD
# --------------------------------------------------
fc2_wgrad = None
if ctx.fc2_weight_requires_grad:
if isinstance(act_out, QuantizedTensor):
act_out.update_usage(rowwise_usage=True, columnwise_usage=True)
if isinstance(grad_output, QuantizedTensor):
grad_output.update_usage(rowwise_usage=True, columnwise_usage=True)
# Prepare input tensor
# Note: Synchronize tensor-parallel communication and
# make sure required data is available
if ctx.fp8 or ctx.debug:
if isinstance(act_out, QuantizedTensorBase):
act_out.update_usage(columnwise_usage=True)
else:
ctx.fc2_input_quantizer.set_usage(rowwise=False, columnwise=True)
act_out = ctx.fc2_input_quantizer(act_out)
# Prepare grad output tensor
# Note: Synchronize tensor-parallel communication and
# make sure required data is available
if ctx.ub_overlap_ag and isinstance(ctx.fc2_grad_output_quantizer, MXFP8Quantizer):
# UB does not support overlapping grad output
# all-gather with wgrad GEMM. Also, we can't
# convert row-scaled MXFP8 to column-scaled, so we
# can't reuse the grad output that was gathered
# for the dgrad GEMM. We work around with blocking
# all-gather for column-scaled MXFP8 data.
ctx.fc2_grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
grad_output, _ = gather_along_first_dim(
grad_outputs[0],
ctx.tp_group,
quantizer=ctx.fc2_grad_output_quantizer,
)
if ctx.fp8 or ctx.debug:
if isinstance(grad_output, QuantizedTensorBase):
grad_output.update_usage(columnwise_usage=True)
else:
ctx.fc2_grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
grad_output = ctx.fc2_grad_output_quantizer(grad_output)
# Whether to set grad arg in general_gemm
grad_arg = True
if ctx.fp8 and ctx.fp8_recipe.float8_block_scaling():
grad_arg = False
general_gemm_fc2_wgrad = functools.partial(
general_gemm,
out_dtype=(
# Arguments to include in wgrad GEMM closure
fc2_wgrad_gemm_kwargs = {
"workspace": get_workspace(),
"out_dtype": (
origin_fc2_weight.main_grad.dtype
if ctx.fuse_wgrad_accumulation
else ctx.activation_dtype
),
workspace=get_workspace(),
quantization_params=ctx.fc2_grad_weight_quantizer, # wgrad in high precision
layout="NT",
grad=grad_arg,
bias=fc2_bias if fc2_bias is not None and fc2_bias_grad is None else None,
accumulate=accumulate_wgrad_into_param_main_grad,
use_split_accumulator=_2X_ACC_WGRAD,
out=origin_fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
)
"quantization_params": ctx.fc2_grad_weight_quantizer, # wgrad in high precision
"accumulate": accumulate_wgrad_into_param_main_grad,
"layout": "NT",
"out": origin_fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
"bias": fc2_bias if fc2_bias is not None and fc2_bias_grad is None else None,
"use_split_accumulator": wgrad_use_split_accumulator,
"grad": grad_arg,
}
def fc2_wgrad_gemm(
x: torch.Tensor,
dy: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Perform FC2 WGRAD GEMM
May be called outside of this function to enable
some advanced communication/compute overlapping.
"""
dw, db, *_ = general_gemm(x, dy, **fc2_wgrad_gemm_kwargs)
return dw, db
# Choose whether to call wgrad GEMM now or delay
if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
ctx.wgrad_store.put([act_out, grad_output], general_gemm_fc2_wgrad)
fc2_wgrad = None
ctx.wgrad_store.put([act_out, grad_output], fc2_wgrad_gemm)
else:
fc2_wgrad, fc2_bias_grad_, *_ = general_gemm_fc2_wgrad(
act_out,
grad_output,
)
# Call wgrad GEMM now
fc2_wgrad, fc2_bias_grad_ = fc2_wgrad_gemm(act_out, grad_output)
# Update grad bias if needed
if fc2_bias_grad is None:
if (
ctx.fp8
......@@ -858,12 +927,17 @@ class _LayerNormMLP(torch.autograd.Function):
):
# BGRAD not fused with GEMM for float8 blockwise gemm.
fc2_bias_grad_ = act_out.view(-1, act_out.shape[-1]).sum(dim=0)
fc2_bias_grad = fc2_bias_grad_
del fc2_bias_grad_
# Deallocate input tensor if permitted
if ctx.wgrad_store is not None and not ctx.wgrad_store.delay_wgrad_compute():
clear_tensor_data(act_out)
# --------------------------------------------------
# Finished FC2 WGRAD...
# --------------------------------------------------
# bias computation
fc1_bias_grad = None
fuse_gemm_and_bias_fc1_wgrad = False
......@@ -927,63 +1001,69 @@ class _LayerNormMLP(torch.autograd.Function):
ub_obj_fc1_dgrad = None
ub_obj_fc1_wgrad = None
ub_type_fc1_dgrad = None
ub_type_fc1_wgrad = None
fc1_dgrad_shape = [reduce(multiply_op, inputmat.shape[:-1]), inputmat.shape[-1]]
fc1_dgrad_rs_out = None
fc1_dgrad_bulk = None
if ctx.ub_overlap_rs_dgrad:
# Overlap DGRAD+RS
ub_obj_fc1_dgrad = get_ub("fc1_dgrad")
ub_type_fc1_dgrad = tex.CommOverlapType.RS
fc1_dgrad_rs_out = torch.empty(
fc1_dgrad_shape, dtype=ctx.activation_dtype, device="cuda"
)
else:
if ctx.ub_bulk_dgrad:
# Overlap ln_out all-gather with DGRAD compute
# NOTE: Copying into communication buffer will always prefer rowwise data,
# and will copy columnwise data if rowwise does not exist. In that case,
# the all-gather will apply to the leading dimension of the transpose,
# which then needs to be interleaved correctly before WGRAD.
ub_obj_fc1_dgrad = get_ub("fc1_dgrad")
ub_type_fc1_dgrad = tex.CommOverlapType.AG
ub_obj_fc1_dgrad.copy_into_buffer(
ln_out, ctx.fc1_input_quantizer, local_chunk=True
)
if ctx.ub_bulk_wgrad:
# Overlap FC1 DGRAD reduce-scatter with WGRAD compute
ub_obj_fc1_wgrad = get_ub("fc1_wgrad")
fc1_dgrad_bulk = ub_obj_fc1_wgrad.get_buffer(None)
ub_type_fc1_wgrad = tex.CommOverlapType.RS
# FC1 DGRAD: Unconditional
# --------------------------------------------------
# FC1 DGRAD
# --------------------------------------------------
# Make sure required data is available
if ctx.fc1_weight_quantizer is not None and isinstance(
ctx.fc1_weight_quantizer, QuantizedTensor
ctx.fc1_weight_quantizer, QuantizedTensorBase
):
ctx.fc1_weight.update_usage(
rowwise_usage=ctx.fc1_weight_quantizer.rowwise_usage,
columnwise_usage=ctx.fc1_weight_quantizer.columnwise_usage,
ctx.fc1_weight.update_usage(columnwise_usage=True)
# Output buffers for Userbuffers reduce-scatter
gemm_out = None
reduce_scatter_out = None
if ctx.ub_overlap_rs_dgrad:
reduce_scatter_out = torch.empty(
fc1_dgrad_shape, dtype=ctx.activation_dtype, device="cuda"
)
fc1_dgrad, *_, fc1_dgrad_rs_out = general_gemm(
if ctx.ub_bulk_wgrad:
gemm_out = ub_obj_fc1_wgrad.get_buffer(local_chunk=False)
# dgrad GEMM
gemm_out, *_, reduce_scatter_out = general_gemm(
fc1_weight,
dact,
get_workspace(),
out=fc1_dgrad_bulk,
out=gemm_out,
out_dtype=ctx.activation_dtype,
quantization_params=ctx.fc1_grad_input_quantizer,
layout="NN",
grad=True,
use_split_accumulator=dgrad_use_split_accumulator,
ub=ub_obj_fc1_dgrad,
ub_type=ub_type_fc1_dgrad,
extra_output=fc1_dgrad_rs_out,
extra_output=reduce_scatter_out,
bulk_overlap=ctx.ub_bulk_dgrad,
)
# Overlap dgrad-RS/AR with wgrad
# Prepare grad input tensor
# Note: Perform tensor-parallel communication
fc1_dgrad = None
fc1_dgrad_work = None
if ctx.ub_overlap_rs_dgrad:
fc1_dgrad = fc1_dgrad_rs_out
fc1_dgrad = reduce_scatter_out
elif ctx.ub_bulk_wgrad:
fc1_dgrad = ub_obj_fc1_wgrad.get_buffer(local_chunk=True)
elif ctx.set_parallel_mode and not ctx.ub_bulk_wgrad:
fc1_dgrad = gemm_out
if ctx.sequence_parallel:
if ctx.return_layernorm_output and ctx.return_layernorm_output_gathered:
fc1_dgrad = fc1_dgrad + grad_outputs[1].view_as(fc1_dgrad)
......@@ -994,90 +1074,125 @@ class _LayerNormMLP(torch.autograd.Function):
)
elif ctx.tensor_parallel:
fc1_dgrad, fc1_dgrad_work = allreduce(fc1_dgrad, ctx.tp_group, async_op=True)
else:
fc1_dgrad = gemm_out
# --------------------------------------------------
# Finished FC1 DGRAD...
# --------------------------------------------------
# --------------------------------------------------
# FC1 WGRAD
# --------------------------------------------------
fc1_wgrad = None
if ctx.fc1_weight_requires_grad:
# Synchronize tensor-parallel communication for FC1 GEMM input tensor
if ctx.ub_bulk_dgrad:
ln_out_total = ub_obj_fc1_dgrad.get_buffer(ctx.fc1_input_quantizer)
if ctx.fp8:
if ln_out._data is None:
# All-gather executed on columnwise data and result is in rowwise data,
# so we need to fix the interleaving before WGRAD.
ln_out_total = _fix_gathered_fp8_transpose(ln_out_total, ctx.tp_size)
elif not non_tn_fp8_gemm_supported():
# FP8 GEMM on Hopper only supports TN layout so the gathered input must
# have a valid transpose.
ln_out_total._create_transpose()
# Prepare input tensor
# Note: Synchronize tensor-parallel communication and
# make sure required data is available
if ln_out_total_work is not None:
ln_out_total_work.wait()
ln_out_total_work = None
if ctx.fc1_input_quantizer is not None and not isinstance(
ln_out_total, QuantizedTensor
):
# Async gather in BF16 does not asynchronously
# call quantizer after gather.
ctx.fc1_input_quantizer.set_usage(rowwise=False, columnwise=True)
ln_out_total = ctx.fc1_input_quantizer(ln_out_total)
# Make sure GEMM inputs have required data
if isinstance(ln_out_total, QuantizedTensor):
ln_out_total.update_usage(columnwise_usage=True)
if isinstance(dact, QuantizedTensor):
dact.update_usage(columnwise_usage=True)
if ctx.fp8 or ctx.debug:
if isinstance(ln_out_total, QuantizedTensorBase):
ln_out_total.update_usage(columnwise_usage=True)
else:
ctx.fc1_input_quantizer.set_usage(rowwise=False, columnwise=True)
ln_out_total = ctx.fc1_input_quantizer(ln_out_total)
# Prepare grad output tensor
# Note: Synchronize tensor-parallel communication and
# make sure required data is available
if ctx.fp8 or ctx.debug:
if isinstance(dact, QuantizedTensorBase):
dact.update_usage(columnwise_usage=True)
else:
ctx.fc1_grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
dact = ctx.fc1_grad_output_quantizer(dact)
# Output buffer for overlapping grad input
# reduce-scatter with wgrad GEMM
reduce_scatter_out = None
if ctx.ub_bulk_wgrad and ub_obj_fc1_wgrad.is_fp8_ubuf():
fc1_dgrad_rs_out = torch.empty(
reduce_scatter_out = torch.empty(
fc1_dgrad_shape, dtype=ctx.activation_dtype, device="cuda"
)
# wgrad GEMM
general_gemm_fc1_wgrad = functools.partial(
general_gemm,
out_dtype=(
# Arguments to include in wgrad GEMM closure
fc1_wgrad_gemm_kwargs = {
"workspace": get_workspace(),
"out_dtype": (
origin_fc1_weight.main_grad.dtype
if ctx.fuse_wgrad_accumulation
else ctx.activation_dtype
),
workspace=get_workspace(),
layout="NT",
quantization_params=ctx.fc1_grad_weight_quantizer,
grad=fuse_gemm_and_bias_fc1_wgrad,
bias=fc1_bias if fuse_gemm_and_bias_fc1_wgrad else None,
accumulate=accumulate_wgrad_into_param_main_grad,
out=origin_fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
ub=ub_obj_fc1_wgrad,
ub_type=tex.CommOverlapType.RS if ctx.ub_bulk_wgrad else None,
extra_output=fc1_dgrad_rs_out,
bulk_overlap=ctx.ub_bulk_wgrad,
)
"quantization_params": ctx.fc1_grad_weight_quantizer,
"accumulate": accumulate_wgrad_into_param_main_grad,
"layout": "NT",
"out": origin_fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
"bias": fc1_bias if fuse_gemm_and_bias_fc1_wgrad else None,
"use_split_accumulator": wgrad_use_split_accumulator,
"grad": fuse_gemm_and_bias_fc1_wgrad,
"ub": ub_obj_fc1_wgrad,
"ub_type": ub_type_fc1_wgrad,
"extra_output": reduce_scatter_out,
"bulk_overlap": ctx.ub_bulk_wgrad,
}
def fc1_wgrad_gemm(
x: torch.Tensor,
dy: torch.Tensor,
_is_delayed: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Perform FC1 WGRAD GEMM
May be called outside of this function to enable
some advanced communication/compute overlapping.
"""
dw, db, *_ = general_gemm(x, dy, **fc1_wgrad_gemm_kwargs)
return dw, db
# Choose whether to call wgrad GEMM now or delay
if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
ctx.wgrad_store.put([ln_out_total, dact], general_gemm_fc1_wgrad)
if (
fc1_wgrad_gemm_kwargs["ub"] is not None
or fc1_wgrad_gemm_kwargs["ub_type"] is not None
or fc1_wgrad_gemm_kwargs["extra_output"] is not None
or fc1_wgrad_gemm_kwargs["bulk_overlap"]
):
raise NotImplementedError(
"Delayed weight grad computation is not supported "
"with Userbuffers (tensor-parallel communication overlapping)"
)
ctx.wgrad_store.put([ln_out_total, dact], fc1_wgrad_gemm)
fc1_wgrad = None
if fuse_gemm_and_bias_fc1_wgrad:
fc1_bias_grad = None
else:
fc1_wgrad_outputs = general_gemm_fc1_wgrad(
ln_out_total,
dact,
)
clear_tensor_data(ln_out_total, dact)
# Call wgrad GEMM now
fc1_wgrad_outputs = fc1_wgrad_gemm(ln_out_total, dact)
if fuse_gemm_and_bias_fc1_wgrad:
fc1_wgrad, fc1_bias_grad, *_ = fc1_wgrad_outputs
fc1_wgrad, fc1_bias_grad = fc1_wgrad_outputs
else:
fc1_wgrad, *_ = fc1_wgrad_outputs
fc1_wgrad, _ = fc1_wgrad_outputs
# Deallocate tensors if permitted
clear_tensor_data(dact)
if not ctx.return_layernorm_output_gathered:
clear_tensor_data(ln_out_total)
# Update grad input if overlapping reduce-scatter with wgrad GEMM
if ctx.ub_bulk_wgrad:
if ub_obj_fc1_wgrad.is_fp8_ubuf():
fc1_dgrad = fc1_dgrad_rs_out
fc1_dgrad = reduce_scatter_out
else:
fc1_dgrad = ub_obj_fc1_wgrad.get_buffer(None, local_chunk=True)
fc1_dgrad = ub_obj_fc1_wgrad.get_buffer(local_chunk=True).clone()
# --------------------------------------------------
# Finished FC1 WGRAD...
# --------------------------------------------------
# Make sure all tensor-parallel communication is finished
if ln_out_total_work is not None:
......@@ -1748,7 +1863,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
) = [None] * 12
if self.fp8:
fc1_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
fc1_input_quantizer.internal = False # temporary
fc1_input_quantizer.internal = True
fc1_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
fc1_weight_quantizer.internal = True
fc2_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT]
......@@ -1756,6 +1871,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
rowwise=True,
columnwise=isinstance(fc2_input_quantizer, (MXFP8Quantizer, Float8BlockQuantizer)),
)
fc1_input_quantizer.internal = True
fc2_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_WEIGHT]
fc2_weight_quantizer.internal = True
if fp8_output:
......@@ -1764,11 +1880,11 @@ class LayerNormMLP(TransformerEngineBaseModule):
]
if torch.is_grad_enabled():
fc2_grad_output_quantizer = self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
tex.FP8BwdTensors.GRAD_OUTPUT2
]
fc2_grad_output_quantizer.internal = True
fc1_grad_output_quantizer = self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_INPUT1
tex.FP8BwdTensors.GRAD_OUTPUT1
]
fc1_grad_output_quantizer.internal = True
......@@ -1853,25 +1969,25 @@ class LayerNormMLP(TransformerEngineBaseModule):
else:
# fc2_grad_output_quantizer: set configs about amax epsilon and power_2_scale for fc2_grad_output_quantizer
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
tex.FP8BwdTensors.GRAD_OUTPUT2
].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
tex.FP8BwdTensors.GRAD_OUTPUT2
].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon
# fc1_grad_output_quantizer: also set numerical configs for fc1_grad_output_quantizer
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_INPUT1
tex.FP8BwdTensors.GRAD_OUTPUT1
].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_INPUT1
tex.FP8BwdTensors.GRAD_OUTPUT1
].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon
if self.sequence_parallel and self.set_parallel_mode:
# fc2_grad_output_quantizer: customize grad_output_quantizer with amax reduction TP group, row parallel + sequence parallel here
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
tex.FP8BwdTensors.GRAD_OUTPUT2
].with_amax_reduction = True
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
tex.FP8BwdTensors.GRAD_OUTPUT2
].amax_reduction_group = self.tp_group
def backward_dw(self):
......
......@@ -6,24 +6,26 @@
from typing import Callable, Dict, Optional, Tuple, Union
from functools import reduce
from operator import mul as multiply_op
import warnings
import functools
import torch
import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch import torch_version
from .base import (
get_workspace,
fill_userbuffers_buffer_for_all_gather,
get_dummy_wgrad,
get_ub,
get_workspace,
TransformerEngineBaseModule,
get_dummy_wgrad,
_2X_ACC_FPROP,
_2X_ACC_DGRAD,
_2X_ACC_WGRAD,
)
from ._common import noop_cat, _fix_gathered_fp8_transpose, WeightGradStore
from ._common import noop_cat, WeightGradStore
from ..fp8 import FP8GlobalStateManager
from ..utils import (
cast_if_needed,
......@@ -32,7 +34,6 @@ from ..utils import (
init_method_constant,
requires_grad,
needs_quantized_gemm,
non_tn_fp8_gemm_supported,
assert_dim_for_fp8_exec,
nvtx_range_pop,
nvtx_range_push,
......@@ -57,6 +58,7 @@ from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing
from ..tensor.quantized_tensor import (
QuantizedTensor,
QuantizedTensorBase,
Quantizer,
prepare_for_saving,
restore_from_saved,
......@@ -125,88 +127,100 @@ class _Linear(torch.autograd.Function):
# Make sure input dimensions are compatible
out_features, in_features = weight.shape
inp_shape = inp.shape
assert inp_shape[-1] == in_features, "GEMM not possible"
assert inp.shape[-1] == in_features, "GEMM not possible"
# Configure tensor-parallel communication
tp_world_size = get_distributed_world_size(tp_group)
backward_needs_input = is_grad_enabled and weight.requires_grad
# Prepare input tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
nvtx_range_push(f"{nvtx_label}.input_cast_comm")
inputmat = inp.view(-1, in_features)
inputmat_total = None
with_input_all_gather_nccl = (
parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop
)
own_quantized_input = False
# TODO(kwyss): Support FP8 allgather for FP8 block quantization.
# Do TP communication in high precision if quantized format
# does not support communication
force_hp_input_gather = (
fp8 and with_input_all_gather_nccl and isinstance(input_quantizer, Float8BlockQuantizer)
) # Perform TP communication in high precision.
)
# Configure Userbuffers communication (comm+GEMM overlap)
ub_obj = None
ub_type = None
if ub_overlap_rs_fprop:
ub_obj = get_ub(ub_name + "_fprop")
ub_type = tex.CommOverlapType.RS
elif ub_overlap_ag_fprop:
ub_obj = get_ub(ub_name + "_fprop")
ub_type = tex.CommOverlapType.AG
# ------------------------------------------------------
# Prepare input tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
# ------------------------------------------------------
nvtx_range_push(f"{nvtx_label}.input_cast_comm")
inputmat = inp # Input tensor to save for backward (maybe sharded)
inputmat_total = None # Input tensor to pass to GEMM (gathered)
own_quantized_input = False
if fp8:
assert_dim_for_fp8_exec(inputmat, weight)
if any([ub_overlap_ag_fprop, ub_overlap_rs_fprop]) and not (
FP8GlobalStateManager.get_fp8_recipe().float8_per_tensor_scaling()
):
raise NotImplementedError(
"Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor"
" current scaling"
)
if fp8 or debug:
if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor")
if with_input_all_gather_nccl:
if force_hp_input_gather:
input_quantizer.set_usage(rowwise=True, columnwise=False)
inputmat_total, _ = gather_along_first_dim(
inputmat, tp_group, quantizer=input_quantizer
)
else:
if not isinstance(inputmat, QuantizedTensor):
columnwise_usage = backward_needs_input and isinstance(
input_quantizer, MXFP8Quantizer
)
# force_hp_input_gather should enforce this
assert not isinstance(input_quantizer, Float8BlockQuantizer)
input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
inputmat = input_quantizer(inputmat)
own_quantized_input = True
input_quantizer.set_usage(rowwise=True, columnwise=False)
inputmat_total, _ = gather_along_first_dim(
inputmat,
tp_group,
quantizer=input_quantizer,
)
if with_input_all_gather_nccl or ub_overlap_ag_fprop: # All-gather input tensor
# Cast local input tensor if needed
if fp8 or debug:
if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor")
if not force_hp_input_gather and not isinstance(inputmat, QuantizedTensorBase):
input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input)
if isinstance(
input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
):
# All-gather is not supported with FP8 column-wise data
input_quantizer.set_usage(columnwise=False)
inputmat = input_quantizer(inputmat)
own_quantized_input = True
else:
if (
FP8GlobalStateManager.get_fp8_recipe().float8_per_tensor_scaling()
and ub_bulk_dgrad
):
# reduce duplicated transpose in `_fix_gathered_fp8_transpose`
input_quantizer.set_usage(rowwise=True, columnwise=False)
inputmat = cast_if_needed(inp, activation_dtype) # Cast for AMP
# Initialize gathered input tensor
quantizer = None
if fp8 or debug:
quantizer = input_quantizer
quantizer.set_usage(rowwise=True, columnwise=False)
if with_input_all_gather_nccl: # Perform NCCL all-gather
inputmat_total, _ = gather_along_first_dim(
inputmat,
tp_group,
quantizer=quantizer,
)
elif ub_overlap_ag_fprop: # Initialize Userbuffers all-gather
inputmat_total, _ = fill_userbuffers_buffer_for_all_gather(
ub_obj,
inputmat,
quantizer,
tp_group,
)
else: # Do not all-gather input tensor
if fp8 or debug:
if isinstance(inputmat, QuantizedTensorBase):
inputmat.update_usage(rowwise_usage=True)
else:
input_quantizer.set_usage(
rowwise=True,
columnwise=backward_needs_input,
)
if not isinstance(inputmat, QuantizedTensor):
if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor")
input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input)
inputmat = input_quantizer(inputmat)
own_quantized_input = True
elif backward_needs_input:
inputmat.update_usage(rowwise_usage=True, columnwise_usage=True)
inputmat_total = inputmat
else:
inputmat = cast_if_needed(inp, activation_dtype)
if with_input_all_gather_nccl:
inputmat_total, _ = gather_along_first_dim(inputmat, tp_group)
else:
inputmat_total = inputmat
inputmat = cast_if_needed(inp, activation_dtype) # Cast for AMP
inputmat_total = inputmat
nvtx_range_pop(f"{nvtx_label}.input_cast_comm")
# ------------------------------------------------------
# Input tensor is ready for GEMM...
# ------------------------------------------------------
# Cast weight to expected dtype
# ------------------------------------------------------
# Prepare weight tensor
# ------------------------------------------------------
weightmat = weight
if fp8 or debug:
# Configure quantizer
if weight_quantizer is not None:
......@@ -217,7 +231,8 @@ class _Linear(torch.autograd.Function):
and not in_fp8_activation_recompute_phase()
)
weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
# FP8 cast to workspace buffer
# Get quantized weight
update_workspace = is_first_microbatch is None or is_first_microbatch
weightmat = module.get_weight_workspace(
tensor=weight,
......@@ -228,19 +243,21 @@ class _Linear(torch.autograd.Function):
fsdp_group=fsdp_group,
workspace_dtype=activation_dtype,
)
weightmat.update_usage(rowwise_usage=True)
else:
weightmat = cast_if_needed(weightmat, activation_dtype)
weightmat = cast_if_needed(weightmat, activation_dtype) # Cast for AMP
# ------------------------------------------------------
# Weight tensor is ready for GEMM...
# ------------------------------------------------------
# Cast bias to expected dtype
bias_dtype = activation_dtype
if needs_quantized_gemm(inputmat_total) and activation_dtype == torch.float32:
# cuBLAS does not support FP8 GEMM with FP32 bias, so we cast to BF16
bias_dtype = torch.bfloat16
bias = cast_if_needed(bias, bias_dtype) if bias is not None else bias
# Configure output quantizer
if output_quantizer is not None:
output_quantizer.set_usage(rowwise=True, columnwise=False)
# Calibrate quantizers if needed
if not fp8 and fp8_calibration:
if input_quantizer is not None:
......@@ -248,44 +265,74 @@ class _Linear(torch.autograd.Function):
if weight_quantizer is not None:
weight_quantizer.calibrate(weight)
ub_obj = None
ub_type = None
rs_out = None
out_dtype = activation_dtype
if ub_overlap_rs_fprop:
ub_obj = get_ub(ub_name + "_fprop")
ub_type = tex.CommOverlapType.RS
out_shape = [reduce(multiply_op, inp_shape[:-1]) // tp_world_size, out_features]
rs_out = torch.empty(out_shape, dtype=activation_dtype, device=inputmat_total.device)
elif ub_overlap_ag_fprop:
ub_obj = get_ub(ub_name + "_fprop")
ub_type = tex.CommOverlapType.AG
if fp8:
assert ub_obj.is_fp8_ubuf(), "AG overlap with FP8 GEMM inputs requires FP8 buffer."
ub_obj.copy_into_buffer(inputmat_total, input_quantizer, local_chunk=True)
inputmat_total = ub_obj.get_buffer(input_quantizer)
nvtx_range_push(f"{nvtx_label}.gemm")
fprop_gemm_use_split_accumulator = _2X_ACC_FPROP
# Choose whether to use GEMM kernel with split accumulator
use_split_accumulator = _2X_ACC_FPROP
if fp8:
recipe = FP8GlobalStateManager.get_fp8_recipe()
if hasattr(recipe, "fp8_gemm_fprop"):
fprop_gemm_use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator
use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator
# Configure output quantizer
if output_quantizer is not None:
output_quantizer.set_usage(rowwise=True, columnwise=False)
out, *_, rs_out = general_gemm(
# Output buffer for Userbuffers reduce-scatter
reduce_scatter_out = None
if ub_overlap_rs_fprop:
out_shape = list(inp.shape)
out_shape[0] //= tp_world_size
out_shape[-1] = out_features
reduce_scatter_out = torch.empty(out_shape, dtype=activation_dtype, device=inp.device)
# ------------------------------------------------------
# Forward GEMM
# Note: y = x * w^T
# ------------------------------------------------------
nvtx_range_push(f"{nvtx_label}.gemm")
gemm_out, *_, reduce_scatter_out = general_gemm(
weightmat,
inputmat_total,
get_workspace(),
quantization_params=output_quantizer,
out_dtype=out_dtype,
out_dtype=activation_dtype,
bias=bias,
use_split_accumulator=fprop_gemm_use_split_accumulator,
use_split_accumulator=use_split_accumulator,
ub=ub_obj,
ub_type=ub_type,
extra_output=rs_out,
extra_output=reduce_scatter_out,
)
nvtx_range_pop(f"{nvtx_label}.gemm")
# ------------------------------------------------------
# Finished forward GEMM...
# ------------------------------------------------------
# ------------------------------------------------------
# Prepare output tensor
# Note: Perform tensor-parallel communication
# ------------------------------------------------------
out = None
if ub_overlap_rs_fprop:
out = reduce_scatter_out
elif parallel_mode == "row" and tp_size > 1:
nvtx_range_push(f"{nvtx_label}.row_parallel_comm")
out = gemm_out
if sequence_parallel:
out, _ = reduce_scatter_along_first_dim(out, tp_group)
elif tensor_parallel:
if symmetric_ar_type is not None:
out, _ = symmetric_all_reduce(out, tp_group, all_reduce_type=symmetric_ar_type)
else:
out, _ = allreduce(out, tp_group)
nvtx_range_pop(f"{nvtx_label}.row_parallel_comm")
else:
out = gemm_out
# ------------------------------------------------------
# Output tensor is ready to return...
# ------------------------------------------------------
# ------------------------------------------------------
# Cache state for backward pass
# ------------------------------------------------------
if is_grad_enabled:
ctx.weight_quantizer = weight_quantizer
......@@ -296,19 +343,19 @@ class _Linear(torch.autograd.Function):
)
if backward_needs_input:
if own_quantized_input and isinstance(inputmat, QuantizedTensor):
if own_quantized_input and isinstance(inputmat, QuantizedTensorBase):
# For sequence parallel in vanilla FP8, rowwise data is
# to gather the input. For MXFP8, columnwise only data
# can be allgathered.
if isinstance(inputmat, MXFP8TensorBase) or not ctx.backward_input_needs_gather:
inputmat.update_usage(rowwise_usage=False, columnwise_usage=True)
if force_hp_input_gather:
assert not isinstance(inputmat, QuantizedTensor)
assert not isinstance(inputmat, QuantizedTensorBase)
saved_inputmat = inputmat
# Weight with column-wise usage is needed for dgrad GEMM.
if inp.requires_grad:
if isinstance(weightmat, QuantizedTensor):
if isinstance(weightmat, QuantizedTensorBase):
weightmat.update_usage(columnwise_usage=True)
if cpu_offloading and saved_inputmat is not None:
......@@ -321,7 +368,7 @@ class _Linear(torch.autograd.Function):
ctx.fsdp_shapes = _fsdp_scatter_tensors(
fsdp_group,
saved_inputmat,
weightmat if fp8 and not isinstance(weight, QuantizedTensor) else None,
weightmat if fp8 and not isinstance(weight, QuantizedTensorBase) else None,
)
nvtx_range_pop(f"{nvtx_label}.fsdp_scatter")
......@@ -364,7 +411,7 @@ class _Linear(torch.autograd.Function):
ctx.use_bias = bias is not None
ctx.sequence_parallel = sequence_parallel
ctx.tensor_parallel = tensor_parallel
ctx.inp_shape = inp_shape
ctx.inp_shape = inp.shape
ctx.parallel_mode = parallel_mode
ctx.tp_group = tp_group
ctx.ub_overlap_ag = ub_overlap_ag_dgrad
......@@ -376,6 +423,7 @@ class _Linear(torch.autograd.Function):
ctx.requires_dgrad = inp.requires_grad
ctx.requires_wgrad = weight.requires_grad
ctx.reduce_and_update_bwd_fp8_tensors = False
ctx.owns_input = saved_inputmat is not inp
if ctx.fp8 and requires_grad(inp, weight, bias):
_first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE
......@@ -384,21 +432,10 @@ class _Linear(torch.autograd.Function):
FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module
ctx.wgrad_store = wgrad_store
# Row Parallel Linear
if ub_overlap_rs_fprop:
out = rs_out
elif parallel_mode == "row":
nvtx_range_push(f"{nvtx_label}.row_parallel_comm")
if sequence_parallel:
out, _ = reduce_scatter_along_first_dim(out, tp_group)
elif tensor_parallel:
if symmetric_ar_type is not None:
out, _ = symmetric_all_reduce(out, tp_group, all_reduce_type=symmetric_ar_type)
else:
out, _ = allreduce(out, tp_group)
nvtx_range_pop(f"{nvtx_label}.row_parallel_comm")
# ------------------------------------------------------
# Cached state for backward pass is ready...
# ------------------------------------------------------
out = out.view(-1, *inp_shape[1:-1], out_features)
return out
@staticmethod
......@@ -411,28 +448,11 @@ class _Linear(torch.autograd.Function):
nvtx_label = f"{nvtx_label}.{ctx.ub_name}"
with torch.cuda.nvtx.range("_Linear_backward"):
if (
ctx.fp8
and any(
[
ctx.ub_overlap_ag,
ctx.ub_overlap_rs_dgrad,
ctx.ub_bulk_dgrad,
ctx.ub_bulk_wgrad,
]
)
and (ctx.fp8_recipe is not None)
):
if not ctx.fp8_recipe.float8_per_tensor_scaling():
raise NotImplementedError(
"Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor"
" current scaling"
)
saved_tensors = ctx.saved_tensors
inputmat, weight_fp8, weight, bias = ( # pylint: disable=unbalanced-tuple-unpacking
restore_from_saved(ctx.tensor_objects, saved_tensors)
)
# Delete the references to tensor objects once they've been consumed
# by the `restore_from_saved` method to construct back the actual tensors.
ctx.tensor_objects = None
......@@ -462,69 +482,55 @@ class _Linear(torch.autograd.Function):
)
nvtx_range_pop(f"{nvtx_label}.fsdp_gather")
# Configure Userbuffers communication (comm+GEMM overlap)
ctx.ub_obj_gradout = None
ub_obj_dgrad = None
ub_obj_wgrad = None
ub_type_dgrad = None
ub_type_wgrad = None
dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]]
rs_out = None
dgrad_bulk = None
if ctx.ub_overlap_ag:
# Overlap grad_output all-gather with dgrad compute
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
ub_obj_dgrad = ctx.ub_obj_gradout
ub_type_dgrad = tex.CommOverlapType.AG
elif ctx.ub_overlap_rs_dgrad:
# Overlap dgrad reduce-scatter with dgrad compute
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
ub_obj_dgrad = ctx.ub_obj_gradout
ub_type_dgrad = tex.CommOverlapType.RS
rs_out = torch.empty(
dgrad_shape, dtype=ctx.activation_dtype, device=grad_output.device
)
else:
if ctx.ub_bulk_dgrad:
# Overlap inputmat all-gather with dgrad compute
# NOTE: Copying into communication buffer will always prefer rowwise data,
# and will copy columnwise data if rowwise does not exist. In that case,
# the all-gather will apply to the leading dimension of the transpose,
# which then needs to be interleaved correctly before WGRAD.
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
ub_obj_dgrad = ctx.ub_obj_gradout
ub_type_dgrad = tex.CommOverlapType.AG
ub_obj_dgrad.copy_into_buffer(inputmat, ctx.input_quantizer, local_chunk=True)
if ctx.ub_bulk_wgrad:
# Overlap dgrad reduce-scatter with wgrad compute
ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad")
ub_type_wgrad = tex.CommOverlapType.RS
ub_obj_wgrad.set_buffer_params(ctx.grad_input_quantizer)
dgrad_bulk = ub_obj_wgrad.get_buffer(ctx.grad_input_quantizer)
# --------------------------------------------------
# Prepare grad output tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
# --------------------------------------------------
# Unmodified grad output tensor
grad_output_arg = grad_output
# Configure quantizer for grad output tensor
# Note: dgrad GEMM requires row-wise usage, wgrad GEMM
# requires column-wise usage
if ctx.grad_output_quantizer is not None:
rowwise_usage = True
columnwise_usage = True
if ctx.ub_overlap_ag and isinstance(
ctx.grad_output_quantizer,
(Float8Quantizer, Float8CurrentScalingQuantizer),
):
# If data is in FP8 and communication is handled
# with Userbuffers, we compute FP8 transposes
# manually
columnwise_usage = False
ctx.grad_output_quantizer.set_usage(
rowwise=rowwise_usage,
columnwise=columnwise_usage,
)
quantizer = ctx.grad_output_quantizer
quantizer.set_usage(rowwise=True, columnwise=True)
if ctx.ub_overlap_ag:
# Userbuffers only supports communication for one
# tensor usage at a time. Configure quantizer with
# usage for only dgrad GEMM.
quantizer.set_usage(columnwise=False)
# Prepare grad output tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
nvtx_range_push(f"{nvtx_label}.grad_output_preprocess")
(
grad_output,
......@@ -537,12 +543,21 @@ class _Linear(torch.autograd.Function):
)
nvtx_range_pop(f"{nvtx_label}.grad_output_preprocess")
# Launch tensor-parallel communication for input tensor
# --------------------------------------------------
# Grad output tensor is ready for computing grad input...
# --------------------------------------------------
# --------------------------------------------------
# Prepare input tensor
# Note: Input tensor is needed for wgrad GEMM.
# Tensor-parallel communication is overlapped with dgrad
# GEMM.
# --------------------------------------------------
inputmat_total = None
inputmat_total_work = None
if ctx.backward_input_needs_gather and not ctx.ub_bulk_dgrad:
if ctx.backward_input_needs_gather:
quantizer = None
if ctx.fp8 or ctx.debug:
if (ctx.fp8 or ctx.debug) and not ctx.force_hp_input_gather:
quantizer = ctx.input_quantizer
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
# If data is in FP8, we compute FP8 transposes manually
......@@ -550,72 +565,92 @@ class _Linear(torch.autograd.Function):
else:
# wgrad GEMM requires input with column-wise usage
quantizer.set_usage(rowwise=False, columnwise=True)
nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input")
gather_quantizer = None if ctx.force_hp_input_gather else quantizer
inputmat_total, inputmat_total_work = gather_along_first_dim(
inputmat,
ctx.tp_group,
async_op=True,
quantizer=gather_quantizer,
)
nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_input")
if ctx.ub_bulk_dgrad:
inputmat_total, _ = fill_userbuffers_buffer_for_all_gather(
ub_obj_dgrad,
inputmat,
quantizer,
ctx.tp_group,
)
else:
nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input")
inputmat_total, inputmat_total_work = gather_along_first_dim(
inputmat,
ctx.tp_group,
async_op=True,
quantizer=quantizer,
)
nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_input")
else:
inputmat_total = inputmat
# --------------------------------------------------
# Input tensor is ready for computing grad weight...
# --------------------------------------------------
# Check whether to output wgrad GEMM directly into main grad
if ctx.is_first_microbatch is not None:
accumulate_wgrad_into_param_main_grad = (
ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch
)
else:
accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
# --------------------------------------------------
# Compute grad input tensor
# --------------------------------------------------
dgrad = None
dgrad_work = None
if ctx.requires_dgrad:
# Update quantizer
if ctx.grad_input_quantizer is not None:
ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False)
# dgrad GEMM
nvtx_range_push(f"{nvtx_label}.dgrad_gemm")
dgrad_gemm_use_split_accumulator = _2X_ACC_DGRAD
# Make sure required data is available
if isinstance(grad_output, QuantizedTensorBase):
grad_output.update_usage(rowwise_usage=True)
if ctx.weight_quantizer is not None and isinstance(weight_fp8, QuantizedTensorBase):
weight_fp8.update_usage(columnwise_usage=True)
# Choose whether to use GEMM kernel with split accumulator
use_split_accumulator = _2X_ACC_DGRAD
if ctx.fp8:
recipe = ctx.fp8_recipe
if hasattr(recipe, "fp8_gemm_dgrad"):
dgrad_gemm_use_split_accumulator = (
recipe.fp8_gemm_dgrad.use_split_accumulator
)
use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator
if ctx.weight_quantizer is not None and isinstance(weight_fp8, QuantizedTensor):
weight_fp8.update_usage(
rowwise_usage=ctx.weight_quantizer.rowwise_usage,
columnwise_usage=ctx.weight_quantizer.columnwise_usage,
# Update grad input quantizer
if ctx.grad_input_quantizer is not None:
ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False)
# Output buffers for Userbuffers reduce-scatter
gemm_out = None
reduce_scatter_out = None
if ctx.ub_overlap_rs_dgrad:
reduce_scatter_out = torch.empty(
dgrad_shape, dtype=ctx.activation_dtype, device=grad_output_arg.device
)
elif ctx.ub_bulk_wgrad:
gemm_out = ub_obj_wgrad.get_buffer(local_chunk=False)
dgrad, *_, rs_out = general_gemm(
# dgrad GEMM
# Note: dx = dy * w
nvtx_range_push(f"{nvtx_label}.dgrad_gemm")
gemm_out, *_, reduce_scatter_out = general_gemm(
weight_fp8,
grad_output,
get_workspace(),
layout="NN",
grad=True,
quantization_params=ctx.grad_input_quantizer,
out=dgrad_bulk,
out=gemm_out,
out_dtype=ctx.activation_dtype,
use_split_accumulator=dgrad_gemm_use_split_accumulator,
use_split_accumulator=use_split_accumulator,
ub=ub_obj_dgrad,
ub_type=ub_type_dgrad,
extra_output=rs_out,
extra_output=reduce_scatter_out,
bulk_overlap=ctx.ub_bulk_dgrad,
)
nvtx_range_pop(f"{nvtx_label}.dgrad_gemm")
# Launch tensor-parallel communication
# Prepare grad input tensor
# Note: Perform tensor-parallel communication
if ctx.ub_overlap_rs_dgrad:
dgrad = rs_out
elif ctx.parallel_mode == "column" and not ctx.ub_bulk_wgrad:
dgrad = reduce_scatter_out
elif ctx.ub_bulk_wgrad:
dgrad = ub_obj_wgrad.get_buffer(local_chunk=True)
elif ctx.parallel_mode == "column" and ctx.tp_size > 1:
nvtx_range_push(f"{nvtx_label}.column_parallel_comm_dgrad")
dgrad = gemm_out
if ctx.sequence_parallel:
dgrad, dgrad_work = reduce_scatter_along_first_dim(
dgrad,
......@@ -625,41 +660,55 @@ class _Linear(torch.autograd.Function):
else:
dgrad, dgrad_work = allreduce(dgrad, ctx.tp_group, async_op=True)
nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_dgrad")
else:
dgrad = gemm_out
# --------------------------------------------------
# Grad input tensor has been computed...
# --------------------------------------------------
# --------------------------------------------------
# Compute grad weight
# --------------------------------------------------
# Compute grad weight tensor
wgrad = None
if ctx.requires_wgrad:
# Synchronize tensor-parallel communication for input tensor
if ctx.ub_bulk_dgrad:
inputmat_total = ub_obj_dgrad.get_buffer(ctx.input_quantizer)
if ctx.fp8:
if inputmat._data is None:
# All-gather executed on columnwise data and result is in rowwise data,
# so we need to fix the interleaving before WGRAD.
inputmat_total = _fix_gathered_fp8_transpose(
inputmat_total, ctx.tp_size
)
elif not non_tn_fp8_gemm_supported():
# FP8 GEMM on Hopper only supports TN layout so the gathered input must
# have a valid transpose.
inputmat_total._create_transpose()
# Prepare input tensor
# Note: Synchronize tensor-parallel communication and
# make sure required data is available
if inputmat_total_work is not None:
inputmat_total_work.wait()
inputmat_total_work = None
if ctx.input_quantizer is not None and not isinstance(
inputmat_total, QuantizedTensor
):
# Async gather in BF16 does not asynchronously
# call quantizer after gather.
ctx.input_quantizer.set_usage(rowwise=False, columnwise=True)
inputmat_total = ctx.input_quantizer(inputmat_total)
# Make sure GEMM inputs have required data
if isinstance(inputmat_total, QuantizedTensor):
inputmat_total.update_usage(columnwise_usage=True)
if isinstance(grad_output, QuantizedTensor):
grad_output.update_usage(columnwise_usage=True)
if ctx.fp8 or ctx.debug:
if isinstance(inputmat_total, QuantizedTensorBase):
inputmat_total.update_usage(columnwise_usage=True)
else:
ctx.input_quantizer.set_usage(rowwise=False, columnwise=True)
inputmat_total = ctx.input_quantizer(inputmat_total)
# Prepare grad output tensor
# Note: Synchronize tensor-parallel communication and
# make sure required data is available
if ctx.ub_overlap_ag and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer):
# UB does not support overlapping grad output
# all-gather with wgrad GEMM. Also, we can't
# convert row-scaled MXFP8 to column-scaled, so we
# can't reuse the grad output that was gathered
# for the dgrad GEMM. We work around with blocking
# all-gather for column-scaled MXFP8 data.
ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
grad_output, _ = gather_along_first_dim(
grad_output_arg,
ctx.tp_group,
quantizer=ctx.grad_output_quantizer,
)
if ctx.fp8 or ctx.debug:
if isinstance(grad_output, QuantizedTensorBase):
grad_output.update_usage(columnwise_usage=True)
else:
ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
grad_output = ctx.grad_output_quantizer(grad_output)
# Figure out whether to use split accumulator
use_split_accumulator = _2X_ACC_WGRAD
......@@ -668,54 +717,95 @@ class _Linear(torch.autograd.Function):
if hasattr(recipe, "fp8_gemm_wgrad"):
use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator
# Output buffer for overlapping grad input
# Figure out whether to output wgrad GEMM directly into main grad
if ctx.is_first_microbatch is not None:
accumulate_wgrad_into_param_main_grad = (
ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch
)
else:
accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
# Output buffer for overlapping FP8 grad input
# reduce-scatter with wgrad GEMM
reduce_scatter_out = None
if ctx.ub_bulk_wgrad and ub_obj_wgrad.is_fp8_ubuf():
rs_out = torch.empty(
dgrad_shape, dtype=ctx.activation_dtype, device=grad_output.device
reduce_scatter_out = torch.empty(
dgrad_shape, dtype=ctx.activation_dtype, device=grad_output_arg.device
)
# wgrad GEMM
# Note: Fuse with bgrad computation if needed
nvtx_range_push(f"{nvtx_label}.wgrad_gemm")
general_gemm_wgrad = functools.partial(
general_gemm,
out_dtype=(
# Arguments to include in wgrad GEMM closure
wgrad_gemm_kwargs = {
"workspace": get_workspace(),
"out_dtype": (
main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype
),
workspace=get_workspace(),
layout="NT",
grad=True,
bias=(bias if (grad_bias is None and not ctx.fp8) else None),
out=main_grad if ctx.fuse_wgrad_accumulation else None,
use_split_accumulator=use_split_accumulator,
accumulate=accumulate_wgrad_into_param_main_grad,
quantization_params=ctx.grad_weight_quantizer,
ub=ub_obj_wgrad,
ub_type=ub_type_wgrad,
extra_output=rs_out,
bulk_overlap=ctx.ub_bulk_wgrad,
)
"quantization_params": ctx.grad_weight_quantizer,
"accumulate": accumulate_wgrad_into_param_main_grad,
"layout": "NT",
"out": main_grad if ctx.fuse_wgrad_accumulation else None,
"bias": (bias if (grad_bias is None and not ctx.fp8) else None),
"use_split_accumulator": use_split_accumulator,
"grad": True,
"ub": ub_obj_wgrad,
"ub_type": ub_type_wgrad,
"extra_output": reduce_scatter_out,
"bulk_overlap": ctx.ub_bulk_wgrad,
}
def wgrad_gemm(
x: torch.Tensor,
dy: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Perform wgrad GEMM: dw = dy^T * x
May be fused with bgrad computation.
May be called outside of this function to enable
some advanced communication/compute overlapping.
"""
nvtx_range_push(f"{nvtx_label}.wgrad_gemm")
dw, db, *_ = general_gemm(x, dy, **wgrad_gemm_kwargs)
nvtx_range_pop(f"{nvtx_label}.wgrad_gemm")
return dw, db
# Choose whether to call wgrad GEMM now or delay
if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
ctx.wgrad_store.put([inputmat_total, grad_output], general_gemm_wgrad)
if (
wgrad_gemm_kwargs["ub"] is not None
or wgrad_gemm_kwargs["ub_type"] is not None
or wgrad_gemm_kwargs["extra_output"] is not None
or wgrad_gemm_kwargs["bulk_overlap"]
):
raise NotImplementedError(
"Delayed weight grad computation is not supported "
"with Userbuffers (tensor-parallel communication overlapping)"
)
ctx.wgrad_store.put([inputmat_total, grad_output], wgrad_gemm)
else:
wgrad, grad_bias_, _, rs_out = general_gemm_wgrad(inputmat_total, grad_output)
# Call wgrad GEMM now
wgrad, grad_bias_ = wgrad_gemm(inputmat_total, grad_output)
# Update grad bias if needed
if grad_bias is None:
grad_bias = grad_bias_
del grad_bias_
# Deallocate input tensor
# Deallocate input tensor if permitted
if ctx.owns_input:
clear_tensor_data(inputmat_total)
nvtx_range_pop(f"{nvtx_label}.wgrad_gemm")
# Update grad input if overlapping reduce-scatter with wgrad GEMM
if ctx.ub_bulk_wgrad:
if ub_obj_wgrad.is_fp8_ubuf():
dgrad = rs_out
dgrad = reduce_scatter_out
else:
dgrad = ub_obj_wgrad.get_buffer(ctx.grad_input_quantizer, local_chunk=True)
dgrad = ub_obj_wgrad.get_buffer(local_chunk=True).clone()
# --------------------------------------------------
# Grad weight has been computed...
# --------------------------------------------------
# Don't return grad bias if not needed
if not ctx.use_bias:
......@@ -753,13 +843,14 @@ class _Linear(torch.autograd.Function):
else:
wgrad = None
# Update FP8 scaling factors if needed
if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing():
nvtx_range_push(f"{nvtx_label}.reduce_and_update_fp8_tensors")
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
nvtx_range_pop(f"{nvtx_label}.reduce_and_update_fp8_tensors")
# Scatter fp8 weight buffers
if ctx.fp8 and not isinstance(weight, QuantizedTensor):
if ctx.fp8 and not isinstance(weight, QuantizedTensorBase):
_fsdp_scatter_tensors(ctx.fsdp_group, weight_fp8)
return (
wgrad,
......@@ -1207,7 +1298,12 @@ class Linear(TransformerEngineBaseModule):
"Splitting QuantizedTensor into multiple params is not supported"
)
else:
warnings.warn(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
unfused_weights = [w.dequantize() for w in unfused_weights]
weight_tensor = noop_cat(unfused_weights)
if self.use_bias:
bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names])
......@@ -1302,7 +1398,7 @@ class Linear(TransformerEngineBaseModule):
grad_output_quantizer = None
output_quantizer = None
input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
input_quantizer.internal = False
input_quantizer.internal = True
weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
weight_quantizer.internal = True
if fp8_output:
......
......@@ -98,6 +98,9 @@ class RMSNorm(_RMSNormOp):
)
kwargs["dtype"] = params_dtype
# Flag for sequence parallelism (custom Megatron-LM integration)
self.sequence_parallel: Optional[bool] = sequence_parallel
# Initialize RMSNorm operation
super().__init__(
normalized_shape,
......@@ -106,11 +109,6 @@ class RMSNorm(_RMSNormOp):
**kwargs,
)
# Flag for sequence parallelism (custom Megatron-LM integration)
self.sequence_parallel: Optional[bool] = sequence_parallel
if sequence_parallel is not None:
self.weight.sequence_parallel = sequence_parallel
def reset_rms_norm_parameters(self) -> None:
"""Deprecated"""
warnings.warn(
......@@ -139,7 +137,7 @@ class RMSNorm(_RMSNormOp):
super().reset_parameters()
# Flag for sequence parallelism (custom Megatron-LM integration)
if getattr(self, "sequence_parallel", None) is not None:
if self.sequence_parallel is not None:
self.weight.sequence_parallel = self.sequence_parallel
@property
......
......@@ -534,7 +534,9 @@ class BasicLinear(BasicOperation):
# Configure input tensor for backward pass
if with_quantized_compute and isinstance(x_local, QuantizedTensor):
x_local.update_usage(rowwise_usage=False, columnwise_usage=True)
if not (isinstance(x_local, Float8TensorBase) and with_x_all_gather):
# FP8 does not support all-gather of transpose data
x_local.update_usage(rowwise_usage=False, columnwise_usage=True)
return y, x_local, w
......@@ -622,7 +624,10 @@ class BasicLinear(BasicOperation):
# Check datatype
if dtype is None:
dtype = weight.dtype
if weight is not None:
dtype = weight.dtype
else:
dtype = grad_output.dtype
dtype = canonicalize_dtype(dtype)
if dtype not in (torch.float32, torch.float16, torch.bfloat16):
raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})")
......@@ -814,7 +819,7 @@ class BasicLinear(BasicOperation):
x_async = None
dy_async = None
# Check grad input tensor
# Check grad weight tensor
dw = grad_weight
dw_dtype = dtype
if dw is None:
......
......@@ -4,30 +4,27 @@
"""Linear layer backward with Userbuffers communication."""
# pylint: skip-file ### TODO Debug Userbuffers support
from __future__ import annotations
from collections.abc import Iterable
from typing import Any, Optional
from typing import Optional
import warnings
import torch
from transformer_engine_torch import CommOverlapAlgo
from transformer_engine_torch import CommOverlapType
from ...cpp_extensions import general_gemm
from ...distributed import get_distributed_world_size
from ...float8_tensor import Float8Tensor
from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype
from ...module.base import get_ub, get_workspace
from ...distributed import gather_along_first_dim, get_distributed_world_size
from ...module.base import (
fill_userbuffers_buffer_for_all_gather,
get_ub,
get_workspace,
_2X_ACC_DGRAD,
_2X_ACC_WGRAD,
)
from ...tensor.quantized_tensor import QuantizedTensorBase, Quantizer
from ...tensor.mxfp8_tensor import MXFP8Quantizer
from ...utils import canonicalize_device, canonicalize_dtype, clear_tensor_data
from ..basic import BasicLinear, Bias, ReduceScatter
from ..op import FusedOperation, FusibleOperation, OperationContext
from .._common import (
convert_tensor,
get_fp8_meta_from_fp8_tensor,
is_float8_tensor,
reshape,
)
class UserbuffersBackwardLinear(FusedOperation):
......@@ -47,9 +44,6 @@ class UserbuffersBackwardLinear(FusedOperation):
reduce_scatter: Optional[ReduceScatter],
) -> None:
### TODO Debug Userbuffers support
raise NotImplementedError("Userbuffers support has been broken by recent refactors")
# Basic operations that comprise this fused operation
op_idxs = {"linear": None, "bias": None, "reduce_scatter": None}
ops = []
......@@ -89,9 +83,8 @@ class UserbuffersBackwardLinear(FusedOperation):
grad_output: torch.Tensor,
input: Optional[torch.Tensor], # pylint: disable=redefined-builtin
weight: Optional[torch.Tensor],
input_dims: Iterable[int],
weight_dims: Iterable[int],
*,
input_requires_grad: bool = True,
weight_requires_grad: bool = True,
bias_requires_grad: bool = False,
device: Optional[torch.device] = None,
......@@ -102,11 +95,11 @@ class UserbuffersBackwardLinear(FusedOperation):
tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
tensor_parallel_size: Optional[int] = None,
sequence_parallel: bool = False,
with_fp8_compute: bool = False,
input_fp8_meta: Optional[dict[str, Any]] = None,
weight_fp8_meta: Optional[dict[str, Any]] = None,
grad_output_fp8_meta: Optional[dict[str, Any]] = None,
grad_input_fp8_meta: Optional[dict[str, Any]] = None,
with_quantized_compute: bool = False,
input_quantizer: Optional[Quantizer] = None,
weight_quantizer: Optional[Quantizer] = None,
grad_output_quantizer: Optional[Quantizer] = None,
grad_input_quantizer: Optional[Quantizer] = None,
ub_comm_name: str,
) -> tuple[torch.Tensor, Optional[torch.Tensor], dict]:
"""Functional API for backward pass
......@@ -121,10 +114,6 @@ class UserbuffersBackwardLinear(FusedOperation):
weight: torch.Tensor, optional
Weight tensor. Required to compute loss gradient w.r.t.
input.
input_dims: iterable of int
Input tensor dimensions
weight_dims: iterable of int
Weight tensor dimensions
weight_requires_grad: bool
Whether to compute loss gradient w.r.t. weight tensor
bias_requires_grad: bool
......@@ -146,21 +135,18 @@ class UserbuffersBackwardLinear(FusedOperation):
parallelism, i.e. distributing input or output tensors
along outer dimension (sequence or batch dim) when not
distributing along inner dimension (embedding dim)
with_fp8_compute: bool, default = `False`
Whether to perform compute in FP8
input_fp8_meta: dict, optional
FP8 metadata for casting input tensor to FP8. Required for
FP8 compute if input is not already in FP8.
weight_fp8_meta: dict, optional
FP8 metadata for casting weight tensor to FP8. Required for
FP8 compute if weight is not already in FP8.
grad_output_fp8_meta: dict, optional
FP8 metadata for casting loss gradient w.r.t. output
tensor to FP8. Required if output grad is not already in
FP8.
grad_input_fp8_meta: dict, optional
FP8 metadata for casting loss gradient w.r.t. input
tensor to FP8
with_quantized_compute: bool, default = `False`
Whether to perform compute with quantized data.
input_quantizer: Quantizer, optional
Builder class for quantized input tensor.
weight_quantizer: Quantizer, optional
Builder class for quantized weight tensor.
grad_output_quantizer: Quantizer, optional
Builder class for quantized loss gradient w.r.t. output
tensor.
grad_input_quantizer: Quantizer, optional
Builder class for quantized loss gradient w.r.t. input
tensor.
ub_comm_name: str
Layer type (e.g. "qkv", "proj", "fc1", "fc2"). This is
used to access the corresponding Userbuffers communicators
......@@ -183,37 +169,24 @@ class UserbuffersBackwardLinear(FusedOperation):
# Check device
if device is None:
device = weight.device
if weight is not None:
device = weight.device
else:
device = grad_output.device
device = canonicalize_device(device)
if device.type != "cuda":
raise ValueError(f"Only CUDA devices are supported (got {device})")
# Check datatype
if dtype is None:
dtype = weight.dtype
if weight is not None:
dtype = weight.dtype
else:
dtype = grad_output.dtype
dtype = canonicalize_dtype(dtype)
if dtype not in (torch.float32, torch.float16, torch.bfloat16):
raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})")
# Input tensor dims
output_dims = tuple(grad_output.size())
input_dims = tuple(input_dims)
weight_dims = tuple(weight_dims)
if len(weight_dims) != 2:
raise ValueError(f"Weight tensor is not 2D (shape={weight_dims})")
if len(input_dims) == 0 or weight_dims[1] != input_dims[-1]:
raise ValueError(
f"Input tensor (shape={input_dims}) "
f"and weight tensor (shape={weight_dims}) "
"are not compatible"
)
if weight_dims[0] != output_dims[-1]:
raise ValueError(
f"Grad output tensor (shape={output_dims}) "
f"and weight tensor (shape={weight_dims}) "
"are not compatible"
)
# Check tensor parallel group
if tensor_parallel_size is None:
tensor_parallel_size = get_distributed_world_size(tensor_parallel_group)
......@@ -227,373 +200,283 @@ class UserbuffersBackwardLinear(FusedOperation):
if not sequence_parallel:
raise RuntimeError(f"Invalid configuration for Userbuffers ({sequence_parallel=})")
# Check if FP8 is enabled
if with_fp8_compute:
if grad_output_fp8_meta is None and not is_float8_tensor(grad_output):
raise ValueError("No FP8 metadata was provided for casting output gradient to FP8")
# dgrad GEMM is required
if not input_requires_grad:
warnings.warn(
"Linear input doesn't require gradient, "
"but Userbuffers implementation requires dgrad GEMM."
)
input_requires_grad = True
# Check quantizers
if with_quantized_compute:
if weight_requires_grad and input_quantizer is None:
raise ValueError("Missing quantizer for input tensor")
if input_requires_grad and weight_quantizer is None:
raise ValueError("Missing quantizer for weight tensor")
if grad_output_quantizer is None:
raise ValueError("Missing quantizer for grad output tensor")
if grad_input_quantizer is not None:
raise ValueError("Quantized grad input is not supported")
else:
input_fp8_meta = None
weight_fp8_meta = None
grad_output_fp8_meta = None
grad_input_fp8_meta = None
with_fp8_grad_input = (
with_fp8_compute
and tensor_parallel_mode != "column"
and grad_input_fp8_meta is not None
)
input_quantizer = None
weight_quantizer = None
grad_output_quantizer = None
grad_input_quantizer = None
# Get Userbuffers communicators and algorithms
# Note: communication patterns are (1) overlap dy all-gather
# Get Userbuffers communicators
# Note: Communication patterns are (1) overlap dy all-gather
# with dgrad GEMM, (2) overlap x all-gather with dgrad GEMM
# and dx reduce-scatter with wgrad GEMM, (3) overlap dx
# reduce-scatter with dgrad GEMM.
with_ub_all_gather_dy = False
with_ub_reduce_scatter_dx = False
with_ub_all_gather_x = False
ub_comm_dy = None
ub_comm_dx = None
ub_comm_x = None
ub_algo_dy = None
ub_algo_dx = None
ub_algo_x = None
# reduce-scatter with dgrad GEMM
ub_comm_dgrad = None
ub_comm_wgrad = None
ub_type_dgrad = None
ub_type_wgrad = None
with_bulk_overlap = False
with_dgrad_all_gather_dy = False
with_dgrad_reduce_scatter_dx = False
with_dgrad_all_gather_x = False
with_wgrad_reduce_scatter_dx = False
if tensor_parallel_mode == "row":
with_ub_all_gather_dy = True
ub_comm_dy = get_ub(ub_comm_name + "_dgrad")
if with_fp8_compute and ub_comm_dy.is_atomic_gemm():
ub_algo_dy = CommOverlapAlgo.ATOMIC_GEMM_AG_P2P
else:
ub_algo_dy = CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P
ub_comm_dgrad = get_ub(ub_comm_name + "_dgrad")
ub_type_dgrad = CommOverlapType.AG
with_dgrad_all_gather_dy = True
elif tensor_parallel_mode == "column":
with_ub_reduce_scatter_dx = True
if weight_requires_grad:
with_ub_all_gather_x = True
ub_comm_dx = get_ub(ub_comm_name + "_wgrad")
ub_comm_x = get_ub(ub_comm_name + "_dgrad")
ub_algo_dx = CommOverlapAlgo.BULK_OVERLAP_RS
ub_algo_x = CommOverlapAlgo.BULK_OVERLAP_AG
if input_requires_grad and weight_requires_grad:
with_bulk_overlap = True
ub_comm_dgrad = get_ub(ub_comm_name + "_dgrad")
ub_type_dgrad = CommOverlapType.AG
with_dgrad_all_gather_x = True
ub_comm_wgrad = get_ub(ub_comm_name + "_wgrad")
ub_type_wgrad = CommOverlapType.RS
with_wgrad_reduce_scatter_dx = True
if ub_comm_wgrad.is_fp8_ubuf():
raise RuntimeError(
"Userbuffers reduce-scatter is not supported with FP8 buffers"
)
else:
with_ub_all_gather_x = False
ub_comm_dx = get_ub(ub_comm_name + "_dgrad")
is_atomic_gemm = with_fp8_compute and ub_comm_dx.is_atomic_gemm()
ub_algo_dx = {
(True, True): CommOverlapAlgo.ATOMIC_GEMM_RS_P2P,
(True, False): CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P,
(False, True): CommOverlapAlgo.ATOMIC_GEMM_RS,
(False, False): CommOverlapAlgo.SPLIT_PIPELINED_RS,
}[(ub_comm_dx.is_p2p_overlap(), is_atomic_gemm)]
# Check grad output tensor
# Note: Possibly fuse cast with computing grad bias
dy_local = reshape(
grad_output,
(-1, output_dims[-1]),
device=device,
dtype=dtype,
)
ub_comm_dgrad = get_ub(ub_comm_name + "_dgrad")
ub_type_dgrad = CommOverlapType.RS
with_dgrad_reduce_scatter_dx = True
if ub_comm_dgrad.is_fp8_ubuf():
raise RuntimeError(
"Userbuffers reduce-scatter is not supported with FP8 buffers"
)
# Compute grad bias if needed
db = None
db_async = None
if bias_requires_grad and with_fp8_compute and with_ub_all_gather_dy:
# We don't have a grad bias impl that takes FP8 input. For
# cases where we cast to FP8 and all-gather, it's better
# to compute the grad bias on ungathered, non-FP8 values.
db = dy_local.sum(dim=0)
db_async = torch.distributed.all_reduce(
db,
group=tensor_parallel_group,
async_op=True,
)
if with_fp8_compute and not is_float8_tensor(dy_local):
fp8_dtype = get_fp8_te_dtype(
grad_output_fp8_meta["recipe"],
fprop_tensor=False,
)
if bias_requires_grad and db is None:
# Fused cast-transpose-bgrad
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=False)
fp8_scale_inv = torch.empty([1], dtype=torch.float32, device=device)
db, data, data_transpose = fp8_cast_transpose_bgrad_fused(
dy_local,
grad_output_fp8_meta[fp8_meta_key],
0,
fp8_dtype,
scale_inv=fp8_scale_inv,
)
if with_ub_all_gather_dy:
data = ub_comm_dy.get_ubuf_output(0).copy_(data)
dy_local = Float8Tensor(
data=data,
fp8_meta=grad_output_fp8_meta,
fp8_meta_forward=False,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
fp8_scale_inv=fp8_scale_inv,
dtype=dtype,
data_transpose=data_transpose,
if bias_requires_grad:
db = grad_output.sum(tuple(range(grad_output.dim() - 1)))
if tensor_parallel_mode == "row":
db_async = torch.distributed.all_reduce(
db,
group=tensor_parallel_group,
async_op=True,
)
else:
dy_local = Float8Tensor.to_float8(
dy_local,
fp8_meta=grad_output_fp8_meta,
fp8_meta_forward=False,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
data=(ub_comm_dy.get_ubuf_output(0) if with_ub_all_gather_dy else None),
with_transpose_cache=(not with_ub_all_gather_dy),
# Cast grad output tensor dtype if needed
dy_local = grad_output
if with_quantized_compute:
if not isinstance(dy_local, QuantizedTensorBase):
with_columnwise = weight_requires_grad
if (
with_columnwise
and with_dgrad_all_gather_dy
and not isinstance(grad_output_quantizer, MXFP8Quantizer)
):
with_columnwise = False
grad_output_quantizer.set_usage(
rowwise=True,
columnwise=with_columnwise,
)
elif not with_fp8_compute and is_float8_tensor(dy_local):
if with_ub_all_gather_dy:
ub_local_buffer = ub_comm_dy.get_ubuf_output(0)
dy_local = ub_local_buffer.copy_(dy_local)
else:
dy_local = dy_local.dequantize()
if bias_requires_grad and db is None and with_fp8_compute and with_ub_all_gather_dy:
# We don't have a fused grad bias impl that takes FP8
# input. For cases where we cast to FP8 and all-gather,
# it's better to compute the grad bias on ungathered,
# non-FP8 values.
db = dy_local.sum(dim=0)
db_async = torch.distributed.all_reduce(
db,
group=tensor_parallel_group,
async_op=True,
)
dy_local = grad_output_quantizer(dy_local)
else:
if isinstance(dy_local, QuantizedTensorBase):
dy_local = dy_local.dequantize(dtype=dtype)
elif dy_local.dtype != dtype:
dy_local = dy_local.to(dtype=dtype)
# Cast weight tensor dtype if needed
if weight is None:
raise ValueError("Weight tensor is required to compute input grad")
w = weight
if with_quantized_compute:
if not isinstance(w, QuantizedTensorBase):
weight_quantizer.set_usage(columnwise=True)
w = weight_quantizer(w)
else:
if isinstance(w, QuantizedTensorBase):
w = w.dequantize(dtype=dtype)
elif w.dtype != dtype:
w = w.to(dtype=dtype)
# Check input tensor
# Cast input tensor dtype if needed
x_local = None
if weight_requires_grad:
x_local = reshape(
input,
(-1, input_dims[-1]),
device=device,
dtype=dtype,
)
if with_fp8_compute and not is_float8_tensor(x_local):
fp8_dtype = get_fp8_te_dtype(
input_fp8_meta["recipe"],
fprop_tensor=True,
if input is None:
raise ValueError("Input tensor is required to compute weight grad")
x_local = input
if with_quantized_compute:
if not isinstance(x_local, QuantizedTensorBase):
input_quantizer.set_usage(columnwise=True)
x_local = input_quantizer(x_local)
else:
if isinstance(x_local, QuantizedTensorBase):
x_local = x_local.dequantize(dtype=dtype)
elif x_local.dtype != dtype:
x_local = x_local.to(dtype=dtype)
# dgrad GEMM
dx_local = None
dx = None
dy = None
x = None
if input_requires_grad:
# Initialize grad output
if with_dgrad_all_gather_dy:
if grad_output_quantizer is not None:
grad_output_quantizer.set_usage(rowwise=True, columnwise=False)
dy, _ = fill_userbuffers_buffer_for_all_gather(
ub_comm_dgrad,
dy_local,
grad_output_quantizer,
tensor_parallel_group,
)
x_local = Float8Tensor.to_float8(
else:
dy = dy_local
# Construct grad input tensor if needed
if with_dgrad_reduce_scatter_dx or with_wgrad_reduce_scatter_dx:
dx_size = list(dy.size())
dx_size[-1] = w.size(-1)
dx_local_size = list(dx_size)
dx_local_size[0] //= tensor_parallel_size
if with_dgrad_reduce_scatter_dx:
dx_local = torch.empty(
dx_local_size,
dtype=dtype,
device=device,
)
elif with_wgrad_reduce_scatter_dx:
dx_local = ub_comm_wgrad.get_buffer(
local_chunk=True,
shape=dx_local_size,
)
dx = ub_comm_wgrad.get_buffer(
local_chunk=False,
shape=dx_size,
)
# Initialize input tensor if needed
if with_dgrad_all_gather_x:
if input_quantizer is not None:
input_quantizer.set_usage(rowwise=False, columnwise=True)
x, _ = fill_userbuffers_buffer_for_all_gather(
ub_comm_dgrad,
x_local,
fp8_meta=input_fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
data=(ub_comm_x.get_ubuf_output(0) if with_ub_all_gather_x else None),
with_transpose_cache=(not with_ub_all_gather_x),
input_quantizer,
tensor_parallel_group,
)
elif not with_fp8_compute and is_float8_tensor(x_local):
if with_ub_all_gather_x:
ub_local_buffer = ub_comm_x.get_ubuf_output(0)
x_local = ub_local_buffer.copy_(x_local)
else:
x_local = x_local.dequantize()
# Check weight tensor
w = convert_tensor(
weight,
device=device,
dtype=dtype,
memory_format=torch.contiguous_format,
)
if with_fp8_compute and not is_float8_tensor(w):
fp8_dtype = get_fp8_te_dtype(
weight_fp8_meta["recipe"],
fprop_tensor=True,
)
w = Float8Tensor.to_float8(
# Perform dgrad GEMM
dx, *_ = general_gemm(
w,
fp8_meta=weight_fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
with_transpose_cache=True,
dy,
get_workspace(),
out_dtype=dtype,
quantization_params=grad_input_quantizer,
layout="NN",
out=dx,
use_split_accumulator=_2X_ACC_DGRAD,
grad=True,
ub=ub_comm_dgrad,
ub_type=ub_type_dgrad,
extra_output=dx_local if with_dgrad_reduce_scatter_dx else None,
bulk_overlap=with_bulk_overlap,
)
elif not with_fp8_compute and is_float8_tensor(w):
w = w.dequantize()
# Initialize buffers for UB all-gather if needed
dy = dy_local
x = x_local
if with_ub_all_gather_dy:
ub_local_buffer = ub_comm_dy.get_ubuf_output(0)
ub_global_buffer = ub_comm_dy.get_ubuf_output(1)
if with_fp8_compute:
dy = Float8Tensor.make_like(dy_local, data=ub_global_buffer)
if dy_local._data.data_ptr() != ub_local_buffer.data_ptr():
ub_local_buffer.copy_(dy_local._data)
else:
dy = ub_global_buffer
if dy_local.data_ptr() != ub_local_buffer.data_ptr():
ub_local_buffer.copy_(dy_local)
if with_ub_all_gather_x:
ub_local_buffer = ub_comm_x.get_ubuf_output(0)
ub_global_buffer = ub_comm_x.get_ubuf_output(1)
if with_fp8_compute:
x = Float8Tensor.make_like(x_local, data=ub_global_buffer)
if x_local._data.data_ptr() != ub_local_buffer.data_ptr():
ub_local_buffer.copy_(x_local._data)
else:
x = ub_global_buffer
if x_local.data_ptr() != ub_local_buffer.data_ptr():
ub_local_buffer.copy_(x_local)
if not (with_dgrad_reduce_scatter_dx or with_wgrad_reduce_scatter_dx):
dx_local = dx
# Construct grad input tensor
dx = None
dx_local = None
if with_ub_reduce_scatter_dx:
# Initialize buffers for UB reduce-scatter
dx = ub_comm_dx.get_ubuf_output(1)
ub_local_buffer = ub_comm_dx.get_ubuf_output(0)
if with_ub_all_gather_x:
dx_local = ub_local_buffer
else:
dx_local = torch.empty_like(ub_local_buffer)
else:
# Allocate grad input tensor
if with_fp8_grad_input:
fp8_dtype = get_fp8_te_dtype(
grad_input_fp8_meta["recipe"],
fprop_tensor=False,
)
data = torch.empty(
(dy.size(0), w.size(-1)),
dtype=torch.uint8,
device=device,
)
dx = Float8Tensor(
data=data,
fp8_meta=grad_input_fp8_meta,
fp8_meta_forward=False,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
dtype=dtype,
# wgrad GEMM
dw = None
if weight_requires_grad:
# Initialize grad output
if tensor_parallel_mode == "row" and isinstance(grad_output_quantizer, MXFP8Quantizer):
# UB does not support overlapping grad output
# all-gather with wgrad GEMM. Also, MXFP8 does not
# allow reusing the grad output that was gathered for
# the dgrad GEMM. We work around with blocking
# all-gather for column-scaled MXFP8 data.
grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
dy, _ = gather_along_first_dim(
grad_output,
tensor_parallel_group,
quantizer=grad_output_quantizer,
)
else:
dx = torch.empty(
(dy.size(0), w.size(-1)),
dtype=dtype,
device=device,
if tensor_parallel_mode == "column":
dy = dy_local
if dy is None:
raise RuntimeError(
"wgrad GEMM requires grad output tensor, which has not been initialized"
)
dx_local = dx
if isinstance(dy, QuantizedTensorBase):
dy.update_usage(rowwise_usage=False, columnwise_usage=True)
# Allocate grad input tensor
if grad_weight is None:
if accumulate_into_grad_weight:
raise ValueError(
"Attempted to accumulate into grad weight bufferwithout providing grad weight"
# Initialize input tensor
if tensor_parallel_mode == "row":
x = x_local
if x is None:
raise RuntimeError(
"wgrad GEMM requires input tensor, which has not been initialized"
)
grad_weight = torch.empty(
weight_dims,
dtype=dtype,
device=device,
memory_format=torch.contiguous_format,
)
if isinstance(x, QuantizedTensorBase):
x.update_usage(rowwise_usage=False, columnwise_usage=True)
# Check grad weight tensor
dw = grad_weight
dw_dtype = dtype
if dw is None:
if accumulate_into_grad_weight:
raise ValueError(
"Attempted to accumulate into grad weight tensor "
"without providing grad weight tensor"
)
else:
dw_dtype = dw.dtype
# Perform dgrad GEMM
if with_fp8_compute:
kwargs = {"out": dx, "use_split_accumulator": True}
if with_ub_all_gather_dy:
kwargs["ub_algo"] = ub_algo_dy
kwargs["ub"] = ub_comm_dy
elif with_ub_all_gather_x:
kwargs["ub_algo"] = ub_algo_x
kwargs["ub"] = ub_comm_x
elif with_ub_reduce_scatter_dx:
kwargs["ub_algo"] = ub_algo_dx
kwargs["ub"] = ub_comm_dx
kwargs["extra_output_tensor"] = dx_local
if with_fp8_grad_input:
fp8_meta, fp8_meta_index = get_fp8_meta_from_fp8_tensor(dx)
kwargs.update(
{
"out": dx._data,
"out_index": fp8_meta_index,
"fp8_meta_tensor": fp8_meta,
"D_dtype": dx._fp8_dtype,
}
)
fp8_gemm(
w.transpose_2d(),
w._scale_inv,
0,
w._fp8_dtype,
dy._data,
dy._scale_inv,
0,
dy._fp8_dtype,
dy.dtype,
get_workspace(),
**kwargs,
)
else:
kwargs = {"grad": True, "layout": "NN", "out": dx}
if with_ub_all_gather_dy:
kwargs["ub_algo"] = ub_algo_dy
kwargs["ub"] = ub_comm_dy
elif with_ub_all_gather_x:
kwargs["ub_algo"] = ub_algo_x
kwargs["ub"] = ub_comm_x
elif with_ub_reduce_scatter_dx:
kwargs["ub_algo"] = ub_algo_dx
kwargs["ub"] = ub_comm_dx
kwargs["extra_output_tensor"] = dx_local
gemm(w, dy, dx.dtype, get_workspace(), **kwargs)
grad_input = reshape(dx_local, input_dims)
# Perform wgrad GEMM
if not weight_requires_grad:
pass
elif with_fp8_compute:
kwargs = {
"accumulate": accumulate_into_grad_weight,
"out": grad_weight,
"use_split_accumulator": True,
}
if with_ub_reduce_scatter_dx:
kwargs["ub_algo"] = ub_algo_dx
kwargs["ub"] = ub_comm_dx
fp8_gemm(
x.transpose_2d(),
x._scale_inv,
0,
x._fp8_dtype,
dy.transpose_2d(),
dy._scale_inv,
0,
dy._fp8_dtype,
grad_weight.dtype,
get_workspace(),
**kwargs,
)
else:
kwargs = {
"accumulate": accumulate_into_grad_weight,
"layout": "NT",
"grad": True,
"use_bias": bias_requires_grad,
"out": grad_weight,
}
if with_ub_reduce_scatter_dx:
kwargs["ub_algo"] = ub_algo_dx
kwargs["ub"] = ub_comm_dx
grad_weight, db, _ = gemm(
# Perform wgrad GEMM
dw, *_ = general_gemm(
x,
dy,
grad_weight.dtype,
get_workspace(),
**kwargs,
out_dtype=dw_dtype,
accumulate=accumulate_into_grad_weight,
layout="NT",
out=dw,
use_split_accumulator=_2X_ACC_WGRAD,
grad=True,
ub=ub_comm_wgrad,
ub_type=ub_type_wgrad,
bulk_overlap=with_bulk_overlap,
)
# Bulk overlap reduce-scatter with non-FP8 buffer is
# in-place. Need to copy grad input tensor to avoid data
# corruption in Userbuffers buffer.
if with_wgrad_reduce_scatter_dx:
dx_local = dx_local.clone()
# Compute grad bias if needed
if db_async is not None:
db_async.wait()
if bias_requires_grad:
if db is None:
db = dy.sum(dim=0)
extra_outputs["grad_bias"] = db
return grad_input, grad_weight, extra_outputs
return dx_local, dw, extra_outputs
def fuser_backward(
self,
......@@ -633,40 +516,24 @@ class UserbuffersBackwardLinear(FusedOperation):
else:
accumulate_into_main_grad = False
# Hackily workaround Userbuffers bug with non-FP8 dgrad
# reduce-scatter overlap
weight_requires_grad = linear_op_ctx.weight_requires_grad
if not linear_op_ctx.with_fp8_compute and not weight_requires_grad:
warnings.warn(
"There is a correctness bug when using Userbuffers "
"to overlap a dgrad reduce-scatter with a non-FP8 dgrad GEMM. "
"Hackily working around by overlapping dgrad reduce-scatter "
"with wgrad GEMM, even though wgrad isn't needed. "
"Please contact Transformer Engine team "
"if you encounter this use-case."
)
weight_requires_grad = True
# Linear backward pass
retval = UserbuffersBackwardLinear._functional_backward(
grad_output=grad_output,
input=x_local,
weight=linear_op.weight,
input_dims=linear_op_ctx.input_dims,
weight_dims=linear_op.weight.size(),
weight_requires_grad=weight_requires_grad,
weight_requires_grad=linear_op_ctx.weight_requires_grad,
bias_requires_grad=(bias_op is not None),
device=linear_op.device,
dtype=linear_op_ctx.dtype,
grad_weight=grad_weight,
accumulate_into_grad_weight=accumulate_into_main_grad,
tensor_parallel_mode=self.tensor_parallel_mode,
tensor_parallel_group=self.tensor_parallel_group,
sequence_parallel=self.sequence_parallel,
with_fp8_compute=linear_op_ctx.with_fp8_compute,
weight_fp8_meta=linear_op_ctx.weight_fp8_meta,
grad_output_fp8_meta=linear_op_ctx.grad_output_fp8_meta,
grad_input_fp8_meta=linear_op_ctx.grad_input_fp8_meta,
with_quantized_compute=linear_op_ctx.with_quantized_compute,
input_quantizer=linear_op_ctx.input_quantizer,
weight_quantizer=linear_op_ctx.weight_quantizer,
grad_output_quantizer=linear_op_ctx.grad_output_quantizer,
grad_input_quantizer=None, # Not supported
ub_comm_name=linear_op._userbuffers_options["comm_name"],
)
grad_input, grad_weight, extra_outputs = retval
......@@ -707,8 +574,6 @@ def fuse_userbuffers_backward_linear(
"""
return ops ### TODO Debug Userbuffers support
# Return immediately if environment is not distributed
if not torch.distributed.is_initialized() or torch.distributed.get_world_size() == 1:
return ops
......
......@@ -4,20 +4,25 @@
"""Linear layer forward with Userbuffers communication."""
# pylint: skip-file ### TODO Debug Userbuffers support
from __future__ import annotations
from collections.abc import Iterable
from typing import Any, Optional
import torch
from transformer_engine_torch import CommOverlapAlgo
from transformer_engine_torch import CommOverlapType
from ...cpp_extensions import general_gemm
from ...distributed import get_distributed_world_size
from ...float8_tensor import Float8Tensor
from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype
from ...module.base import get_ub, get_workspace
from ...fp8 import FP8GlobalStateManager
from ...module.base import (
fill_userbuffers_buffer_for_all_gather,
get_ub,
get_workspace,
_2X_ACC_FPROP,
)
from ...tensor.quantized_tensor import QuantizedTensorBase, Quantizer
from ...tensor.float8_tensor import Float8Quantizer
from ...tensor._internal.float8_tensor_base import Float8TensorBase
from ...utils import canonicalize_device, canonicalize_dtype
from ..basic import BasicLinear, Bias, ReduceScatter
from ..op import (
......@@ -26,12 +31,6 @@ from ..op import (
FusibleOperation,
OperationContext,
)
from .._common import (
convert_tensor,
get_fp8_meta_from_fp8_tensor,
is_float8_tensor,
reshape,
)
class UserbuffersForwardLinear(FusedOperation):
......@@ -51,9 +50,6 @@ class UserbuffersForwardLinear(FusedOperation):
reduce_scatter: Optional[ReduceScatter],
) -> None:
### TODO Debug Userbuffers support
raise NotImplementedError("Userbuffers support has been broken by recent refactors")
# Basic operations that comprise this fused operation
op_idxs = {"linear": 0, "bias": None, "reduce_scatter": None}
ops = [linear]
......@@ -98,10 +94,10 @@ class UserbuffersForwardLinear(FusedOperation):
tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
tensor_parallel_size: Optional[int] = None,
sequence_parallel: bool = False,
with_fp8_compute: bool = False,
input_fp8_meta: Optional[dict[str, Any]] = None,
weight_fp8_meta: Optional[dict[str, Any]] = None,
output_fp8_meta: Optional[dict[str, Any]] = None,
with_quantized_compute: bool = False,
input_quantizer: Optional[Quantizer] = None,
weight_quantizer: Optional[Quantizer] = None,
output_quantizer: Optional[Quantizer] = None,
ub_comm_name: str,
) -> tuple[torch.Tensor, dict]:
"""Functional API for forward pass
......@@ -127,16 +123,14 @@ class UserbuffersForwardLinear(FusedOperation):
parallelism, i.e. distributing input or output tensors
along outer dimension (sequence or batch dim) when not
distributing along inner dimension (embedding dim)
with_fp8_compute: bool, default = `False`
Whether to perform compute in FP8
input_fp8_meta: dict, optional
FP8 metadata for casting input tensor to FP8. Required for
FP8 compute if input is not already in FP8.
weight_fp8_meta: dict, optional
FP8 metadata for casting weight tensor to FP8. Required for
FP8 compute if weight is not already in FP8.
output_fp8_meta: dict, optional
FP8 metadata for casting output tensor to FP8
with_quantized_compute: bool, default = `False`
Whether to perform compute with quantized data.
input_quantizer: Quantizer, optional
Builder class for quantized input tensor.
weight_quantizer: Quantizer, optional
Builder class for quantized weight tensor.
output_quantizer: Quantizer, optional
Builder class for quantized output tensor.
ub_comm_name: str
Layer type (e.g. "qkv", "proj", "fc1", "fc2"). This is
used to access the corresponding Userbuffers communicators
......@@ -166,23 +160,6 @@ class UserbuffersForwardLinear(FusedOperation):
if dtype not in (torch.float32, torch.float16, torch.bfloat16):
raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})")
# Input tensor dims
input_dims = tuple(input.size())
weight_dims = tuple(weight.size())
if len(weight_dims) != 2:
raise ValueError(f"Weight tensor is not 2D (shape={weight_dims})")
if len(input_dims) == 0 or weight_dims[1] != input_dims[-1]:
raise ValueError(
f"Input tensor (shape={input_dims}) "
f"and weight tensor (shape={weight_dims}) "
"are not compatible"
)
# Output tensor dims
output_dims = list(input_dims)
output_dims[0] = -1
output_dims[-1] = weight_dims[0]
# Check tensor parallel group
if tensor_parallel_size is None:
tensor_parallel_size = get_distributed_world_size(tensor_parallel_group)
......@@ -196,235 +173,106 @@ class UserbuffersForwardLinear(FusedOperation):
if not sequence_parallel:
raise RuntimeError(f"Invalid configuration for Userbuffers ({sequence_parallel=})")
# Check if FP8 is enabled
if with_fp8_compute:
if input_fp8_meta is None and not is_float8_tensor(input):
raise ValueError("No FP8 metadata was provided for casting input to FP8")
if weight_fp8_meta is None and not is_float8_tensor(weight):
raise ValueError("No FP8 metadata was provided for casting weight to FP8")
# Check quantizers
if with_quantized_compute:
if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor")
if weight_quantizer is None:
raise ValueError("Missing quantizer for weight tensor")
if output_quantizer is not None:
raise ValueError("FP8 output is not supported")
else:
input_fp8_meta = None
weight_fp8_meta = None
output_fp8_meta = None
with_fp8_output = (
with_fp8_compute and tensor_parallel_mode != "row" and output_fp8_meta is not None
)
input_quantizer = None
weight_quantizer = None
output_quantizer = None
# Get Userbuffers communicator
ub_comm = get_ub(ub_comm_name + "_fprop")
ub_local_buffer = ub_comm.get_ubuf_output(0)
ub_global_buffer = ub_comm.get_ubuf_output(1)
with_ub_all_gather = tensor_parallel_mode == "column"
with_ub_reduce_scatter = tensor_parallel_mode == "row"
ub_type = CommOverlapType.AG if with_ub_all_gather else CommOverlapType.RS
# Choose Userbuffers communication algorithm
ub_algo = None
# Initialize input tensor
x_local = input
x = None
if with_ub_all_gather:
if with_fp8_compute and ub_comm.is_atomic_gemm():
ub_algo = CommOverlapAlgo.ATOMIC_GEMM_AG_P2P
else:
ub_algo = CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P
elif with_ub_reduce_scatter:
is_atomic_gemm = with_fp8_compute and ub_comm.is_atomic_gemm()
ub_algo = {
(True, True): CommOverlapAlgo.ATOMIC_GEMM_RS_P2P,
(True, False): CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P,
(False, True): CommOverlapAlgo.ATOMIC_GEMM_RS,
(False, False): CommOverlapAlgo.SPLIT_PIPELINED_RS,
}[(ub_comm.is_p2p_overlap(), is_atomic_gemm)]
else:
raise RuntimeError("Could not choose Userbuffers communication algorithm")
# Cast input tensor to correct dtype
x_local = reshape(
input,
(-1, input_dims[-1]),
device=device,
dtype=dtype,
)
if with_fp8_compute and not is_float8_tensor(x_local):
fp8_dtype = get_fp8_te_dtype(
input_fp8_meta["recipe"],
fprop_tensor=True,
)
with_transpose_cache = weight.requires_grad
if tensor_parallel_mode == "column" and sequence_parallel:
with_transpose_cache = False
x_local = Float8Tensor.to_float8(
if input_quantizer is not None:
if not isinstance(x_local, QuantizedTensorBase):
input_quantizer.set_usage(rowwise=True, columnwise=True)
if isinstance(input_quantizer, Float8Quantizer):
input_quantizer.set_usage(columnwise=False)
x_local = input_quantizer(x_local)
input_quantizer.set_usage(rowwise=True, columnwise=False)
x, x_local = fill_userbuffers_buffer_for_all_gather(
ub_comm,
x_local,
fp8_meta=input_fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
data=(ub_local_buffer if with_ub_all_gather else None),
with_transpose_cache=with_transpose_cache,
input_quantizer,
tensor_parallel_group,
)
elif not with_fp8_compute and is_float8_tensor(x_local):
if with_ub_all_gather:
x_local = ub_local_buffer.copy_(x_local)
else:
x_local = x_local.dequantize()
# Initialize buffers for UB all-gather if needed
x = x_local
if with_ub_all_gather:
if with_fp8_compute:
x = Float8Tensor.make_like(x_local, data=ub_global_buffer)
if x_local._data.data_ptr() != ub_local_buffer.data_ptr():
ub_local_buffer.copy_(x_local._data)
else:
x_local._data = torch.empty_like(x_local._data)
else:
if with_quantized_compute:
if not isinstance(x_local, QuantizedTensorBase):
input_quantizer.set_usage(rowwise=True, columnwise=True)
x_local = input_quantizer(x_local)
else:
x = ub_global_buffer
if x_local.data_ptr() != ub_local_buffer.data_ptr():
ub_local_buffer.copy_(x_local)
else:
x_local = torch.empty_like(x_local)
# Check weight tensor
w = convert_tensor(
weight,
device=device,
dtype=dtype,
memory_format=torch.contiguous_format,
)
if with_fp8_compute and not is_float8_tensor(w):
fp8_dtype = get_fp8_te_dtype(
weight_fp8_meta["recipe"],
fprop_tensor=True,
)
w = Float8Tensor.to_float8(
w,
fp8_meta=weight_fp8_meta,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
)
elif not with_fp8_compute and is_float8_tensor(w):
if isinstance(x_local, QuantizedTensorBase):
x_local = x_local.dequantize(dtype=dtype)
if x_local.dtype != dtype:
x_local = x_local.to(dtype=dtype)
x = x_local
# Initialize weight tensor
w = weight
w_is_quantized = isinstance(w, QuantizedTensorBase)
if with_quantized_compute and not w_is_quantized:
weight_quantizer.set_usage(rowwise=True)
w = weight_quantizer(w)
elif not with_quantized_compute and w_is_quantized:
w = w.dequantize()
if not with_quantized_compute and w.dtype != dtype:
w = w.to(dtype=dtype)
# Check bias tensor
b = None
if bias is not None:
b = convert_tensor(
bias,
device=device,
dtype=dtype,
memory_format=torch.contiguous_format,
)
# Construct output tensor
y = None
y_local = None
# Construct output tensor if needed
reduce_scatter_output = None
if with_ub_reduce_scatter:
# Initialize buffers for UB reduce-scatter
if with_fp8_output:
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True)
fp8_dtype = get_fp8_te_dtype(
output_fp8_meta["recipe"],
fprop_tensor=True,
)
y = Float8Tensor(
data=ub_global_buffer,
fp8_meta=output_fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
fp8_scale_inv=output_fp8_meta[fp8_meta_key].scale_inv[0],
dtype=dtype,
)
ub_comm.set_ubuf_scale_inv(y._scale_inv)
else:
y = ub_global_buffer
y_local = torch.empty(
(x.size(0) // tensor_parallel_size, weight_dims[0]),
dtype=dtype,
device=device,
)
else:
# Allocate output tensor
if with_fp8_output:
fp8_dtype = get_fp8_te_dtype(
output_fp8_meta["recipe"],
fprop_tensor=True,
)
data = torch.empty(
(x.size(0), weight_dims[0]),
dtype=torch.uint8,
device=device,
)
y = Float8Tensor(
data=data,
fp8_meta=output_fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
dtype=dtype,
)
else:
y = torch.empty(
(x.size(0), weight_dims[0]),
dtype=dtype,
device=device,
)
y_local = y
y_local_size = list(x.size())
y_local_size[0] //= tensor_parallel_size
y_local_size[-1] = w.size(0)
reduce_scatter_output = torch.empty(y_local_size, dtype=dtype, device=device)
# Perform GEMM
if with_fp8_compute:
kwargs = {
"out": y,
"bias": b,
"use_bias": (b is not None),
"use_split_accumulator": False,
"ub_algo": ub_algo,
"ub": ub_comm,
}
if with_ub_all_gather:
kwargs["extra_output_tensor"] = x_local._data
if with_ub_reduce_scatter:
kwargs["extra_output_tensor"] = y_local
if with_fp8_output:
fp8_meta, fp8_meta_index = get_fp8_meta_from_fp8_tensor(y)
kwargs.update(
{
"out": y._data,
"out_index": fp8_meta_index,
"fp8_meta_tensor": fp8_meta,
"D_dtype": y._fp8_dtype,
}
)
fp8_gemm(
w._data,
w._scale_inv,
0,
w._fp8_dtype,
x._data,
x._scale_inv,
0,
x._fp8_dtype,
y.dtype,
get_workspace(),
**kwargs,
)
gemm_output, *_, reduce_scatter_output = general_gemm(
w,
x,
get_workspace(),
out_dtype=dtype,
quantization_params=output_quantizer,
bias=bias,
use_split_accumulator=_2X_ACC_FPROP,
ub=ub_comm,
ub_type=ub_type,
extra_output=reduce_scatter_output,
)
if with_ub_reduce_scatter:
y_local = reduce_scatter_output
else:
kwargs = {
"out": y,
"bias": b,
"use_bias": (b is not None),
"ub_algo": ub_algo,
"ub": ub_comm,
}
if with_ub_all_gather:
kwargs["extra_output_tensor"] = x_local
if with_ub_reduce_scatter:
kwargs["extra_output_tensor"] = y_local
gemm(w, x, y.dtype, get_workspace(), **kwargs)
# Reshape output tensor
out = reshape(y_local, output_dims)
y_local = gemm_output
# Detach input tensor if needed
# Note: PyTorch autograd produces esoteric errors if we save
# input tensor as context for backward pass.
if x_local is input:
x_local = x_local.detach()
# Configure input tensor for backward pass
if with_quantized_compute and isinstance(x_local, QuantizedTensorBase):
if not (isinstance(x_local, Float8TensorBase) and with_ub_all_gather):
# FP8 does not support all-gather of transpose data
x_local.update_usage(rowwise_usage=False, columnwise_usage=True)
# Return cast tensors
extra_outputs = {"input": x_local, "weight": w}
return out, extra_outputs
return y_local, extra_outputs
def fuser_forward(
self,
......@@ -450,23 +298,22 @@ class UserbuffersForwardLinear(FusedOperation):
if basic_op_kwargs[idx]:
raise ValueError("Bias operation forward does not expect keyword arguments")
# FP8 metadata
with_fp8_compute = FP8GlobalStateManager.is_fp8_enabled()
input_fp8_meta = None
weight_fp8_meta = None
output_fp8_meta = None
grad_output_fp8_meta = None
grad_input_fp8_meta = None
if with_fp8_compute:
input_fp8_meta = linear_op.get_fp8_meta("input")
weight_fp8_meta = linear_op.get_fp8_meta("param")
next_op = basic_op_next_ops[-1]
if next_op is not None and next_op.num_fp8_scales("input") > 0:
output_fp8_meta = next_op.get_fp8_meta("input")
grad_output_fp8_meta = linear_op.get_fp8_meta("grad_output")
# Quantization metadata
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
input_quantizer = None
weight_quantizer = None
grad_output_quantizer = None
grad_input_quantizer = None
if with_quantized_compute:
recipe = FP8GlobalStateManager.get_fp8_recipe()
if not recipe.delayed() and not recipe.mxfp8():
raise RuntimeError("Userbuffers is only supported with FP8 delayed scaling recipe")
input_quantizer = linear_op.get_quantizer("forward", 0)
weight_quantizer = linear_op.get_quantizer("forward", 1)
grad_output_quantizer = linear_op.get_quantizer("backward", 0)
prev_op = basic_op_prev_ops[0]
if prev_op is not None and prev_op.num_fp8_scales("grad_output") > 0:
grad_input_fp8_meta = prev_op.get_fp8_meta("grad_output")
if prev_op is not None and prev_op.num_quantizers("backward") > 0 and recipe.delayed():
grad_input_quantizer = prev_op.get_quantizer("backward", 0)
# Get autocast dtype if needed
dtype = None
......@@ -482,26 +329,26 @@ class UserbuffersForwardLinear(FusedOperation):
input=input_,
weight=linear_op.weight,
bias=bias,
device=linear_op.device,
dtype=dtype,
tensor_parallel_mode=self.tensor_parallel_mode,
tensor_parallel_group=self.tensor_parallel_group,
tensor_parallel_size=self.tensor_parallel_size,
sequence_parallel=self.sequence_parallel,
with_fp8_compute=with_fp8_compute,
input_fp8_meta=input_fp8_meta,
weight_fp8_meta=weight_fp8_meta,
output_fp8_meta=output_fp8_meta,
with_quantized_compute=with_quantized_compute,
input_quantizer=input_quantizer,
weight_quantizer=weight_quantizer,
output_quantizer=None, # Not supported
ub_comm_name=linear_op._userbuffers_options["comm_name"],
)
x_local = extra_outputs["input"]
# Save state for backward pass
linear_op_ctx.save_for_backward(x_local)
linear_op_ctx.with_fp8_compute = with_fp8_compute
linear_op_ctx.weight_fp8_meta = weight_fp8_meta
linear_op_ctx.grad_output_fp8_meta = grad_output_fp8_meta
linear_op_ctx.grad_input_fp8_meta = grad_input_fp8_meta
linear_op_ctx.with_quantized_compute = with_quantized_compute
linear_op_ctx.input_quantizer = input_quantizer
linear_op_ctx.weight_quantizer = weight_quantizer
linear_op_ctx.grad_output_quantizer = grad_output_quantizer
linear_op_ctx.grad_input_quantizer = grad_input_quantizer
linear_op_ctx.dtype = dtype
linear_op_ctx.input_dims = input_.size()
linear_op_ctx.input_requires_grad = input_.requires_grad
......@@ -529,8 +376,6 @@ def fuse_userbuffers_forward_linear(
"""
return ops ### TODO Debug Userbuffers support
# Return immediately if environment is not distributed
if not torch.distributed.is_initialized() or torch.distributed.get_world_size() == 1:
return ops
......
......@@ -55,7 +55,17 @@ if __name__ == "__main__":
description="Transformer acceleration library - Torch Lib",
ext_modules=ext_modules,
cmdclass={"build_ext": CMakeBuildExtension},
install_requires=["torch"],
setup_requires=[
"torch>=2.1",
"nvidia-cuda-runtime-cu12",
"nvidia-cublas-cu12",
"nvidia-cudnn-cu12",
"nvidia-cuda-cccl-cu12",
"nvidia-cuda-nvcc-cu12",
"nvidia-nvtx-cu12",
"nvidia-cuda-nvrtc-cu12",
],
install_requires=["torch>=2.1"],
tests_require=["numpy", "torchvision"],
)
if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment