Commit f8c2af4c authored by yuguo's avatar yuguo
Browse files

Merge commit '1d903f5e' of...

Merge commit '1d903f5e' of https://github.com/NVIDIA/TransformerEngine
parents e92773a3 1d903f5e
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#pragma once
#include <ATen/ATen.h>
#include "common/utils.cuh"
// Forward/backward compatiblity hack around
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
// pending more future-proof guidance from upstream.
// struct TypeShim
// {
// const at::Type& payload;
// TypeShim(const at::Type& type) : payload(type) {}
// // Enable trivial conversion to a const at::Type& for pre-3aeb78
// operator const at::Type&(){ return payload; };
// // Enable dispatch switch statements to take *this directly for post-3aeb78
// //operator at::ScalarType(){ return payload.; };
// };
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t_##LEVEL = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Byte: { \
using scalar_t_##LEVEL = uint8_t; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Double: { \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Double: { \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t_##LEVEL = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Double: { \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Half: { \
using scalar_t = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
switch (TYPEIN) { \
case at::ScalarType::Float: { \
using scalar_t_in = float; \
switch (TYPEOUT) { \
case at::ScalarType::Float: { \
using scalar_t_out = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
} \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_in = at::Half; \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t_in = at::BFloat16; \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
switch (TYPEIN) { \
case at::ScalarType::Double: { \
using scalar_t_in = double; \
switch (TYPEOUT) { \
case at::ScalarType::Double: { \
using scalar_t_out = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: { \
using scalar_t_out = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
} \
break; \
} \
case at::ScalarType::Float: { \
using scalar_t_in = float; \
switch (TYPEOUT) { \
case at::ScalarType::Float: { \
using scalar_t_out = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
} \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_in = at::Half; \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t_in = at::BFloat16; \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
}
template <typename T>
__device__ __forceinline__ T
reduce_block_into_lanes(T *x, T val, int lanes = 1,
bool share_result = false) { // lanes is intended to be <= 32.
int tid = threadIdx.x + threadIdx.y * blockDim.x;
int blockSize = blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
if (blockSize >= 64) {
x[tid] = val;
__syncthreads();
}
#pragma unroll
for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
if (tid < i) x[tid] = x[tid] + x[tid + i];
__syncthreads();
}
T final;
if (tid < 32) {
if (blockSize >= 64)
final = x[tid] + x[tid + 32];
else
final = val;
// __SYNCWARP();
#pragma unroll
#ifdef __HIP_PLATFORM_AMD__
for (int i = 16; i >= lanes; i >>= 1) final = final + __shfl_down(final, i, THREADS_PER_WARP);
#else
for (int i = 16; i >= lanes; i >>= 1) final = final + __shfl_down_sync(0xffffffff, final, i);
#endif
}
if (share_result) {
if (tid < lanes) x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
}
__syncthreads();
// Avoid potential write before read race when reduce_block_into_lanes is called back to back
return final;
}
template <typename T>
__device__ __forceinline__ T
reduce_block_into_lanes_max_op(T *x, T val, int lanes = 1,
bool share_result = false) { // lanes is intended to be <= 32.
int tid = threadIdx.x + threadIdx.y * blockDim.x;
int blockSize = blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
if (blockSize >= 64) {
x[tid] = val;
__syncthreads();
}
#pragma unroll
for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
if (tid < i) x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid + i]));
__syncthreads();
}
T final;
if (tid < 32) {
if (blockSize >= 64)
final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32]));
else
final = val;
// __SYNCWARP();
#pragma unroll
for (int i = 16; i >= lanes; i >>= 1)
#ifdef __HIP_PLATFORM_AMD__
final = fmaxf(fabsf(final), fabsf(__shfl_down(final, i, THREADS_PER_WARP)));
#else
final = fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i)));
#endif
}
if (share_result) {
if (tid < lanes) x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads();
}
return final;
}
...@@ -4,10 +4,10 @@ ...@@ -4,10 +4,10 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "extensions.h"
#include "transformer_engine/transformer_engine.h"
#include "util.h" #include "util.h"
#include "common.h"
std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrapper& input, std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrapper& input,
bool rowwise) { bool rowwise) {
using namespace transformer_engine::pytorch; using namespace transformer_engine::pytorch;
...@@ -45,80 +45,33 @@ std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrap ...@@ -45,80 +45,33 @@ std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrap
transformer_engine::TensorWrapper input_cu(NVTE_MXFP8_1D_SCALING); transformer_engine::TensorWrapper input_cu(NVTE_MXFP8_1D_SCALING);
transformer_engine::TensorWrapper output_cu(NVTE_MXFP8_1D_SCALING); transformer_engine::TensorWrapper output_cu(NVTE_MXFP8_1D_SCALING);
if (rowwise) { if (rowwise) {
input_cu.set_rowwise_data(input.dptr(), DType::kFloat8E4M3, input_shape); input_cu.set_rowwise_data(input.dptr(), transformer_engine::DType::kFloat8E4M3, input_shape);
input_cu.set_rowwise_scale_inv(scale_inv_dptr, DType::kFloat8E8M0, scale_inv_shape); input_cu.set_rowwise_scale_inv(scale_inv_dptr, transformer_engine::DType::kFloat8E8M0,
output_cu.set_rowwise_data(input.dptr(), DType::kFloat8E4M3, input_shape); scale_inv_shape);
output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, DType::kFloat8E8M0, scale_inv_shape); output_cu.set_rowwise_data(input.dptr(), transformer_engine::DType::kFloat8E4M3, input_shape);
output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0,
scale_inv_shape);
} else { } else {
input_cu.set_columnwise_data(input.columnwise_dptr(), DType::kFloat8E4M3, input_shape); input_cu.set_columnwise_data(input.columnwise_dptr(), transformer_engine::DType::kFloat8E4M3,
input_cu.set_columnwise_scale_inv(scale_inv_dptr, DType::kFloat8E8M0, scale_inv_shape); input_shape);
output_cu.set_columnwise_data(input.columnwise_dptr(), DType::kFloat8E4M3, input_shape); input_cu.set_columnwise_scale_inv(scale_inv_dptr, transformer_engine::DType::kFloat8E8M0,
output_cu.set_columnwise_scale_inv(swizzled_scale_inv_dptr, DType::kFloat8E8M0, scale_inv_shape);
scale_inv_shape); output_cu.set_columnwise_data(input.columnwise_dptr(), transformer_engine::DType::kFloat8E4M3,
input_shape);
output_cu.set_columnwise_scale_inv(swizzled_scale_inv_dptr,
transformer_engine::DType::kFloat8E8M0, scale_inv_shape);
} }
// Launch kernel // Launch kernel
nvte_swizzle_scaling_factors(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); nvte_swizzle_scaling_factors(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
if (rowwise) { if (rowwise) {
input.set_rowwise_scale_inv(swizzled_scale_inv_dptr, DType::kFloat8E8M0, scale_inv_shape); input.set_rowwise_scale_inv(swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0,
scale_inv_shape);
} else { } else {
input.set_columnwise_scale_inv(swizzled_scale_inv_dptr, DType::kFloat8E8M0, scale_inv_shape); input.set_columnwise_scale_inv(swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0,
scale_inv_shape);
} }
return swizzled_scale_inv; return swizzled_scale_inv;
} }
at::Tensor rowwise_swizzle(at::Tensor input, at::Tensor scale_inv) {
using namespace transformer_engine::pytorch;
NVTE_CHECK(input.element_size() == 1, "8-bit input required for swizzling scaling factors.");
auto options = at::TensorOptions().dtype(scale_inv.dtype()).device(torch::kCUDA);
auto swizzled_scale_inv = at::empty_like(scale_inv, options);
void* scale_inv_dptr = getDataPtr(scale_inv, 0);
void* swizzled_scale_inv_dptr = getDataPtr(swizzled_scale_inv, 0);
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), getTensorShape(input),
DType::kFloat8E4M3, nullptr, nullptr, scale_inv_dptr,
getTensorShape(scale_inv), NVTE_MXFP8_1D_SCALING);
auto output_cu = makeTransformerEngineTensor(
input.data_ptr(), getTensorShape(input), DType::kFloat8E4M3, nullptr, nullptr,
swizzled_scale_inv_dptr, getTensorShape(swizzled_scale_inv), NVTE_MXFP8_1D_SCALING);
// Launch kernel
nvte_swizzle_scaling_factors(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
return swizzled_scale_inv;
}
at::Tensor columnwise_swizzle(at::Tensor input, at::Tensor scale_inv) {
using namespace transformer_engine::pytorch;
NVTE_CHECK(input.element_size() == 1, "8-bit input required for swizzling scaling factors.");
auto options = at::TensorOptions().dtype(scale_inv.dtype()).device(torch::kCUDA);
auto swizzled_scale_inv = at::empty_like(scale_inv, options);
// Return immediately if tensor is empty
if (scale_inv.numel() == 0) {
return swizzled_scale_inv;
}
void* scale_inv_dptr = getDataPtr(scale_inv, 0);
void* swizzled_scale_inv_dptr = getDataPtr(swizzled_scale_inv, 0);
auto input_cu = makeTransformerEngineTensor(
nullptr, input.data_ptr(), {1}, getTensorShape(input), DType::kFloat8E4M3, nullptr, nullptr,
nullptr, scale_inv_dptr, {1}, getTensorShape(scale_inv), NVTE_MXFP8_1D_SCALING);
auto output_cu = makeTransformerEngineTensor(
nullptr, input.data_ptr(), {1}, getTensorShape(input), DType::kFloat8E4M3, nullptr, nullptr,
nullptr, swizzled_scale_inv_dptr, {1}, getTensorShape(swizzled_scale_inv),
NVTE_MXFP8_1D_SCALING);
// Launch kernel
nvte_swizzle_scaling_factors(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
return swizzled_scale_inv;
}
...@@ -13,8 +13,6 @@ ...@@ -13,8 +13,6 @@
#include "transformer_engine/transformer_engine.h" #include "transformer_engine/transformer_engine.h"
bool non_tn_fp8_gemm_supported();
/* Swizzle the scaling factor of the input tensor. /* Swizzle the scaling factor of the input tensor.
* *
* The returned swizzled scaling factor tensor should be kept alive during the GEMM. * The returned swizzled scaling factor tensor should be kept alive during the GEMM.
......
...@@ -19,7 +19,12 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP ...@@ -19,7 +19,12 @@ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp._common_utils import _get_module_fsdp_state from torch.distributed.fsdp._common_utils import _get_module_fsdp_state
from torch.distributed.fsdp._traversal_utils import _get_fsdp_states_with_modules from torch.distributed.fsdp._traversal_utils import _get_fsdp_states_with_modules
from .utils import non_tn_fp8_gemm_supported, safely_set_viewless_tensor_data, needs_quantized_gemm from . import torch_version
from .utils import (
is_non_tn_fp8_gemm_supported,
safely_set_viewless_tensor_data,
needs_quantized_gemm,
)
from .constants import dist_group_type from .constants import dist_group_type
from .fp8 import FP8GlobalStateManager, fp8_autocast from .fp8 import FP8GlobalStateManager, fp8_autocast
from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer
...@@ -267,17 +272,36 @@ def _get_active_autocast_contexts(): ...@@ -267,17 +272,36 @@ def _get_active_autocast_contexts():
""" """
autocast_cached = torch.is_autocast_cache_enabled() autocast_cached = torch.is_autocast_cache_enabled()
gpu_autocast_enabled = torch.is_autocast_enabled() if torch_version() >= (2, 4, 0):
gpu_autocast_dtype = torch.get_autocast_gpu_dtype() gpu_autocast_enabled = torch.is_autocast_enabled("cuda")
gpu_autocast_ctx = torch.cuda.amp.autocast( gpu_autocast_dtype = torch.get_autocast_dtype("cuda")
gpu_autocast_enabled, gpu_autocast_dtype, autocast_cached gpu_autocast_ctx = torch.amp.autocast(
) "cuda",
enabled=gpu_autocast_enabled,
dtype=gpu_autocast_dtype,
cache_enabled=autocast_cached,
)
cpu_autocast_enabled = torch.is_autocast_cpu_enabled() cpu_autocast_enabled = torch.is_autocast_enabled("cpu")
cpu_autocast_dtype = torch.get_autocast_cpu_dtype() cpu_autocast_dtype = torch.get_autocast_dtype("cpu")
cpu_autocast_ctx = torch.cpu.amp.autocast( cpu_autocast_ctx = torch.amp.autocast(
cpu_autocast_enabled, cpu_autocast_dtype, autocast_cached "cpu",
) enabled=cpu_autocast_enabled,
dtype=cpu_autocast_dtype,
cache_enabled=autocast_cached,
)
else:
gpu_autocast_enabled = torch.is_autocast_enabled()
gpu_autocast_dtype = torch.get_autocast_gpu_dtype()
gpu_autocast_ctx = torch.cuda.amp.autocast(
gpu_autocast_enabled, gpu_autocast_dtype, autocast_cached
)
cpu_autocast_enabled = torch.is_autocast_cpu_enabled()
cpu_autocast_dtype = torch.get_autocast_cpu_dtype()
cpu_autocast_ctx = torch.cpu.amp.autocast(
cpu_autocast_enabled, cpu_autocast_dtype, autocast_cached
)
return gpu_autocast_ctx, cpu_autocast_ctx return gpu_autocast_ctx, cpu_autocast_ctx
...@@ -561,7 +585,9 @@ def has_te_modules(network): ...@@ -561,7 +585,9 @@ def has_te_modules(network):
""" """
from .module import LayerNorm, RMSNorm from .module import LayerNorm, RMSNorm
from .module.base import TransformerEngineBaseModule from .module.base import TransformerEngineBaseModule
from .attention import UnfusedDotProductAttention, DotProductAttention, MultiheadAttention from .attention.dot_product_attention.backends import UnfusedDotProductAttention
from .attention.dot_product_attention.dot_product_attention import DotProductAttention
from .attention.multi_head_attention import MultiheadAttention
from .transformer import TransformerLayer from .transformer import TransformerLayer
te_classes_list = [ te_classes_list = [
...@@ -893,8 +919,10 @@ def _all_gather_fp8( ...@@ -893,8 +919,10 @@ def _all_gather_fp8(
# Note: We cannot directly all-gather the transposed FP8 tensor, # Note: We cannot directly all-gather the transposed FP8 tensor,
# so temporarily modify quantizer to avoid creating FP8 transpose. # so temporarily modify quantizer to avoid creating FP8 transpose.
if not isinstance(inp, Float8TensorBase): if not isinstance(inp, Float8TensorBase):
if quantizer is None: assert isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer))
raise ValueError("Input tensor is not FP8 and no quantizer was provided") # we cannot directly gather the transposed fp8 tensor
# so we need to disable columnwise usage for the quantizer
# and then set it back to the original value after quantizing
init_rowwise_usage = quantizer.rowwise_usage init_rowwise_usage = quantizer.rowwise_usage
init_columnwise_usage = quantizer.columnwise_usage init_columnwise_usage = quantizer.columnwise_usage
quantizer.set_usage(rowwise=True, columnwise=False) quantizer.set_usage(rowwise=True, columnwise=False)
...@@ -938,7 +966,7 @@ def _all_gather_fp8( ...@@ -938,7 +966,7 @@ def _all_gather_fp8(
# Make sure FP8 transpose is populated if needed # Make sure FP8 transpose is populated if needed
needs_transpose = ( needs_transpose = (
quantizer is not None and quantizer.columnwise_usage and not non_tn_fp8_gemm_supported() quantizer is not None and quantizer.columnwise_usage and not is_non_tn_fp8_gemm_supported()
) )
if needs_transpose: if needs_transpose:
if handle is not None: if handle is not None:
...@@ -1037,11 +1065,11 @@ def _all_gather_mxfp8( ...@@ -1037,11 +1065,11 @@ def _all_gather_mxfp8(
dtype = inp.dtype dtype = inp.dtype
elif isinstance(inp, MXFP8TensorBase): elif isinstance(inp, MXFP8TensorBase):
if inp._rowwise_data is not None: if inp._rowwise_data is not None:
in_shape = inp._rowwise_data.device.size() in_shape = inp._rowwise_data.size()
device = inp._rowwise_data.device device = inp._rowwise_data.device
dtype = inp._rowwise_data.dtype dtype = inp._rowwise_data.dtype
elif inp._columnwise_data is not None: elif inp._columnwise_data is not None:
in_shape = inp._columnwise_data.device.size() in_shape = inp._columnwise_data.size()
device = inp._columnwise_data.device device = inp._columnwise_data.device
dtype = inp._columnwise_data.dtype dtype = inp._columnwise_data.dtype
else: else:
...@@ -1474,7 +1502,9 @@ def _is_te_module(module): ...@@ -1474,7 +1502,9 @@ def _is_te_module(module):
""" """
from .module import LayerNorm, RMSNorm from .module import LayerNorm, RMSNorm
from .module.base import TransformerEngineBaseModule from .module.base import TransformerEngineBaseModule
from .attention import UnfusedDotProductAttention, DotProductAttention, MultiheadAttention from .attention.dot_product_attention.dot_product_attention import DotProductAttention
from .attention.dot_product_attention.backends import UnfusedDotProductAttention
from .attention.multi_head_attention import MultiheadAttention
from .transformer import TransformerLayer from .transformer import TransformerLayer
te_classes_list = [ te_classes_list = [
......
...@@ -520,8 +520,8 @@ class FP8GlobalStateManager: ...@@ -520,8 +520,8 @@ class FP8GlobalStateManager:
return return
# Store updated amaxes and scales from phase 1 post forward. # Store updated amaxes and scales from phase 1 post forward.
fp8_meta["updated_amax_history_fwd"] = fp8_meta["scaling_fwd"].amax_history fp8_meta["updated_amax_history_fwd"] = fp8_meta["scaling_fwd"].amax_history.clone()
fp8_meta["updated_scale_fwd"] = fp8_meta["scaling_fwd"].scale fp8_meta["updated_scale_fwd"] = fp8_meta["scaling_fwd"].scale.clone()
# Retrieve stashed amaxes and scales from phase 1 pre forward. # Retrieve stashed amaxes and scales from phase 1 pre forward.
buffer_position_key = "global_fp8_buffer_pos_fwd_recompute" buffer_position_key = "global_fp8_buffer_pos_fwd_recompute"
......
...@@ -536,7 +536,9 @@ def _make_graphed_callables( ...@@ -536,7 +536,9 @@ def _make_graphed_callables(
# Only Set the FP8 meta for the modules included by forward # Only Set the FP8 meta for the modules included by forward
continue continue
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
from transformer_engine.pytorch.attention import DotProductAttention from transformer_engine.pytorch.attention.dot_product_attention import (
DotProductAttention,
)
if ( if (
isinstance(m, DotProductAttention) isinstance(m, DotProductAttention)
......
...@@ -8,6 +8,9 @@ from functools import wraps ...@@ -8,6 +8,9 @@ from functools import wraps
from typing import Callable, Optional, Tuple from typing import Callable, Optional, Tuple
import torch import torch
from . import torch_version
from .utils import gpu_autocast_ctx
from torch.utils.cpp_extension import IS_HIP_EXTENSION from torch.utils.cpp_extension import IS_HIP_EXTENSION
# pylint: disable=unnecessary-lambda-assignment # pylint: disable=unnecessary-lambda-assignment
...@@ -32,13 +35,13 @@ def lazy_compile(func): ...@@ -32,13 +35,13 @@ def lazy_compile(func):
jit_fuser = lambda func: func jit_fuser = lambda func: func
if torch.__version__ >= "2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))): if torch_version() >= (2, 0, 0) and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))):
jit_fuser = lazy_compile jit_fuser = lazy_compile
# See: https://github.com/NVIDIA/TransformerEngine/issues/597 # See: https://github.com/NVIDIA/TransformerEngine/issues/597
dropout_fuser = torch.jit.script dropout_fuser = torch.jit.script
if torch.__version__ >= "2.2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))): if torch_version() >= (2, 2, 0) and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))):
dropout_fuser = lazy_compile dropout_fuser = lazy_compile
...@@ -51,11 +54,9 @@ def set_jit_fusion_options() -> None: ...@@ -51,11 +54,9 @@ def set_jit_fusion_options() -> None:
if not IS_HIP_EXTENSION: if not IS_HIP_EXTENSION:
"""Set PyTorch JIT layer fusion options.""" """Set PyTorch JIT layer fusion options."""
# flags required to enable jit fusion kernels # flags required to enable jit fusion kernels
TORCH_MAJOR = int(torch.__version__.split(".")[0]) if torch_version() >= (2, 2, 0):
TORCH_MINOR = int(torch.__version__.split(".")[1])
if TORCH_MAJOR == 2 and TORCH_MINOR >= 2:
pass pass
elif (TORCH_MAJOR == 2) or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10): elif torch_version() >= (1, 10, 0):
# nvfuser # nvfuser
torch._C._jit_set_profiling_executor(True) torch._C._jit_set_profiling_executor(True)
torch._C._jit_set_profiling_mode(True) torch._C._jit_set_profiling_mode(True)
...@@ -124,7 +125,7 @@ def dgelu_fused_(grad_output: torch.Tensor, inp: torch.Tensor) -> torch.Tensor: ...@@ -124,7 +125,7 @@ def dgelu_fused_(grad_output: torch.Tensor, inp: torch.Tensor) -> torch.Tensor:
def bias_gelu_fused(inp: torch.Tensor, bias: torch.Tensor) -> torch.Tensor: def bias_gelu_fused(inp: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
"""Disable native AMP for bias_gelu_fused_""" """Disable native AMP for bias_gelu_fused_"""
with torch.cuda.amp.autocast(enabled=False): with gpu_autocast_ctx(enabled=False):
if bias is not None and bias.numel() != 0: if bias is not None and bias.numel() != 0:
return bias_gelu_fused_(inp, bias) return bias_gelu_fused_(inp, bias)
return gelu_fused_(inp) return gelu_fused_(inp)
...@@ -134,7 +135,7 @@ def bgrad_dgelu_fused( ...@@ -134,7 +135,7 @@ def bgrad_dgelu_fused(
grad_output: torch.Tensor, inp: torch.Tensor, bias: torch.Tensor grad_output: torch.Tensor, inp: torch.Tensor, bias: torch.Tensor
) -> Tuple[Optional[torch.Tensor], torch.Tensor]: ) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
"""Disable native AMP for `bgrad_dgelu_fused_`""" """Disable native AMP for `bgrad_dgelu_fused_`"""
with torch.cuda.amp.autocast(enabled=False): with gpu_autocast_ctx(enabled=False):
if bias is not None and bias.numel() != 0: if bias is not None and bias.numel() != 0:
return bgrad_dgelu_fused_(grad_output, inp, bias) return bgrad_dgelu_fused_(grad_output, inp, bias)
return None, dgelu_fused_(grad_output, inp) return None, dgelu_fused_(grad_output, inp)
...@@ -175,7 +176,7 @@ def bias_dropout_add_fused_train( ...@@ -175,7 +176,7 @@ def bias_dropout_add_fused_train(
) -> torch.Tensor: ) -> torch.Tensor:
"""Disable native AMP and enable grad for BDA""" """Disable native AMP and enable grad for BDA"""
with torch.enable_grad(): with torch.enable_grad():
with torch.cuda.amp.autocast(enabled=False): with gpu_autocast_ctx(enabled=False):
return bias_dropout_add_fused_train_(x, bias, residual, prob) return bias_dropout_add_fused_train_(x, bias, residual, prob)
...@@ -191,7 +192,7 @@ def bias_dropout_add_fused_inference( ...@@ -191,7 +192,7 @@ def bias_dropout_add_fused_inference(
x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float x: torch.Tensor, bias: torch.Tensor, residual: torch.Tensor, prob: float
) -> torch.Tensor: ) -> torch.Tensor:
"""Disable native AMP for BDA""" """Disable native AMP for BDA"""
with torch.cuda.amp.autocast(enabled=False): with gpu_autocast_ctx(enabled=False):
return bias_dropout_add_fused_inference_(x, bias, residual, prob) return bias_dropout_add_fused_inference_(x, bias, residual, prob)
......
...@@ -6,8 +6,6 @@ ...@@ -6,8 +6,6 @@
from typing import Any, List, Optional, Tuple, Union, Callable from typing import Any, List, Optional, Tuple, Union, Callable
from dataclasses import dataclass from dataclasses import dataclass
from functools import reduce
from operator import mul as multiply_op
import queue import queue
import torch import torch
...@@ -15,7 +13,6 @@ import torch ...@@ -15,7 +13,6 @@ import torch
from .. import cpp_extensions as tex from .. import cpp_extensions as tex
from ..constants import TE_DType from ..constants import TE_DType
from ..utils import get_default_init_method from ..utils import get_default_init_method
from ..tensor.float8_tensor import Float8Tensor
import warnings import warnings
try: try:
from lightop import rmsnorm_forward,rmsnorm_backward from lightop import rmsnorm_forward,rmsnorm_backward
...@@ -24,7 +21,6 @@ except ImportError: ...@@ -24,7 +21,6 @@ except ImportError:
enable_lightop = False enable_lightop = False
warnings.warn("Failed to import lightop module. Falling back to alternative implementation.", UserWarning) warnings.warn("Failed to import lightop module. Falling back to alternative implementation.", UserWarning)
def _get_normalization_func(normalization: str, forward: bool): def _get_normalization_func(normalization: str, forward: bool):
fwd_normalization_funcs = { fwd_normalization_funcs = {
"LayerNorm": tex.layernorm_fwd, "LayerNorm": tex.layernorm_fwd,
...@@ -40,39 +36,6 @@ def _get_normalization_func(normalization: str, forward: bool): ...@@ -40,39 +36,6 @@ def _get_normalization_func(normalization: str, forward: bool):
return bwd_normalization_funcs[normalization] return bwd_normalization_funcs[normalization]
def _fix_gathered_fp8_transpose(fp8_tensor: Float8Tensor, tp_size: int) -> Float8Tensor:
"""Reorder FP8 transposes after Userbuffers gather.
The all-gather is performed in-place in the Float8Tensor's
row-wise data, and afterwards we need to do a transpose to get the
correct ordering. This misuses data fields in Float8Tensor and
should be considered an evil hack. It would be best to move
transpose logic into CommOverlap::get_buffer.
Responsibility for fixing: adener, tmoon
"""
assert isinstance(fp8_tensor, Float8Tensor), "Tensor is not a Float8Tensor"
assert tp_size > 1, "The tensor transpose cannot be interleaved when TP size is 1"
assert fp8_tensor._data is not None, "The tensor does not hold any rowwise data"
assert (
fp8_tensor._data.shape[0] % tp_size == 0
), "Leading dimension of data is not divisble by TP size"
data = fp8_tensor._data
batched_size = reduce(multiply_op, data.shape[1:])
interleaved_shape = [tp_size, data.shape[0] // tp_size, batched_size]
transposed_shape = [data.shape[0] // tp_size, batched_size * tp_size]
fp8_tensor._transpose = (
data.view(interleaved_shape).transpose(0, 1).contiguous().view(transposed_shape)
)
fp8_tensor._transpose_invalid = False
fp8_tensor._data = None
return fp8_tensor
def apply_normalization( def apply_normalization(
inputmat: torch.Tensor, inputmat: torch.Tensor,
ln_out: torch.Tensor, ln_out: torch.Tensor,
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
"""Base modules and utilities for TransformerEngine PyTorch API""" """Base modules and utilities for TransformerEngine PyTorch API"""
import io import io
import math
import os import os
import pickle import pickle
import warnings import warnings
...@@ -35,10 +36,13 @@ from ..distributed import ( ...@@ -35,10 +36,13 @@ from ..distributed import (
_fsdp_gather_tensors, _fsdp_gather_tensors,
) )
from ..constants import dist_group_type from ..constants import dist_group_type
from ..tensor import QuantizedTensor, Quantizer from ..tensor.quantized_tensor import QuantizedTensor, QuantizedTensorBase, Quantizer
from ..tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..tensor._internal.float8_tensor_base import Float8TensorBase from ..tensor._internal.float8_tensor_base import Float8TensorBase
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..utils import torch_get_autocast_gpu_dtype
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ...common.recipe import Recipe from ...common.recipe import Recipe
from ...debug.pytorch.debug_state import TEDebugState from ...debug.pytorch.debug_state import TEDebugState
...@@ -451,6 +455,142 @@ def destroy_ub(): ...@@ -451,6 +455,142 @@ def destroy_ub():
layers_atomic_ring_exchange = [] layers_atomic_ring_exchange = []
def fill_userbuffers_buffer_for_all_gather(
comm,
local_tensor: torch.Tensor,
quantizer: Optional[Quantizer],
process_group,
) -> tuple[torch.Tensor | QuantizedTensorBase, torch.Tensor | QuantizedTensorBase]:
"""Fill local shard of Userbuffers buffer with data for all-gather
Returns the full tensor and the local shard, both using the
Userbuffers buffer as their underlying data. These tensors should
be used carefully (e.g. only immediately before and after a
Userbuffers operation) since the underlying data may be
overwritten by other Userbuffers operations.
May perform blocking communication if needed for the gathered
tensor's metadata, e.g. scaling factors.
"""
# Tensor dimensions
local_shape = local_tensor.size()
if not local_shape:
raise ValueError(f"Invalid local tensor (shape={tuple(local_shape)})")
process_group_size = torch.distributed.get_world_size(process_group)
global_shape = list(local_shape)
global_shape[0] *= process_group_size
# Unquantized data
if quantizer is None:
if isinstance(local_tensor, QuantizedTensorBase):
local_tensor = local_tensor.dequantize()
if comm.is_fp8_ubuf():
raise RuntimeError(
"Attempting to all-gather unquantized tensor, "
"but Userbuffers is initialized with FP8 buffers"
)
comm.copy_into_buffer(local_tensor, local_chunk=True)
global_tensor = comm.get_buffer(shape=global_shape)
return global_tensor, local_tensor
# FP8 data
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
if not isinstance(local_tensor, Float8TensorBase):
if isinstance(local_tensor, QuantizedTensorBase):
local_tensor.dequantize()
quantizer.set_usage(rowwise=True, columnwise=False)
local_tensor = quantizer(local_tensor)
if not comm.is_fp8_ubuf():
raise RuntimeError(
"Attempting to all-gather FP8 tensor, "
"but Userbuffers is not initialized with FP8 buffers"
)
comm.copy_into_buffer(local_tensor._data, local_chunk=True)
global_tensor_data = comm.get_buffer(shape=global_shape)
global_tensor = Float8TensorBase(
data=global_tensor_data,
fp8_scale_inv=local_tensor._scale_inv,
fp8_dtype=local_tensor._fp8_dtype,
quantizer=quantizer,
)
return global_tensor, local_tensor
# MXFP8 data
if isinstance(quantizer, MXFP8Quantizer):
# Cast to MXFP8 if needed
if not isinstance(local_tensor, MXFP8TensorBase):
if isinstance(local_tensor, QuantizedTensorBase):
local_tensor.dequantize()
local_tensor = quantizer(local_tensor)
if not comm.is_fp8_ubuf():
raise RuntimeError(
"Attempting to all-gather MXFP8 tensor, "
"but Userbuffers is not initialized with FP8 buffers"
)
# Check which MXFP8 buffer to communicate
if quantizer.rowwise_usage == quantizer.columnwise_usage:
raise ValueError(
"Userbuffers can only communicate one MXFP8 buffer at a time, "
f"but quantizer has rowwise_usage={quantizer.rowwise_usage}, "
f"columnwise_usage={quantizer.columnwise_usage}"
)
with_rowwise_data = quantizer.rowwise_usage
# Copy MXFP8 data to local chunk of Userbuffers buffer
local_data = (
local_tensor._rowwise_data if with_rowwise_data else local_tensor._columnwise_data
)
comm.copy_into_buffer(local_data, local_chunk=True)
# Gather scaling-inverses
if math.prod(local_shape[:-1]) % 128 != 0:
raise ValueError(
"Userbuffers requires MXFP8 tensor dims that are divisible by 128, "
f"but got MXFP8 tensor with shape={tuple(local_shape)}"
)
local_scale_inv = (
local_tensor._rowwise_scale_inv
if with_rowwise_data
else local_tensor._columnwise_scale_inv
)
local_scale_inv_size = list(local_scale_inv.size())
global_scale_inv = torch.empty(
[process_group_size * local_scale_inv_size[0]] + local_scale_inv_size[1:],
dtype=local_scale_inv.dtype,
device=local_scale_inv.device,
)
torch.distributed.all_gather_into_tensor(
global_scale_inv,
local_scale_inv,
group=process_group,
)
# Construct MXFP8 tensor with Userbuffers buffer
rowwise_data, rowwise_scale_inv = None, None
columnwise_data, columnwise_scale_inv = None, None
global_data = comm.get_buffer(shape=global_shape)
if with_rowwise_data:
rowwise_data, rowwise_scale_inv = global_data, global_scale_inv
else:
columnwise_data, columnwise_scale_inv = global_data, global_scale_inv
global_tensor = MXFP8TensorBase(
rowwise_data=rowwise_data,
rowwise_scale_inv=rowwise_scale_inv,
columnwise_data=columnwise_data,
columnwise_scale_inv=columnwise_scale_inv,
fp8_dtype=local_tensor._fp8_dtype,
quantizer=quantizer,
)
return global_tensor, local_tensor
# Unsupported data format
raise ValueError(f"Unsupported quantizer for Userbuffers ({quantizer})")
class TransformerEngineBaseModule(torch.nn.Module, ABC): class TransformerEngineBaseModule(torch.nn.Module, ABC):
"""Base TE module.""" """Base TE module."""
...@@ -625,7 +765,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -625,7 +765,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
reset("scaling_fwd") reset("scaling_fwd")
reset("scaling_bwd") reset("scaling_bwd")
def get_extra_state(self) -> torch.Tensor: def get_extra_state(self) -> Optional[torch.Tensor]:
"""Save before checkpointing.""" """Save before checkpointing."""
# This implementation is working around a few issues: # This implementation is working around a few issues:
...@@ -659,25 +799,26 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -659,25 +799,26 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Store FP8 state if needed # Store FP8 state if needed
state = None state = None
fp8_checkpoint = self.fp8_meta["fp8_checkpoint"] or self.fp8 or self.fp8_calibration fp8_checkpoint = self.fp8_meta["fp8_checkpoint"] or self.fp8 or self.fp8_calibration
if fp8_checkpoint: if not fp8_checkpoint:
return None
# Copy tensors to CPU and store
state = {} # Copy tensors to CPU and store
state["recipe"] = self.fp8_meta["recipe"] state = {}
if state["recipe"].delayed(): state["recipe"] = self.fp8_meta["recipe"]
state["scale_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].scale) if state["recipe"].delayed():
state["amax_history_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].amax_history) state["scale_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].scale)
state["scale_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].scale) state["amax_history_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].amax_history)
state["amax_history_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].amax_history) state["scale_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].scale)
state["amax_history_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].amax_history)
# Store other pickelable values
extra = {} # Store other pickelable values
for k, v in self.fp8_meta.items(): extra = {}
if k != "buffer_index_and_autocast_key" and isinstance( for k, v in self.fp8_meta.items():
v, (bool, int, float, str, tuple, list) if k != "buffer_index_and_autocast_key" and isinstance(
): v, (bool, int, float, str, tuple, list)
extra[k] = v ):
state["extra_fp8_variables"] = extra extra[k] = v
state["extra_fp8_variables"] = extra
# Serialize state into byte tensor # Serialize state into byte tensor
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -685,7 +826,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -685,7 +826,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
state_serialized = torch.frombuffer(state_serialized, dtype=torch.uint8) state_serialized = torch.frombuffer(state_serialized, dtype=torch.uint8)
return state_serialized return state_serialized
def set_extra_state(self, state: torch.Tensor) -> None: def set_extra_state(self, state: Optional[torch.Tensor]) -> None:
"""Load previous state.""" """Load previous state."""
if state is None: if state is None:
return return
...@@ -734,7 +875,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -734,7 +875,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
"""Get activation data type for AMP.""" """Get activation data type for AMP."""
# Native AMP (`torch.autocast`) gets highest priority # Native AMP (`torch.autocast`) gets highest priority
if torch.is_autocast_enabled(): if torch.is_autocast_enabled():
self.activation_dtype = torch.get_autocast_gpu_dtype() self.activation_dtype = torch_get_autocast_gpu_dtype()
return return
# All checks after this have already been performed once, thus skip # All checks after this have already been performed once, thus skip
...@@ -898,11 +1039,15 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -898,11 +1039,15 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Non-FP8 case: bgrad is fused with wgrad for this case. # Non-FP8 case: bgrad is fused with wgrad for this case.
if not ctx.fp8 and not ctx.debug: if not ctx.fp8 and not ctx.debug:
if gather_grad_output: if gather_grad_output:
if not ctx.ub_overlap_ag or ctx.ub_obj_gradout is None: if not ctx.ub_overlap_ag or ctx.ub_obj_gradout is None: # Perform NCCL all-gather
grad_output, _ = gather_along_first_dim(grad_output, ctx.tp_group) grad_output, _ = gather_along_first_dim(grad_output, ctx.tp_group)
else: else: # Initialize Userbuffers all-gather
ctx.ub_obj_gradout.copy_into_buffer(grad_output, quantizer, local_chunk=True) grad_output, _ = fill_userbuffers_buffer_for_all_gather(
grad_output = ctx.ub_obj_gradout.get_buffer(quantizer) ctx.ub_obj_gradout,
grad_output,
None,
ctx.tp_group,
)
return grad_output, None return grad_output, None
# FP8 with all-gather: unfused bgrad, fused cast + transpose # FP8 with all-gather: unfused bgrad, fused cast + transpose
...@@ -925,8 +1070,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -925,8 +1070,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
grad_output = quantizer(grad_output) grad_output = quantizer(grad_output)
# Copy into communication buffer, and replace original gradient with it # Copy into communication buffer, and replace original gradient with it
ctx.ub_obj_gradout.copy_into_buffer(grad_output, quantizer, local_chunk=True) grad_output, _ = fill_userbuffers_buffer_for_all_gather(
grad_output = ctx.ub_obj_gradout.get_buffer(quantizer) ctx.ub_obj_gradout,
grad_output,
quantizer,
ctx.tp_group,
)
else: else:
grad_output, _ = gather_along_first_dim( grad_output, _ = gather_along_first_dim(
grad_output, grad_output,
...@@ -1140,7 +1289,16 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1140,7 +1289,16 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
raise ValueError( raise ValueError(
"tensor and quantizer kwargs must be provided to construct FP8 workspace" "tensor and quantizer kwargs must be provided to construct FP8 workspace"
) )
if cache_name is not None:
# Ensure the tensor in the cache is an instance of torch.Tensor,
# as it persists beyond a single forward pass.
# Setting internal=True would cause the data to be removed in prepare_for_saving(...).
quantizer_internal = quantizer.internal
quantizer.internal = False
out = quantizer.quantize(tensor, dtype=workspace_dtype) out = quantizer.quantize(tensor, dtype=workspace_dtype)
if cache_name is not None:
quantizer.internal = quantizer_internal
# Update cache # Update cache
if cache_name is not None: if cache_name is not None:
...@@ -1188,7 +1346,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1188,7 +1346,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if self.wgrad_store is None or not self.wgrad_store.delay_wgrad_compute(): if self.wgrad_store is None or not self.wgrad_store.delay_wgrad_compute():
return return
with torch.cuda.nvtx.range(f"_{self.__class__.__name__}_wgrad"): with torch.cuda.nvtx.range(f"_{self.__class__.__name__}_wgrad"):
(wgrad, grad_bias_, _, _), _ = self.wgrad_store.pop() (wgrad, bgrad), _ = self.wgrad_store.pop()
if not self.fuse_wgrad_accumulation: if not self.fuse_wgrad_accumulation:
unfused_weights = [getattr(self, name) for name in self.weight_names] unfused_weights = [getattr(self, name) for name in self.weight_names]
weight_tensor = noop_cat(unfused_weights) weight_tensor = noop_cat(unfused_weights)
...@@ -1197,9 +1355,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1197,9 +1355,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if self.use_bias: if self.use_bias:
bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names]) bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names])
if bias_tensor.grad is None: if bias_tensor.grad is None:
bias_tensor.grad = grad_bias_.to(bias_tensor.dtype) bias_tensor.grad = bgrad.to(bias_tensor.dtype)
del grad_bias_
del wgrad
def _validate_name(self): def _validate_name(self):
""" """
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
"""GroupedLinear API""" """GroupedLinear API"""
from typing import Union, Optional, Callable, Tuple, List from typing import Union, Optional, Callable, Tuple, List
import warnings
import functools import functools
import torch import torch
...@@ -43,7 +44,7 @@ from ..graph import is_graph_capturing ...@@ -43,7 +44,7 @@ from ..graph import is_graph_capturing
from ..cpu_offload import is_cpu_offload_enabled from ..cpu_offload import is_cpu_offload_enabled
from ..tensor.quantized_tensor import ( from ..tensor.quantized_tensor import (
QuantizedTensor, QuantizedTensorBase,
Quantizer, Quantizer,
prepare_for_saving, prepare_for_saving,
restore_from_saved, restore_from_saved,
...@@ -182,11 +183,11 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -182,11 +183,11 @@ class _GroupedLinear(torch.autograd.Function):
# TODO: update after #1638 is merged. # pylint: disable=fixme # TODO: update after #1638 is merged. # pylint: disable=fixme
if weight_requires_grad: if weight_requires_grad:
for inputmat in inputmats: for inputmat in inputmats:
if isinstance(inputmat, QuantizedTensor): if isinstance(inputmat, QuantizedTensorBase):
inputmat.update_usage(rowwise_usage=False, columnwise_usage=True) inputmat.update_usage(rowwise_usage=False, columnwise_usage=True)
if inp.requires_grad: if inp.requires_grad:
for weight in weights_fp8: for weight in weights_fp8:
if isinstance(weight, QuantizedTensor): if isinstance(weight, QuantizedTensorBase):
weight.update_usage(columnwise_usage=True) weight.update_usage(columnwise_usage=True)
tensors_to_save, tensor_objects = prepare_for_saving( tensors_to_save, tensor_objects = prepare_for_saving(
...@@ -299,7 +300,7 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -299,7 +300,7 @@ class _GroupedLinear(torch.autograd.Function):
) )
for weight, quantizer in zip(weights, ctx.weight_quantizers): for weight, quantizer in zip(weights, ctx.weight_quantizers):
if quantizer is not None and isinstance(weight, QuantizedTensor): if quantizer is not None and isinstance(weight, QuantizedTensorBase):
weight.update_usage( weight.update_usage(
rowwise_usage=quantizer.rowwise_usage, rowwise_usage=quantizer.rowwise_usage,
columnwise_usage=quantizer.columnwise_usage, columnwise_usage=quantizer.columnwise_usage,
...@@ -663,7 +664,7 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -663,7 +664,7 @@ class GroupedLinear(TransformerEngineBaseModule):
produced) produced)
""" """
assert not isinstance( assert not isinstance(
inp, QuantizedTensor inp, QuantizedTensorBase
), "GroupedLinear doesn't support input tensor in FP8." ), "GroupedLinear doesn't support input tensor in FP8."
assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs." assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs."
...@@ -675,9 +676,14 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -675,9 +676,14 @@ class GroupedLinear(TransformerEngineBaseModule):
weight_tensors = [getattr(self, f"weight{i}") for i in range(self.num_gemms)] weight_tensors = [getattr(self, f"weight{i}") for i in range(self.num_gemms)]
bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)]
if not self.fp8: if not self.fp8 and any(isinstance(w, QuantizedTensorBase) for w in weight_tensors):
warnings.warn(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
weight_tensors = [ weight_tensors = [
w.dequantize() if isinstance(w, QuantizedTensor) else w for w in weight_tensors w.dequantize() if isinstance(w, QuantizedTensorBase) else w
for w in weight_tensors
] ]
input_quantizers, weight_quantizers, output_quantizers = ( input_quantizers, weight_quantizers, output_quantizers = (
......
...@@ -94,6 +94,9 @@ class LayerNorm(_LayerNormOp): ...@@ -94,6 +94,9 @@ class LayerNorm(_LayerNormOp):
) )
kwargs["dtype"] = params_dtype kwargs["dtype"] = params_dtype
# Flag for sequence parallelism (custom Megatron-LM integration)
self.sequence_parallel: Optional[bool] = sequence_parallel
# Initialize layer norm operation # Initialize layer norm operation
super().__init__( super().__init__(
normalized_shape, normalized_shape,
...@@ -102,12 +105,6 @@ class LayerNorm(_LayerNormOp): ...@@ -102,12 +105,6 @@ class LayerNorm(_LayerNormOp):
**kwargs, **kwargs,
) )
# Flag for sequence parallelism (custom Megatron-LM integration)
self.sequence_parallel: Optional[bool] = sequence_parallel
if sequence_parallel is not None:
self.weight.sequence_parallel = sequence_parallel
self.bias.sequence_parallel = sequence_parallel
def reset_layer_norm_parameters(self) -> None: def reset_layer_norm_parameters(self) -> None:
"""Init LN params""" """Init LN params"""
warnings.warn( warnings.warn(
...@@ -136,7 +133,7 @@ class LayerNorm(_LayerNormOp): ...@@ -136,7 +133,7 @@ class LayerNorm(_LayerNormOp):
super().reset_parameters() super().reset_parameters()
# Set flag for sequence parallelism (custom Megatron-LM integration) # Set flag for sequence parallelism (custom Megatron-LM integration)
if getattr(self, "sequence_parallel", None) is not None: if self.sequence_parallel is not None:
self.weight.sequence_parallel = self.sequence_parallel self.weight.sequence_parallel = self.sequence_parallel
self.bias.sequence_parallel = self.sequence_parallel self.bias.sequence_parallel = self.sequence_parallel
......
...@@ -9,7 +9,6 @@ from typing import Callable, Dict, Optional, Tuple, Union ...@@ -9,7 +9,6 @@ from typing import Callable, Dict, Optional, Tuple, Union
from functools import reduce from functools import reduce
from operator import mul as multiply_op from operator import mul as multiply_op
import functools
import torch import torch
from torch.nn import init from torch.nn import init
...@@ -18,6 +17,7 @@ import transformer_engine_torch as tex ...@@ -18,6 +17,7 @@ import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch import torch_version from transformer_engine.pytorch import torch_version
from .base import ( from .base import (
fill_userbuffers_buffer_for_all_gather,
get_workspace, get_workspace,
get_ub, get_ub,
TransformerEngineBaseModule, TransformerEngineBaseModule,
...@@ -53,9 +53,10 @@ from ..distributed import ( ...@@ -53,9 +53,10 @@ from ..distributed import (
from ..constants import GemmParallelModes, dist_group_type from ..constants import GemmParallelModes, dist_group_type
from ..jit import no_torch_dynamo from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing from ..graph import is_graph_capturing
from ._common import apply_normalization, noop_cat, _fix_gathered_fp8_transpose, WeightGradStore from ._common import apply_normalization, noop_cat, WeightGradStore
from ..tensor.quantized_tensor import ( from ..tensor.quantized_tensor import (
QuantizedTensor, QuantizedTensor,
QuantizedTensorBase,
Quantizer, Quantizer,
prepare_for_saving, prepare_for_saving,
restore_from_saved, restore_from_saved,
...@@ -144,8 +145,10 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -144,8 +145,10 @@ class _LayerNormLinear(torch.autograd.Function):
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
out_features, in_features = weight.shape out_features, in_features = weight.shape
inp_shape = inp.shape inp_shape = inp.shape
inp_requires_grad = inp.requires_grad
assert inp_shape[-1] == in_features, "GEMM not possible" assert inp_shape[-1] == in_features, "GEMM not possible"
inputmat = inp.view((-1, in_features)) inp = inp.view((-1, in_features))
inputmat = inp
if fp8: if fp8:
assert_dim_for_fp8_exec(inputmat, weight) assert_dim_for_fp8_exec(inputmat, weight)
...@@ -158,42 +161,43 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -158,42 +161,43 @@ class _LayerNormLinear(torch.autograd.Function):
nvtx_range_pop(f"{nvtx_label}.norm_input_cast") nvtx_range_pop(f"{nvtx_label}.norm_input_cast")
tp_world_size = get_distributed_world_size(tp_group) tp_world_size = get_distributed_world_size(tp_group)
ub_overlap_ag_fprop = (
ub_overlap_ag_fprop and is_grad_enabled and not return_layernorm_output
)
weight_requires_grad = weight.requires_grad weight_requires_grad = weight.requires_grad
backward_needs_input = is_grad_enabled and weight_requires_grad backward_needs_input = is_grad_enabled and weight_requires_grad
with_input_all_gather = parallel_mode == "column" and sequence_parallel with_input_all_gather = parallel_mode == "column" and sequence_parallel
# Check if Userbuffers is supported # Configure Userbuffers communication (comm+GEMM overlap)
if fp8: ub_obj = None
if any([ub_overlap_ag_fprop, ub_overlap_rs_fprop]) and not ( ub_type = None
FP8GlobalStateManager.get_fp8_recipe().float8_per_tensor_scaling() ub_overlap_ag_fprop = (
): ub_overlap_ag_fprop and is_grad_enabled and not return_layernorm_output
raise NotImplementedError( )
"Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor" if ub_overlap_rs_fprop:
" current scaling" ub_obj = get_ub(ub_name + "_fprop")
) ub_type = tex.CommOverlapType.RS
elif ub_overlap_ag_fprop:
ub_obj = get_ub(ub_name + "_fprop")
ub_type = tex.CommOverlapType.AG
# Configure quantizer for norm output # Configure quantizer for norm output
if fp8: if fp8:
if input_quantizer is None: if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor") raise ValueError("Missing quantizer for input tensor")
columnwise_usage = backward_needs_input input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input)
if ( if with_input_all_gather and isinstance(
columnwise_usage input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
and with_input_all_gather
and not isinstance(input_quantizer, MXFP8Quantizer)
): ):
columnwise_usage = False # All-gather is not supported with FP8 column-wise data
input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) input_quantizer.set_usage(columnwise=False)
# Avoid quantized norm kernel if norm output will be returned # Do TP communication in high precision if quantized format
# or if a gather of ln_out must be in high precision. # does not support communication
force_hp_blockwise_ln_out_gather = ( force_hp_blockwise_ln_out_gather = (
fp8 and with_input_all_gather and isinstance(input_quantizer, Float8BlockQuantizer) fp8 and with_input_all_gather and isinstance(input_quantizer, Float8BlockQuantizer)
) # Perform TP communication in high precision. )
# Avoid quantized norm kernel if norm output will be returned
# or if a gather of ln_out must be in high precision.
with_quantized_norm = ( with_quantized_norm = (
fp8 fp8
and not return_layernorm_output and not return_layernorm_output
...@@ -215,16 +219,19 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -215,16 +219,19 @@ class _LayerNormLinear(torch.autograd.Function):
fwd_ln_sm_margin, fwd_ln_sm_margin,
zero_centered_gamma, zero_centered_gamma,
) )
nvtx_range_pop(f"{nvtx_label}.norm")
# Store unquantized layer norm output if we need to return it
ln_out_return = None ln_out_return = None
if return_layernorm_output or return_layernorm_output_gathered: if return_layernorm_output or return_layernorm_output_gathered:
ln_out_return = ln_out ln_out_return = ln_out
nvtx_range_pop(f"{nvtx_label}.norm")
# Prepare GEMM input # ------------------------------------------------------
# Prepare GEMM input tensor
# Note: Cast to expected dtype and perform tensor-parallel communication # Note: Cast to expected dtype and perform tensor-parallel communication
# ------------------------------------------------------
nvtx_range_push(f"{nvtx_label}.gemm_input_cast_comm") nvtx_range_push(f"{nvtx_label}.gemm_input_cast_comm")
ln_out_total = None ln_out_total = None
ub_obj_fprop = None
if with_input_all_gather: if with_input_all_gather:
if return_layernorm_output_gathered: if return_layernorm_output_gathered:
# Perform all-gather in high precision if gathered # Perform all-gather in high precision if gathered
...@@ -232,47 +239,53 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -232,47 +239,53 @@ class _LayerNormLinear(torch.autograd.Function):
ln_out_total, _ = gather_along_first_dim(ln_out, tp_group) ln_out_total, _ = gather_along_first_dim(ln_out, tp_group)
ln_out_return = ln_out_total ln_out_return = ln_out_total
if fp8 or debug: if fp8 or debug:
ln_out = input_quantizer(ln_out) if not force_hp_blockwise_ln_out_gather:
ln_out = input_quantizer(ln_out)
input_quantizer.set_usage(rowwise=True, columnwise=False) input_quantizer.set_usage(rowwise=True, columnwise=False)
ln_out_total = input_quantizer(ln_out_total) ln_out_total = input_quantizer(ln_out_total)
else: else:
quantizer = None
if fp8 or debug: if fp8 or debug:
quantizer = input_quantizer
if not with_quantized_norm and not force_hp_blockwise_ln_out_gather: if not with_quantized_norm and not force_hp_blockwise_ln_out_gather:
ln_out = input_quantizer(ln_out) ln_out = quantizer(ln_out)
input_quantizer.set_usage(rowwise=True, columnwise=False) quantizer.set_usage(rowwise=True, columnwise=False)
if ub_overlap_ag_fprop: if ub_overlap_ag_fprop: # Initialize Userbuffers all-gather
# Copy into Userbuffers buffer ln_out_total, _ = fill_userbuffers_buffer_for_all_gather(
ub_obj_fprop = get_ub(ub_name + "_fprop") ub_obj,
ub_obj_fprop.get_buffer(input_quantizer, local_chunk=True).copy_(ln_out) ln_out,
ln_out_total = ub_obj_fprop.get_buffer(input_quantizer) quantizer,
else: tp_group,
# All-gather with NCCL )
else: # Perform NCCL all-gather
ln_out_total, _ = gather_along_first_dim( ln_out_total, _ = gather_along_first_dim(
ln_out, ln_out,
tp_group, tp_group,
quantizer=(input_quantizer if fp8 or debug else None), quantizer=quantizer,
) )
else: else:
if (fp8 or debug) and not with_quantized_norm: if (fp8 or debug) and not with_quantized_norm:
ln_out = input_quantizer(ln_out) ln_out = input_quantizer(ln_out)
ln_out_total = ln_out ln_out_total = ln_out
nvtx_range_pop(f"{nvtx_label}.gemm_input_cast_comm") nvtx_range_pop(f"{nvtx_label}.gemm_input_cast_comm")
# ------------------------------------------------------
# GEMM input tensor is ready...
# ------------------------------------------------------
# Cast weight to expected dtype # ------------------------------------------------------
# Prepare weight tensor
# ------------------------------------------------------
weightmat = weight weightmat = weight
quantized_weight = False quantized_weight = False
if not fp8 and not debug: if fp8 or debug:
weightmat = cast_if_needed(weightmat, activation_dtype) quantized_weight = not isinstance(weight, QuantizedTensorBase)
else:
quantized_weight = not isinstance(weight, QuantizedTensor)
# Configure quantizer # Configure quantizer
if weight_quantizer is not None: if weight_quantizer is not None:
weight_quantizer.set_usage(rowwise=True, columnwise=True) weight_quantizer.set_usage(rowwise=True, columnwise=True)
# FP8 cast to workspace buffer # Get quantized weight
update_workspace = is_first_microbatch is None or is_first_microbatch update_workspace = is_first_microbatch is None or is_first_microbatch
weightmat = module.get_weight_workspace( weightmat = module.get_weight_workspace(
tensor=weight, tensor=weight,
quantizer=weight_quantizer, quantizer=weight_quantizer,
...@@ -282,17 +295,21 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -282,17 +295,21 @@ class _LayerNormLinear(torch.autograd.Function):
fsdp_group=fsdp_group, fsdp_group=fsdp_group,
workspace_dtype=activation_dtype, workspace_dtype=activation_dtype,
) )
weightmat.update_usage(rowwise_usage=True)
else:
weightmat = cast_if_needed(weightmat, activation_dtype) # Cast for AMP
# ------------------------------------------------------
# Weight tensor is ready for GEMM...
# ------------------------------------------------------
# Cast bias to expected dtype # Cast bias to expected dtype
bias_dtype = activation_dtype bias_dtype = activation_dtype
if needs_quantized_gemm(ln_out_total) and activation_dtype == torch.float32: if needs_quantized_gemm(ln_out_total) and activation_dtype == torch.float32:
# cuBLAS does not support FP8 GEMM with FP32 bias, so we cast to BF16
bias_dtype = torch.bfloat16 bias_dtype = torch.bfloat16
bias = cast_if_needed(bias, bias_dtype) if bias is not None else bias bias = cast_if_needed(bias, bias_dtype) if bias is not None else bias
# Configure output quantizer
if output_quantizer is not None:
output_quantizer.set_usage(rowwise=True, columnwise=False)
# Calibrate quantizers if needed # Calibrate quantizers if needed
if not fp8 and fp8_calibration: if not fp8 and fp8_calibration:
if input_quantizer is not None: if input_quantizer is not None:
...@@ -300,47 +317,80 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -300,47 +317,80 @@ class _LayerNormLinear(torch.autograd.Function):
if weight_quantizer is not None: if weight_quantizer is not None:
weight_quantizer.calibrate(weight) weight_quantizer.calibrate(weight)
ub_obj = None # Choose whether to use GEMM kernel with split accumulator
ub_type = None use_split_accumulator = _2X_ACC_FPROP
rs_out = None
if ub_overlap_rs_fprop:
ub_obj = get_ub(ub_name + "_fprop")
ub_type = tex.CommOverlapType.RS
out_shape = [reduce(multiply_op, inp_shape[:-1]) // tp_world_size, out_features]
rs_out = torch.empty(out_shape, dtype=activation_dtype, device=ln_out_total.device)
elif ub_overlap_ag_fprop:
ub_obj = get_ub(ub_name + "_fprop")
ub_type = tex.CommOverlapType.AG
if fp8:
assert ub_obj.is_fp8_ubuf(), "AG overlap with FP8 GEMM inputs requires FP8 buffer."
ln_out_total = ub_obj.get_buffer(input_quantizer)
nvtx_range_push(f"{nvtx_label}.gemm")
fprop_gemm_use_split_accumulator = _2X_ACC_FPROP
if fp8: if fp8:
recipe = FP8GlobalStateManager.get_fp8_recipe() recipe = FP8GlobalStateManager.get_fp8_recipe()
if hasattr(recipe, "fp8_gemm_fprop"): if hasattr(recipe, "fp8_gemm_fprop"):
fprop_gemm_use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator
# Configure output quantizer
if output_quantizer is not None:
output_quantizer.set_usage(rowwise=True, columnwise=False)
out, *_, rs_out = general_gemm( # Output buffer for Userbuffers reduce-scatter
reduce_scatter_out = None
if ub_overlap_rs_fprop:
out_shape = list(inp_shape)
out_shape[0] //= tp_world_size
out_shape[-1] = out_features
reduce_scatter_out = torch.empty(out_shape, dtype=activation_dtype, device=inp.device)
# ------------------------------------------------------
# Forward GEMM
# Note: y = x * w^T
# ------------------------------------------------------
nvtx_range_push(f"{nvtx_label}.gemm")
gemm_out, *_, reduce_scatter_out = general_gemm(
weightmat, weightmat,
ln_out_total, ln_out_total,
get_workspace(), get_workspace(),
quantization_params=output_quantizer, quantization_params=output_quantizer,
out_dtype=activation_dtype, out_dtype=activation_dtype,
bias=bias, bias=bias,
use_split_accumulator=fprop_gemm_use_split_accumulator, use_split_accumulator=use_split_accumulator,
ub=ub_obj, ub=ub_obj,
ub_type=ub_type, ub_type=ub_type,
extra_output=rs_out, extra_output=reduce_scatter_out,
) )
nvtx_range_pop(f"{nvtx_label}.gemm") nvtx_range_pop(f"{nvtx_label}.gemm")
# ------------------------------------------------------
# Finished forward GEMM...
# ------------------------------------------------------
# Deallocate GEMM input tensor if no longer needed
if not weight.requires_grad and not return_layernorm_output:
ln_out = ln_out_total = None
clear_tensor_data(ln_out, ln_out_total)
# ------------------------------------------------------
# Prepare output tensor
# Note: Perform tensor-parallel communication
# ------------------------------------------------------
out = None
if ub_overlap_rs_fprop:
out = reduce_scatter_out
elif parallel_mode == "row" and tp_size > 1:
nvtx_range_push(f"{nvtx_label}.row_parallel_comm")
out = gemm_out
if sequence_parallel:
out, _ = reduce_scatter_along_first_dim(out, tp_group)
elif tensor_parallel:
if symmetric_ar_type is not None:
out, _ = symmetric_all_reduce(out, tp_group, all_reduce_type=symmetric_ar_type)
else:
out, _ = allreduce(out, tp_group)
nvtx_range_pop(f"{nvtx_label}.row_parallel_comm")
else:
out = gemm_out
out = out.view(-1, *inp_shape[1:-1], out_features)
# ------------------------------------------------------
# Output tensor is ready to return...
# ------------------------------------------------------
if not weight.requires_grad: # ------------------------------------------------------
if not return_layernorm_output: # Cache state for backward pass
ln_out = ln_out_total = None # ------------------------------------------------------
clear_tensor_data(ln_out, ln_out_total)
if is_grad_enabled: if is_grad_enabled:
ctx.weight_quantizer = weight_quantizer ctx.weight_quantizer = weight_quantizer
...@@ -351,19 +401,15 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -351,19 +401,15 @@ class _LayerNormLinear(torch.autograd.Function):
# Input with column-wise usage is needed for wgrad GEMM. # Input with column-wise usage is needed for wgrad GEMM.
if backward_needs_input: if backward_needs_input:
if isinstance(ln_out, QuantizedTensor): if isinstance(ln_out, QuantizedTensorBase):
# For sequence parallel in vanilla FP8, rowwise data is # For sequence parallel in vanilla FP8, rowwise data is
# to gather the input. For MXFP8, columnwise only data # to gather the input. For MXFP8, columnwise only data
# can be allgathered. # can be allgathered.
if isinstance(ln_out, MXFP8TensorBase) or not ctx.ln_out_needs_gather: if isinstance(ln_out, MXFP8TensorBase) or not ctx.ln_out_needs_gather:
ln_out.update_usage(rowwise_usage=False) ln_out.update_usage(rowwise_usage=False)
# For force_hp_blockwise_ln_out_gather, we should
# be saving the unquantized ln_out to ctx.
assert not force_hp_blockwise_ln_out_gather
# Weight with column-wise usage is needed for dgrad GEMM. # Weight with column-wise usage is needed for dgrad GEMM.
if isinstance(weightmat, QuantizedTensor): if isinstance(weightmat, QuantizedTensorBase):
weightmat.update_usage(columnwise_usage=True) weightmat.update_usage(columnwise_usage=True)
if cpu_offloading: if cpu_offloading:
...@@ -406,7 +452,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -406,7 +452,7 @@ class _LayerNormLinear(torch.autograd.Function):
) )
ctx.save_for_backward(*tensors_to_save) ctx.save_for_backward(*tensors_to_save)
ctx.tensor_objects = tensor_objects ctx.tensor_objects = tensor_objects
ctx.requires_dgrad = inp.requires_grad ctx.requires_dgrad = inp_requires_grad
ctx.requires_wgrad = weight.requires_grad ctx.requires_wgrad = weight.requires_grad
ctx.quantized_weight = quantized_weight ctx.quantized_weight = quantized_weight
if fuse_wgrad_accumulation and weight.requires_grad: if fuse_wgrad_accumulation and weight.requires_grad:
...@@ -439,7 +485,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -439,7 +485,7 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.ub_bulk_wgrad = ub_bulk_wgrad ctx.ub_bulk_wgrad = ub_bulk_wgrad
ctx.ub_bulk_dgrad = ub_bulk_dgrad ctx.ub_bulk_dgrad = ub_bulk_dgrad
ctx.ub_name = ub_name ctx.ub_name = ub_name
ctx.requires_dgrad = inp.requires_grad ctx.requires_dgrad = inp_requires_grad
ctx.normalization = normalization ctx.normalization = normalization
ctx.reduce_and_update_bwd_fp8_tensors = False ctx.reduce_and_update_bwd_fp8_tensors = False
if ctx.fp8 and requires_grad(inp, ln_weight, ln_bias, weight, bias): if ctx.fp8 and requires_grad(inp, ln_weight, ln_bias, weight, bias):
...@@ -450,29 +496,16 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -450,29 +496,16 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.wgrad_store = wgrad_store ctx.wgrad_store = wgrad_store
ctx.debug = debug ctx.debug = debug
# Row Parallel Linear # ------------------------------------------------------
if ub_overlap_rs_fprop: # Cached state for backward pass is ready...
out = rs_out # ------------------------------------------------------
elif parallel_mode == "row":
nvtx_range_push(f"{nvtx_label}.row_parallel_comm")
if sequence_parallel:
out, _ = reduce_scatter_along_first_dim(out, tp_group)
elif tensor_parallel:
if symmetric_ar_type is not None:
out, _ = symmetric_all_reduce(out, tp_group, all_reduce_type=symmetric_ar_type)
else:
out, _ = allreduce(out, tp_group)
nvtx_range_pop(f"{nvtx_label}.row_parallel_comm")
# [*, in_features] -> [*, out_features] except first dimension changes for SP
out = out.view(-1, *inp_shape[1:-1], out_features)
if return_layernorm_output: if return_layernorm_output:
if return_layernorm_output_gathered: if return_layernorm_output_gathered:
shape = list(inp.shape) shape = list(inp_shape)
shape[0] *= tp_size shape[0] *= tp_size
return out, ln_out_return.view(shape) return out, ln_out_return.view(shape)
return out, ln_out_return.view_as(inp) return out, ln_out_return.view(inp_shape)
return out return out
@staticmethod @staticmethod
...@@ -487,24 +520,6 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -487,24 +520,6 @@ class _LayerNormLinear(torch.autograd.Function):
nvtx_label = f"{nvtx_label}.{ctx.ub_name}" nvtx_label = f"{nvtx_label}.{ctx.ub_name}"
with torch.cuda.nvtx.range("_LayerNormLinear_backward"): with torch.cuda.nvtx.range("_LayerNormLinear_backward"):
if (
ctx.fp8
and any(
[
ctx.ub_overlap_ag,
ctx.ub_overlap_rs_dgrad,
ctx.ub_bulk_dgrad,
ctx.ub_bulk_wgrad,
]
)
and (ctx.fp8_recipe is not None)
):
if not ctx.fp8_recipe.float8_per_tensor_scaling():
raise NotImplementedError(
"Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor"
" current scaling"
)
saved_tensors = ctx.saved_tensors saved_tensors = ctx.saved_tensors
( # pylint: disable=unbalanced-tuple-unpacking ( # pylint: disable=unbalanced-tuple-unpacking
inputmat, inputmat,
...@@ -549,66 +564,50 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -549,66 +564,50 @@ class _LayerNormLinear(torch.autograd.Function):
if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation: if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation:
origin_weight.main_grad = main_grad origin_weight.main_grad = main_grad
# Configure Userbuffers communication (comm+GEMM overlap)
ctx.ub_obj_gradout = None ctx.ub_obj_gradout = None
ub_obj_dgrad = None ub_obj_dgrad = None
ub_obj_wgrad = None ub_obj_wgrad = None
ub_type_dgrad = None ub_type_dgrad = None
ub_type_wgrad = None ub_type_wgrad = None
dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]] dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]]
rs_out = None
dgrad_bulk = None
if ctx.ub_overlap_ag: if ctx.ub_overlap_ag:
# Overlap grad_output all-gather with dgrad compute # Overlap grad_output all-gather with dgrad compute
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
ub_obj_dgrad = ctx.ub_obj_gradout ub_obj_dgrad = ctx.ub_obj_gradout
ub_type_dgrad = tex.CommOverlapType.AG ub_type_dgrad = tex.CommOverlapType.AG
elif ctx.ub_overlap_rs_dgrad: elif ctx.ub_overlap_rs_dgrad:
# Overlap dgrad reduce-scatter with dgrad compute # Overlap dgrad reduce-scatter with dgrad compute
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
ub_obj_dgrad = ctx.ub_obj_gradout ub_obj_dgrad = ctx.ub_obj_gradout
ub_type_dgrad = tex.CommOverlapType.RS ub_type_dgrad = tex.CommOverlapType.RS
rs_out = torch.empty(
dgrad_shape, dtype=ctx.activation_dtype, device=inputmat.device
)
else: else:
if ctx.ub_bulk_dgrad: if ctx.ub_bulk_dgrad:
# Overlap inputmat all-gather with dgrad compute # Overlap inputmat all-gather with dgrad compute
# NOTE: Copying into communication buffer will always prefer rowwise data,
# and will copy columnwise data if rowwise does not exist. In that case,
# the all-gather will apply to the leading dimension of the transpose,
# which then needs to be interleaved correctly before WGRAD.
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
ub_obj_dgrad = ctx.ub_obj_gradout ub_obj_dgrad = ctx.ub_obj_gradout
ub_type_dgrad = tex.CommOverlapType.AG ub_type_dgrad = tex.CommOverlapType.AG
ub_obj_dgrad.copy_into_buffer(ln_out, ctx.input_quantizer, local_chunk=True)
if ctx.ub_bulk_wgrad: if ctx.ub_bulk_wgrad:
# Overlap dgrad reduce-scatter with wgrad compute # Overlap dgrad reduce-scatter with wgrad compute
ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad") ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad")
ub_type_wgrad = tex.CommOverlapType.RS ub_type_wgrad = tex.CommOverlapType.RS
ub_obj_wgrad.set_buffer_params(ctx.grad_input_quantizer)
dgrad_bulk = ub_obj_wgrad.get_buffer(ctx.grad_input_quantizer) # --------------------------------------------------
# Prepare grad output tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
# --------------------------------------------------
# Configure quantizer for grad output tensor # Configure quantizer for grad output tensor
# Note: dgrad GEMM requires row-wise usage, wgrad GEMM # Note: dgrad GEMM requires row-wise usage, wgrad GEMM
# requires column-wise usage # requires column-wise usage
if ctx.grad_output_quantizer is not None: if ctx.grad_output_quantizer is not None:
rowwise_usage = True quantizer = ctx.grad_output_quantizer
columnwise_usage = True quantizer.set_usage(rowwise=True, columnwise=True)
if ctx.ub_overlap_ag and isinstance( if ctx.ub_overlap_ag:
ctx.grad_output_quantizer, # Userbuffers only supports communication for one
(Float8Quantizer, Float8CurrentScalingQuantizer), # tensor usage at a time. Configure quantizer with
): # usage for only dgrad GEMM.
# If data is in FP8 and communication is handled quantizer.set_usage(columnwise=False)
# with Userbuffers, we compute FP8 transposes
# manually
columnwise_usage = False
ctx.grad_output_quantizer.set_usage(
rowwise=rowwise_usage,
columnwise=columnwise_usage,
)
# Prepare grad output tensor # Prepare grad output tensor
# Note: Cast to expected dtype and perform tensor-parallel communication # Note: Cast to expected dtype and perform tensor-parallel communication
...@@ -624,12 +623,21 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -624,12 +623,21 @@ class _LayerNormLinear(torch.autograd.Function):
) )
nvtx_range_pop(f"{nvtx_label}.grad_output_preprocess") nvtx_range_pop(f"{nvtx_label}.grad_output_preprocess")
# Launch tensor-parallel communication for LayerNorm out tensor # --------------------------------------------------
# Grad output tensor is ready for computing grad input...
# --------------------------------------------------
# --------------------------------------------------
# Prepare GEMM input tensor
# Note: Input tensor is needed for wgrad GEMM.
# Tensor-parallel communication is overlapped with dgrad
# GEMM.
# --------------------------------------------------
ln_out_total = None ln_out_total = None
ln_out_total_work = None ln_out_total_work = None
if ctx.ln_out_needs_gather and not ctx.ub_bulk_dgrad: if ctx.ln_out_needs_gather:
quantizer = None quantizer = None
if ctx.input_quantizer is not None: if ctx.input_quantizer is not None and not ctx.force_hp_blockwise_ln_out_gather:
quantizer = ctx.input_quantizer quantizer = ctx.input_quantizer
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
# If data is in FP8, we compute FP8 transposes manually # If data is in FP8, we compute FP8 transposes manually
...@@ -637,70 +645,92 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -637,70 +645,92 @@ class _LayerNormLinear(torch.autograd.Function):
else: else:
# wgrad GEMM requires input with column-wise usage # wgrad GEMM requires input with column-wise usage
quantizer.set_usage(rowwise=False, columnwise=True) quantizer.set_usage(rowwise=False, columnwise=True)
nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input") if ctx.ub_bulk_dgrad:
# async_op is not compatible with high precision gather since ln_out_total, _ = fill_userbuffers_buffer_for_all_gather(
# gather_along_first_dim does not offer callback chaining. ub_obj_dgrad,
gather_quantizer = None if ctx.force_hp_blockwise_ln_out_gather else quantizer ln_out,
ln_out_total, ln_out_total_work = gather_along_first_dim( quantizer,
ln_out, ctx.tp_group,
ctx.tp_group, )
async_op=True, else:
quantizer=gather_quantizer, nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input")
) ln_out_total, ln_out_total_work = gather_along_first_dim(
nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_input") ln_out,
ctx.tp_group,
async_op=True,
quantizer=quantizer,
)
nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_input")
else: else:
ln_out_total = ln_out ln_out_total = ln_out
# --------------------------------------------------
# Check whether to output wgrad GEMM directly into main grad # Input tensor is ready for computing grad weight...
if ctx.is_first_microbatch is not None: # --------------------------------------------------
accumulate_wgrad_into_param_main_grad = (
ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch # --------------------------------------------------
) # Compute grad input tensor
else: # Note: Gradient w.r.t. GEMM input (i.e. norm output).
accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation # --------------------------------------------------
# dgrad GEMM # Make sure required data is available
if ctx.grad_input_quantizer is not None: if isinstance(grad_output, QuantizedTensorBase):
ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) grad_output.update_usage(rowwise_usage=True)
if ctx.weight_quantizer is not None and isinstance(weight, QuantizedTensorBase):
nvtx_range_push(f"{nvtx_label}.dgrad_gemm") weight.update_usage(columnwise_usage=True)
dgrad_gemm_use_split_accumulator = _2X_ACC_DGRAD
# Choose whether to use GEMM kernel with split accumulator
use_split_accumulator = _2X_ACC_DGRAD
if ctx.fp8: if ctx.fp8:
recipe = ctx.fp8_recipe recipe = ctx.fp8_recipe
if hasattr(recipe, "fp8_gemm_dgrad"): if hasattr(recipe, "fp8_gemm_dgrad"):
dgrad_gemm_use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator
# Update grad input quantizer
if ctx.grad_input_quantizer is not None:
ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False)
if ctx.weight_quantizer is not None and isinstance(weight, QuantizedTensor): # Output buffers for Userbuffers reduce-scatter
weight.update_usage( gemm_out = None
rowwise_usage=ctx.weight_quantizer.rowwise_usage, reduce_scatter_out = None
columnwise_usage=ctx.weight_quantizer.columnwise_usage, if ctx.ub_overlap_rs_dgrad:
reduce_scatter_out = torch.empty(
dgrad_shape, dtype=ctx.activation_dtype, device=grad_outputs[0].device
) )
dgrad, *_ = general_gemm( elif ctx.ub_bulk_wgrad:
gemm_out = ub_obj_wgrad.get_buffer(local_chunk=False)
# dgrad GEMM
# Note: dx = dy * w
nvtx_range_push(f"{nvtx_label}.dgrad_gemm")
gemm_out, *_, reduce_scatter_out = general_gemm(
weight, weight,
grad_output, grad_output,
get_workspace(), get_workspace(),
layout="NN", layout="NN",
grad=True, grad=True,
quantization_params=ctx.grad_input_quantizer, quantization_params=ctx.grad_input_quantizer,
out=dgrad_bulk, out=gemm_out,
out_dtype=ctx.activation_dtype, out_dtype=ctx.activation_dtype,
use_split_accumulator=dgrad_gemm_use_split_accumulator, use_split_accumulator=use_split_accumulator,
ub=ub_obj_dgrad, ub=ub_obj_dgrad,
ub_type=ub_type_dgrad, ub_type=ub_type_dgrad,
extra_output=rs_out, extra_output=reduce_scatter_out,
bulk_overlap=ctx.ub_bulk_dgrad, bulk_overlap=ctx.ub_bulk_dgrad,
) )
nvtx_range_pop(f"{nvtx_label}.dgrad_gemm") nvtx_range_pop(f"{nvtx_label}.dgrad_gemm")
# Launch tensor-parallel communication # Prepare grad input tensor
# Note: Perform tensor-parallel communication
dgrad = None
dgrad_work = None dgrad_work = None
if ctx.ub_overlap_rs_dgrad: if ctx.ub_overlap_rs_dgrad:
dgrad = rs_out dgrad = reduce_scatter_out
elif ctx.parallel_mode == "column" and not ctx.ub_bulk_wgrad: elif ctx.ub_bulk_wgrad:
dgrad = ub_obj_wgrad.get_buffer(local_chunk=True)
elif ctx.parallel_mode == "column" and ctx.tp_size > 1:
nvtx_range_push(f"{nvtx_label}.column_parallel_comm_dgrad") nvtx_range_push(f"{nvtx_label}.column_parallel_comm_dgrad")
dgrad = gemm_out
if ctx.sequence_parallel: if ctx.sequence_parallel:
if ctx.return_layernorm_output and ctx.return_layernorm_output_gathered:
dgrad = dgrad + grad_outputs[1].view_as(dgrad)
dgrad, dgrad_work = reduce_scatter_along_first_dim( dgrad, dgrad_work = reduce_scatter_along_first_dim(
dgrad, dgrad,
ctx.tp_group, ctx.tp_group,
...@@ -709,41 +739,55 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -709,41 +739,55 @@ class _LayerNormLinear(torch.autograd.Function):
else: else:
dgrad, dgrad_work = allreduce(dgrad, ctx.tp_group, async_op=True) dgrad, dgrad_work = allreduce(dgrad, ctx.tp_group, async_op=True)
nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_dgrad") nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_dgrad")
else:
dgrad = gemm_out
# --------------------------------------------------
# Grad input tensor has been computed...
# --------------------------------------------------
# --------------------------------------------------
# Compute grad weight
# --------------------------------------------------
# Compute grad weight tensor
wgrad = None wgrad = None
if ctx.requires_wgrad: if ctx.requires_wgrad:
# Synchronize tensor-parallel communication for input tensor # Prepare input tensor
if ctx.ub_bulk_dgrad: # Note: Synchronize tensor-parallel communication and
ln_out_total = ub_obj_dgrad.get_buffer(ctx.input_quantizer) # make sure required data is available
if ctx.fp8:
# FP8 GEMM on Hopper only supports TN layout so the gathered input must have
# a valid transpose.
if ln_out._data is None:
# All-gather executed on columnwise data and result is in rowwise data,
# so we need to fix the interleaving before WGRAD.
ln_out_total = _fix_gathered_fp8_transpose(ln_out_total, ctx.tp_size)
else:
# FP8 GEMM on Hopper only supports TN layout so the gathered input must
# have a valid transpose.
ln_out_total._create_transpose()
if ln_out_total_work is not None: if ln_out_total_work is not None:
ln_out_total_work.wait() ln_out_total_work.wait()
ln_out_total_work = None ln_out_total_work = None
if ctx.input_quantizer is not None and not isinstance( if ctx.fp8 or ctx.debug:
ln_out_total, QuantizedTensor if isinstance(ln_out_total, QuantizedTensorBase):
): ln_out_total.update_usage(columnwise_usage=True)
# Async gather may have been done in BF16 else:
# call quantizer after gather. ctx.input_quantizer.set_usage(rowwise=False, columnwise=True)
ctx.input_quantizer.set_usage(rowwise=False, columnwise=True) ln_out_total = ctx.input_quantizer(ln_out_total)
ln_out_total = ctx.input_quantizer(ln_out_total)
# Prepare grad output tensor
# Make sure GEMM inputs have required data # Note: Synchronize tensor-parallel communication and
if isinstance(ln_out_total, QuantizedTensor): # make sure required data is available
ln_out_total.update_usage(columnwise_usage=True) if ctx.ub_overlap_ag and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer):
if isinstance(grad_output, QuantizedTensor): # UB does not support overlapping grad output
grad_output.update_usage(columnwise_usage=True) # all-gather with wgrad GEMM. Also, we can't
# convert row-scaled MXFP8 to column-scaled, so we
# can't reuse the grad output that was gathered
# for the dgrad GEMM. We work around with blocking
# all-gather for column-scaled MXFP8 data.
ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
grad_output, _ = gather_along_first_dim(
grad_outputs[0],
ctx.tp_group,
quantizer=ctx.grad_output_quantizer,
)
if ctx.fp8 or ctx.debug:
if isinstance(grad_output, QuantizedTensorBase):
grad_output.update_usage(columnwise_usage=True)
else:
ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
grad_output = ctx.grad_output_quantizer(grad_output)
# Figure out whether to use split accumulator # Figure out whether to use split accumulator
use_split_accumulator = _2X_ACC_WGRAD use_split_accumulator = _2X_ACC_WGRAD
...@@ -752,55 +796,95 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -752,55 +796,95 @@ class _LayerNormLinear(torch.autograd.Function):
if hasattr(recipe, "fp8_gemm_wgrad"): if hasattr(recipe, "fp8_gemm_wgrad"):
use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator
# Output buffer for overlapping grad input # Figure out whether to output wgrad GEMM directly into main grad
if ctx.is_first_microbatch is not None:
accumulate_wgrad_into_param_main_grad = (
ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch
)
else:
accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
# Output buffer for overlapping FP8 grad input
# reduce-scatter with wgrad GEMM # reduce-scatter with wgrad GEMM
reduce_scatter_out = None
if ctx.ub_bulk_wgrad and ub_obj_wgrad.is_fp8_ubuf(): if ctx.ub_bulk_wgrad and ub_obj_wgrad.is_fp8_ubuf():
rs_out = torch.empty( reduce_scatter_out = torch.empty(
dgrad_shape, dtype=ctx.activation_dtype, device=inputmat.device dgrad_shape, dtype=ctx.activation_dtype, device=grad_outputs[0].device
) )
# wgrad GEMM # Arguments to include in wgrad GEMM closure
# Note: Fuse with bgrad computation if needed wgrad_gemm_kwargs = {
nvtx_range_push(f"{nvtx_label}.wgrad_gemm") "workspace": get_workspace(),
general_gemm_wgrad = functools.partial( "out_dtype": (
general_gemm,
out_dtype=(
main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype
), ),
workspace=get_workspace(), "quantization_params": ctx.grad_weight_quantizer,
layout="NT", "accumulate": accumulate_wgrad_into_param_main_grad,
grad=True, "layout": "NT",
bias=(bias if (grad_bias is None and not ctx.fp8) else None), "out": main_grad if ctx.fuse_wgrad_accumulation else None,
out=main_grad if ctx.fuse_wgrad_accumulation else None, "bias": (bias if (grad_bias is None and not ctx.fp8) else None),
use_split_accumulator=use_split_accumulator, "use_split_accumulator": use_split_accumulator,
accumulate=accumulate_wgrad_into_param_main_grad, "grad": True,
quantization_params=ctx.grad_weight_quantizer, "ub": ub_obj_wgrad,
ub=ub_obj_wgrad, "ub_type": ub_type_wgrad,
ub_type=ub_type_wgrad, "extra_output": reduce_scatter_out,
extra_output=rs_out, "bulk_overlap": ctx.ub_bulk_wgrad,
bulk_overlap=ctx.ub_bulk_wgrad, }
)
def wgrad_gemm(
x: torch.Tensor,
dy: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Perform wgrad GEMM: dw = dy^T * x
May be fused with bgrad computation.
May be called outside of this function to enable
some advanced communication/compute overlapping.
"""
nvtx_range_push(f"{nvtx_label}.wgrad_gemm")
dw, db, *_ = general_gemm(x, dy, **wgrad_gemm_kwargs)
nvtx_range_pop(f"{nvtx_label}.wgrad_gemm")
return dw, db
# Choose whether to call wgrad GEMM now or delay
if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute(): if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
ctx.wgrad_store.put([ln_out_total, grad_output], general_gemm_wgrad) if (
wgrad_gemm_kwargs["ub"] is not None
or wgrad_gemm_kwargs["ub_type"] is not None
or wgrad_gemm_kwargs["extra_output"] is not None
or wgrad_gemm_kwargs["bulk_overlap"]
):
raise NotImplementedError(
"Delayed weight grad computation is not supported "
"with Userbuffers (tensor-parallel communication overlapping)"
)
ctx.wgrad_store.put([ln_out_total, grad_output], wgrad_gemm)
else: else:
wgrad, grad_bias_, _, rs_out = general_gemm_wgrad(ln_out_total, grad_output)
# Call wgrad GEMM now
wgrad, grad_bias_ = wgrad_gemm(ln_out_total, grad_output)
# Update grad bias if needed
if grad_bias is None: if grad_bias is None:
grad_bias = grad_bias_ grad_bias = grad_bias_
del grad_bias_ del grad_bias_
# Deallocate input tensor # Deallocate input tensor if permitted
if not ctx.return_layernorm_output: if not ctx.return_layernorm_output:
# TODO (pgadzinski) - deallocate transpose only # pylint: disable=fixme
clear_tensor_data(ln_out_total) clear_tensor_data(ln_out_total)
nvtx_range_pop(f"{nvtx_label}.wgrad_gemm")
# Update grad input if overlapping reduce-scatter with wgrad GEMM
if ctx.ub_bulk_wgrad: if ctx.ub_bulk_wgrad:
if ub_obj_wgrad.is_fp8_ubuf(): if ub_obj_wgrad.is_fp8_ubuf():
dgrad = rs_out dgrad = reduce_scatter_out
else: else:
dgrad = ub_obj_wgrad.get_buffer(None, local_chunk=True) dgrad = ub_obj_wgrad.get_buffer(local_chunk=True).clone()
# --------------------------------------------------
# Grad weight has been computed...
# --------------------------------------------------
# Don't return grad bias if not needed # Don't return grad bias if not needed
if not ctx.use_bias: if not ctx.use_bias:
...@@ -879,7 +963,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -879,7 +963,7 @@ class _LayerNormLinear(torch.autograd.Function):
nvtx_range_pop(f"{nvtx_label}.reduce_and_update_fp8_tensors") nvtx_range_pop(f"{nvtx_label}.reduce_and_update_fp8_tensors")
# Scatter fp8 weight buffers # Scatter fp8 weight buffers
# if ctx.fp8 and not isinstance(weight, QuantizedTensor): # if ctx.fp8 and not isinstance(weight, QuantizedTensorBase):
# _fsdp_scatter_tensors(ctx.fsdp_group, weight_fp8) # _fsdp_scatter_tensors(ctx.fsdp_group, weight_fp8)
return ( return (
...@@ -1405,6 +1489,10 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1405,6 +1489,10 @@ class LayerNormLinear(TransformerEngineBaseModule):
"Splitting QuantizedTensor into multiple params is not supported" "Splitting QuantizedTensor into multiple params is not supported"
) )
else: else:
warnings.warn(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
unfused_weights = [w.dequantize() for w in unfused_weights] unfused_weights = [w.dequantize() for w in unfused_weights]
weight_tensor = noop_cat(unfused_weights) weight_tensor = noop_cat(unfused_weights)
...@@ -1511,7 +1599,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1511,7 +1599,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
grad_output_quantizer = None grad_output_quantizer = None
output_quantizer = None output_quantizer = None
input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
input_quantizer.internal = False input_quantizer.internal = True
weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
weight_quantizer.internal = True weight_quantizer.internal = True
if fp8_output: if fp8_output:
...@@ -1579,3 +1667,12 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1579,3 +1667,12 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.quantizers["scaling_bwd"][ self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1 tex.FP8BwdTensors.GRAD_OUTPUT1
].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon
# parallel related
if self.sequence_parallel and self.parallel_mode == "row":
# customize grad_output_quantizer with amax reduction TP group
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
].with_amax_reduction = True
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
].amax_reduction_group = self.tp_group
...@@ -8,7 +8,6 @@ import warnings ...@@ -8,7 +8,6 @@ import warnings
from typing import Callable, Optional, Tuple, Union from typing import Callable, Optional, Tuple, Union
from functools import reduce from functools import reduce
from operator import mul as multiply_op from operator import mul as multiply_op
import functools
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
...@@ -20,6 +19,7 @@ import transformer_engine_torch as tex ...@@ -20,6 +19,7 @@ import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch import torch_version from transformer_engine.pytorch import torch_version
from .base import ( from .base import (
fill_userbuffers_buffer_for_all_gather,
get_workspace, get_workspace,
_ub_communicators, _ub_communicators,
get_ub, get_ub,
...@@ -43,7 +43,6 @@ from ..utils import ( ...@@ -43,7 +43,6 @@ from ..utils import (
assert_dim_for_fp8_exec, assert_dim_for_fp8_exec,
clear_tensor_data, clear_tensor_data,
requires_grad, requires_grad,
non_tn_fp8_gemm_supported,
needs_quantized_gemm, needs_quantized_gemm,
) )
from ..distributed import ( from ..distributed import (
...@@ -67,10 +66,11 @@ from ..tensor.float8_tensor import ( ...@@ -67,10 +66,11 @@ from ..tensor.float8_tensor import (
) )
from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ._common import apply_normalization, _fix_gathered_fp8_transpose, WeightGradStore from ._common import apply_normalization, WeightGradStore
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ..tensor.quantized_tensor import ( from ..tensor.quantized_tensor import (
QuantizedTensor, QuantizedTensor,
QuantizedTensorBase,
Quantizer, Quantizer,
prepare_for_saving, prepare_for_saving,
restore_from_saved, restore_from_saved,
...@@ -201,24 +201,16 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -201,24 +201,16 @@ class _LayerNormMLP(torch.autograd.Function):
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
in_features, inp_shape = ln_weight.numel(), inp.shape
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
in_features, inp_shape = ln_weight.numel(), inp.shape
assert inp_shape[-1] == in_features, "GEMM not possible" assert inp_shape[-1] == in_features, "GEMM not possible"
inputmat = inp.view((-1, in_features)) inputmat = inp.view((-1, in_features))
if fp8: if fp8:
assert_dim_for_fp8_exec(inputmat, fc1_weight, fc2_weight) assert_dim_for_fp8_exec(inputmat, fc1_weight, fc2_weight)
if any([ub_overlap_ag, ub_overlap_rs]) and not (
FP8GlobalStateManager.get_fp8_recipe().float8_per_tensor_scaling()
):
raise NotImplementedError(
"Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor"
" current scaling"
)
activation_func = _act_func( activation_func = _act_func(
activation, FP8GlobalStateManager.get_fp8_recipe() if fp8 else None activation, FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
)[0] )[0]
device = inp.device
# Cast for native AMP # Cast for native AMP
inputmat = cast_if_needed(inputmat, activation_dtype) inputmat = cast_if_needed(inputmat, activation_dtype)
...@@ -226,6 +218,38 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -226,6 +218,38 @@ class _LayerNormMLP(torch.autograd.Function):
if ln_bias is not None: if ln_bias is not None:
ln_bias = cast_if_needed(ln_bias, activation_dtype) ln_bias = cast_if_needed(ln_bias, activation_dtype)
tp_world_size = get_distributed_world_size(tp_group)
backwards_needs_fc1_input = is_grad_enabled and fc1_weight.requires_grad
device = inp.device
# Configure Userbuffers communication (comm+GEMM overlap)
ub_overlap_ag = ub_overlap_ag and is_grad_enabled and not return_layernorm_output_gathered
ub_overlap_rs = ub_overlap_rs and is_grad_enabled
# Choose whether to use GEMM kernel with split accumulator
use_split_accumulator = _2X_ACC_FPROP
if fp8:
recipe = FP8GlobalStateManager.get_fp8_recipe()
if hasattr(recipe, "fp8_gemm_fprop"):
use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator
# Configure quantizer for norm output
if fp8:
if fc1_input_quantizer is None:
raise ValueError("Missing quantizer for FC1 input tensor")
fc1_input_quantizer.set_usage(rowwise=True, columnwise=backwards_needs_fc1_input)
if sequence_parallel and isinstance(
fc1_input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
):
# All-gather is not supported with FP8 column-wise data
fc1_input_quantizer.set_usage(columnwise=False)
# Do TP communication in high precision if quantized format
# does not support communication
force_hp_fc1_input_gather = (
fp8 and sequence_parallel and isinstance(fc1_input_quantizer, Float8BlockQuantizer)
)
# for fp8 DelayedScaling: layernorm output = FP8 # for fp8 DelayedScaling: layernorm output = FP8
# only output of the linear is returned # only output of the linear is returned
# for return_layernorm_output: layernorm output = High precision, then cast to FP8 # for return_layernorm_output: layernorm output = High precision, then cast to FP8
...@@ -241,29 +265,6 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -241,29 +265,6 @@ class _LayerNormMLP(torch.autograd.Function):
# Kernels not available for norm fusion. # Kernels not available for norm fusion.
with_quantized_norm = False with_quantized_norm = False
tp_world_size = get_distributed_world_size(tp_group)
ub_overlap_ag = ub_overlap_ag and is_grad_enabled and not return_layernorm_output_gathered
ub_overlap_rs = ub_overlap_rs and is_grad_enabled
backwards_needs_fc1_input = is_grad_enabled and fc1_weight.requires_grad
# TODO(kwyss): Support FP8 allgather of Float8BlockQuantizer recipe
force_hp_fc1_input_gather = (
fp8 and sequence_parallel and isinstance(fc1_input_quantizer, Float8BlockQuantizer)
) # Perform TP communication in high precision.
# Configure quantizer for norm output
if fp8:
if fc1_input_quantizer is None:
raise ValueError("Missing quantizer for FC1 input tensor")
columnwise_usage = backwards_needs_fc1_input
if (
columnwise_usage
and sequence_parallel
and not isinstance(fc1_input_quantizer, MXFP8Quantizer)
):
columnwise_usage = False
fc1_input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
# Apply normalization # Apply normalization
ln_out, mu, rsigma = apply_normalization( ln_out, mu, rsigma = apply_normalization(
inputmat, inputmat,
...@@ -297,39 +298,43 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -297,39 +298,43 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) fc1_input_quantizer.set_usage(rowwise=True, columnwise=False)
ln_out_total = fc1_input_quantizer(ln_out_total) ln_out_total = fc1_input_quantizer(ln_out_total)
else: else:
quantizer = None
if fp8 or debug: if fp8 or debug:
quantizer = fc1_input_quantizer
if not with_quantized_norm and not force_hp_fc1_input_gather: if not with_quantized_norm and not force_hp_fc1_input_gather:
ln_out = fc1_input_quantizer(ln_out) ln_out = fc1_input_quantizer(ln_out)
fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) fc1_input_quantizer.set_usage(rowwise=True, columnwise=False)
if ub_overlap_ag: if ub_overlap_ag:
# Copy into Userbuffers buffer # Copy into Userbuffers buffer
ub_obj_lnout = get_ub("fc1_fprop") ub_obj_lnout = get_ub("fc1_fprop")
ub_obj_lnout.get_buffer(fc1_input_quantizer, local_chunk=True).copy_(ln_out) ln_out_total, _ = fill_userbuffers_buffer_for_all_gather(
ln_out_total = ub_obj_lnout.get_buffer(fc1_input_quantizer) ub_obj_lnout,
ln_out,
quantizer,
tp_group,
)
else: else:
# All-gather with NCCL # All-gather with NCCL
ln_out_total, _ = gather_along_first_dim( ln_out_total, _ = gather_along_first_dim(
ln_out, ln_out,
tp_group, tp_group,
quantizer=(fc1_input_quantizer if fp8 or debug else None), quantizer=quantizer,
) )
else: else:
# NOTE: force_hp_fc1_input_gather is redundant with else, but if (fp8 or debug) and not with_quantized_norm:
# here for clarity. We should not quantize ln_out if bwd needs
# to gather in hp.
if (fp8 or debug) and not with_quantized_norm and not force_hp_fc1_input_gather:
ln_out = fc1_input_quantizer(ln_out) ln_out = fc1_input_quantizer(ln_out)
ln_out_total = ln_out ln_out_total = ln_out
# Cast weights to expected dtype # Cast weights to expected dtype
fc1_weight_final = fc1_weight fc1_weight_final = fc1_weight
fc2_weight_final = fc2_weight fc2_weight_final = fc2_weight
if fp8 or debug: if fp8 or debug:
# If weights are not quantized, we call get_weight_workspace, # If weights are not quantized, we call get_weight_workspace,
# which handles weight caching etc. # which handles weight caching etc.
# FP8 cast to workspace buffer # FP8 cast to workspace buffer
update_workspace = is_first_microbatch is None or is_first_microbatch update_workspace = is_first_microbatch is None or is_first_microbatch
fc1_weight_quantizer.set_usage(rowwise=True, columnwise=True)
fc2_weight_quantizer.set_usage(rowwise=True, columnwise=True)
fc1_weight_final = module.get_weight_workspace( fc1_weight_final = module.get_weight_workspace(
tensor=fc1_weight, tensor=fc1_weight,
quantizer=fc1_weight_quantizer, quantizer=fc1_weight_quantizer,
...@@ -339,7 +344,6 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -339,7 +344,6 @@ class _LayerNormMLP(torch.autograd.Function):
fsdp_group=fsdp_group, fsdp_group=fsdp_group,
workspace_dtype=activation_dtype, workspace_dtype=activation_dtype,
) )
fc2_weight_quantizer.set_usage(rowwise=True, columnwise=True)
fc2_weight_final = module.get_weight_workspace( fc2_weight_final = module.get_weight_workspace(
tensor=fc2_weight, tensor=fc2_weight,
quantizer=fc2_weight_quantizer, quantizer=fc2_weight_quantizer,
...@@ -349,6 +353,8 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -349,6 +353,8 @@ class _LayerNormMLP(torch.autograd.Function):
fsdp_group=fsdp_group, fsdp_group=fsdp_group,
workspace_dtype=activation_dtype, workspace_dtype=activation_dtype,
) )
fc1_weight_final.update_usage(rowwise_usage=True)
fc2_weight_final.update_usage(rowwise_usage=True)
else: else:
fc1_weight_final = cast_if_needed(fc1_weight_final, activation_dtype) fc1_weight_final = cast_if_needed(fc1_weight_final, activation_dtype)
fc2_weight_final = cast_if_needed(fc2_weight_final, activation_dtype) fc2_weight_final = cast_if_needed(fc2_weight_final, activation_dtype)
...@@ -356,6 +362,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -356,6 +362,7 @@ class _LayerNormMLP(torch.autograd.Function):
# Cast biases to expected dtype # Cast biases to expected dtype
bias_dtype = activation_dtype bias_dtype = activation_dtype
if needs_quantized_gemm(ln_out_total) and activation_dtype == torch.float32: if needs_quantized_gemm(ln_out_total) and activation_dtype == torch.float32:
# cuBLAS does not support FP8 GEMM with FP32 bias, so we cast to BF16
bias_dtype = torch.bfloat16 bias_dtype = torch.bfloat16
if fc1_bias is not None: if fc1_bias is not None:
fc1_bias = cast_if_needed(fc1_bias, bias_dtype) fc1_bias = cast_if_needed(fc1_bias, bias_dtype)
...@@ -369,7 +376,9 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -369,7 +376,9 @@ class _LayerNormMLP(torch.autograd.Function):
if fc1_weight_quantizer is not None: if fc1_weight_quantizer is not None:
fc1_weight_quantizer.calibrate(fc1_weight) fc1_weight_quantizer.calibrate(fc1_weight)
# ------------------------------------------------------
# FC1 GEMM # FC1 GEMM
# ------------------------------------------------------
# There are 2 fusions possible: # There are 2 fusions possible:
# - gemm_gelu_fusion - default for full precision, optional for fp8 - need to turn on gemm_gelu_fusion, # - gemm_gelu_fusion - default for full precision, optional for fp8 - need to turn on gemm_gelu_fusion,
...@@ -401,11 +410,16 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -401,11 +410,16 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_bias if not bias_gelu_fusion else None fc1_bias if not bias_gelu_fusion else None
), # otherwise bias is added later (fused with gelu) ), # otherwise bias is added later (fused with gelu)
gelu=gemm_gelu_fusion, gelu=gemm_gelu_fusion,
accumulate=_2X_ACC_FPROP, use_split_accumulator=use_split_accumulator,
ub=ub_obj_lnout, ub=ub_obj_lnout,
ub_type=tex.CommOverlapType.AG if ub_overlap_ag else None, ub_type=tex.CommOverlapType.AG if ub_overlap_ag else None,
) )
# ------------------------------------------------------
# Finished FC1 GEMM...
# ------------------------------------------------------
# Deallocate FC1 GEMM input tensor if no longer needed
if not is_grad_enabled and (ln_out_total is not ln_out_return): if not is_grad_enabled and (ln_out_total is not ln_out_return):
clear_tensor_data(ln_out_total) clear_tensor_data(ln_out_total)
...@@ -439,45 +453,66 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -439,45 +453,66 @@ class _LayerNormMLP(torch.autograd.Function):
fc2_input_quantizer.calibrate(act_out) fc2_input_quantizer.calibrate(act_out)
fc2_weight_quantizer.calibrate(fc2_weight) fc2_weight_quantizer.calibrate(fc2_weight)
# Configure Userbuffers reduce-scatter if needed
ub_obj_fc2out = None ub_obj_fc2out = None
rs_out = None reduce_scatter_out = None
fc2_out = None
if ub_overlap_rs: if ub_overlap_rs:
ub_obj_fc2out = get_ub("fc2_fprop") ub_obj_fc2out = get_ub("fc2_fprop")
dim_size = list(act_out.size()) dim_size = list(act_out.size())
dim_size[0] = dim_size[0] // tp_world_size dim_size[0] //= tp_world_size
dim_size[1] = fc2_weight.size(0) dim_size[-1] = fc2_weight.size(0)
rs_out = torch.empty(dim_size, dtype=activation_dtype, device=device) reduce_scatter_out = torch.empty(dim_size, dtype=activation_dtype, device=device)
else:
dim_size = list(act_out.size())
dim_size[1] = fc2_weight.size(0)
fc2_out = torch.empty(dim_size, dtype=activation_dtype, device=device)
# ------------------------------------------------------
# FC2 GEMM # FC2 GEMM
_ = general_gemm( # ------------------------------------------------------
gemm_out, *_, reduce_scatter_out = general_gemm(
fc2_weight_final, fc2_weight_final,
act_out, act_out,
get_workspace(), get_workspace(),
out_dtype=activation_dtype, out_dtype=activation_dtype,
bias=fc2_bias, bias=fc2_bias,
quantization_params=fc2_output_quantizer, quantization_params=fc2_output_quantizer,
out=fc2_out, use_split_accumulator=use_split_accumulator,
use_split_accumulator=_2X_ACC_FPROP,
ub=ub_obj_fc2out, ub=ub_obj_fc2out,
ub_type=tex.CommOverlapType.RS if ub_overlap_rs else None, ub_type=tex.CommOverlapType.RS if ub_overlap_rs else None,
extra_output=rs_out, extra_output=reduce_scatter_out,
) )
# ------------------------------------------------------
# Finished FC2 GEMM...
# ------------------------------------------------------
# Deallocate tensors if no longer needed
if not is_grad_enabled:
clear_tensor_data(act_out, fc1_out_without_bias, fc1_out)
# Prepare output tensor
# Note: Perform tensor-parallel communication if needed
fc2_out = None
if ub_overlap_rs:
fc2_out = reduce_scatter_out
elif set_parallel_mode and sequence_parallel:
fc2_out, _ = reduce_scatter_along_first_dim(gemm_out, tp_group)
elif set_parallel_mode and tensor_parallel:
if symmetric_ar_type is not None:
fc2_out, _ = symmetric_all_reduce(
gemm_out, tp_group, all_reduce_type=symmetric_ar_type
)
else:
fc2_out, _ = allreduce(gemm_out, tp_group)
else:
fc2_out = gemm_out
fc2_out = fc2_out.view(-1, *inp_shape[1:-1], fc2_out.shape[-1])
# Weight with column-wise usage is needed for dgrad GEMM. # Cache state for backward pass
if is_grad_enabled: if is_grad_enabled:
if isinstance(fc1_weight_final, QuantizedTensor):
# Weight with column-wise usage is needed for dgrad GEMM.
if isinstance(fc1_weight_final, QuantizedTensorBase):
fc1_weight_final.update_usage(columnwise_usage=True) fc1_weight_final.update_usage(columnwise_usage=True)
if isinstance(fc2_weight_final, QuantizedTensor): if isinstance(fc2_weight_final, QuantizedTensorBase):
fc2_weight_final.update_usage(columnwise_usage=True) fc2_weight_final.update_usage(columnwise_usage=True)
if not is_grad_enabled:
clear_tensor_data(act_out, fc1_out_without_bias, fc1_out)
else:
if cpu_offloading: if cpu_offloading:
mark_activation_offload( mark_activation_offload(
inputmat, mu, rsigma, ln_out, fc1_out, fc1_out_without_bias, act_out inputmat, mu, rsigma, ln_out, fc1_out, fc1_out_without_bias, act_out
...@@ -504,8 +539,6 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -504,8 +539,6 @@ class _LayerNormMLP(torch.autograd.Function):
if not return_layernorm_output: if not return_layernorm_output:
clear_tensor_data(ln_out) clear_tensor_data(ln_out)
ln_out = None ln_out = None
elif force_hp_fc1_input_gather:
assert not isinstance(ln_out, QuantizedTensor)
if not fc2_weight.requires_grad: if not fc2_weight.requires_grad:
clear_tensor_data(act_out) clear_tensor_data(act_out)
act_out = None act_out = None
...@@ -592,28 +625,12 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -592,28 +625,12 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.wgrad_store = wgrad_store ctx.wgrad_store = wgrad_store
# Row Parallel Linear
if ub_overlap_rs:
fc2_out = rs_out
elif set_parallel_mode and sequence_parallel:
fc2_out, _ = reduce_scatter_along_first_dim(fc2_out, tp_group)
elif set_parallel_mode and tensor_parallel:
if symmetric_ar_type is not None:
fc2_out, _ = symmetric_all_reduce(
fc2_out, tp_group, all_reduce_type=symmetric_ar_type
)
else:
fc2_out, _ = allreduce(fc2_out, tp_group)
# [*, in_features] -> [*, out_features] except first dimension changes for SP
fc2_out = fc2_out.view(-1, *inp_shape[1:-1], fc2_out.shape[-1])
if return_layernorm_output: if return_layernorm_output:
if return_layernorm_output_gathered: if return_layernorm_output_gathered:
shape = list(inp_shape) shape = list(inp_shape)
shape[0] *= tp_size shape[0] *= tp_size
return fc2_out, ln_out_return.view(shape) return fc2_out, ln_out_return.view(shape)
return fc2_out, ln_out_return.view_as(inp) return fc2_out, ln_out_return.view(inp_shape)
return fc2_out return fc2_out
@staticmethod @staticmethod
...@@ -622,24 +639,6 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -622,24 +639,6 @@ class _LayerNormMLP(torch.autograd.Function):
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
with torch.cuda.nvtx.range("_LayerNormMLP_backward"): with torch.cuda.nvtx.range("_LayerNormMLP_backward"):
if (
ctx.fp8
and any(
[
ctx.ub_overlap_ag,
ctx.ub_overlap_rs_dgrad,
ctx.ub_bulk_dgrad,
ctx.ub_bulk_wgrad,
]
)
and (ctx.fp8_recipe is not None)
):
if not ctx.fp8_recipe.float8_per_tensor_scaling():
raise NotImplementedError(
"Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor"
" current scaling"
)
saved_tensors = ctx.saved_tensors saved_tensors = ctx.saved_tensors
( # pylint: disable=unbalanced-tuple-unpacking ( # pylint: disable=unbalanced-tuple-unpacking
inputmat, inputmat,
...@@ -699,6 +698,16 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -699,6 +698,16 @@ class _LayerNormMLP(torch.autograd.Function):
# fc2_weight_fp8 if ctx.fp8 and not isinstance(fc2_weight, Float8Tensor) else None, # fc2_weight_fp8 if ctx.fp8 and not isinstance(fc2_weight, Float8Tensor) else None,
# ) # )
# Choose whether to use GEMM kernel with split accumulator
dgrad_use_split_accumulator = _2X_ACC_DGRAD
wgrad_use_split_accumulator = _2X_ACC_WGRAD
if ctx.fp8:
recipe = ctx.fp8_recipe
if hasattr(recipe, "fp8_gemm_dgrad"):
dgrad_use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator
if hasattr(recipe, "fp8_gemm_wgrad"):
wgrad_use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator
# No need to do bulk DGRAD/WGRAD overlap if WGRAD is not required # No need to do bulk DGRAD/WGRAD overlap if WGRAD is not required
ctx.ub_bulk_dgrad = ctx.fc1_weight_requires_grad and ctx.ub_bulk_dgrad ctx.ub_bulk_dgrad = ctx.fc1_weight_requires_grad and ctx.ub_bulk_dgrad
ctx.ub_bulk_wgrad = ctx.fc1_weight_requires_grad and ctx.ub_bulk_wgrad ctx.ub_bulk_wgrad = ctx.fc1_weight_requires_grad and ctx.ub_bulk_wgrad
...@@ -707,20 +716,13 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -707,20 +716,13 @@ class _LayerNormMLP(torch.autograd.Function):
# Note: dgrad GEMM requires row-wise usage, wgrad GEMM # Note: dgrad GEMM requires row-wise usage, wgrad GEMM
# requires column-wise usage # requires column-wise usage
if ctx.fc2_grad_output_quantizer is not None: if ctx.fc2_grad_output_quantizer is not None:
rowwise_usage = True quantizer = ctx.fc2_grad_output_quantizer
columnwise_usage = True quantizer.set_usage(rowwise=True, columnwise=True)
if ctx.ub_overlap_ag and isinstance( if ctx.ub_overlap_ag:
ctx.fc2_grad_output_quantizer, # Userbuffers only supports communication for one
(Float8Quantizer, Float8CurrentScalingQuantizer), # tensor usage at a time. Configure quantizer with
): # usage for only dgrad GEMM.
# If data is in FP8 and communication is handled quantizer.set_usage(columnwise=False)
# with Userbuffers, we compute FP8 transposes
# manually
columnwise_usage = False
ctx.fc2_grad_output_quantizer.set_usage(
rowwise=rowwise_usage,
columnwise=columnwise_usage,
)
# Prepare FC2 grad output tensor # Prepare FC2 grad output tensor
# Note: Cast to expected dtype and perform tensor-parallel communication # Note: Cast to expected dtype and perform tensor-parallel communication
...@@ -738,14 +740,10 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -738,14 +740,10 @@ class _LayerNormMLP(torch.autograd.Function):
# Launch tensor-parallel communication for FC1 GEMM input # Launch tensor-parallel communication for FC1 GEMM input
ln_out_total = None ln_out_total = None
ln_out_total_work = None ln_out_total_work = None
if ( ub_obj_fc1_dgrad = None
ctx.fc1_weight_requires_grad if ctx.fc1_weight_requires_grad and ctx.tensor_parallel and ctx.sequence_parallel:
and ctx.tensor_parallel
and ctx.sequence_parallel
and not ctx.ub_bulk_dgrad
):
quantizer = None quantizer = None
if ctx.fp8 or ctx.debug: if ctx.fp8 or ctx.debug and not ctx.force_hp_fc1_input_gather:
quantizer = ctx.fc1_input_quantizer quantizer = ctx.fc1_input_quantizer
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
# If data is in FP8, we compute FP8 transposes manually # If data is in FP8, we compute FP8 transposes manually
...@@ -753,13 +751,21 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -753,13 +751,21 @@ class _LayerNormMLP(torch.autograd.Function):
else: else:
# wgrad GEMM requires input with column-wise usage # wgrad GEMM requires input with column-wise usage
quantizer.set_usage(rowwise=False, columnwise=True) quantizer.set_usage(rowwise=False, columnwise=True)
gather_quantizer = None if ctx.force_hp_fc1_input_gather else quantizer if ctx.ub_bulk_dgrad:
ln_out_total, ln_out_total_work = gather_along_first_dim( ub_obj_fc1_dgrad = get_ub("fc1_dgrad")
ln_out, ln_out_total, _ = fill_userbuffers_buffer_for_all_gather(
ctx.tp_group, ub_obj_fc1_dgrad,
async_op=True, ln_out,
quantizer=gather_quantizer, quantizer,
) ctx.tp_group,
)
else:
ln_out_total, ln_out_total_work = gather_along_first_dim(
ln_out,
ctx.tp_group,
async_op=True,
quantizer=quantizer,
)
else: else:
ln_out_total = ln_out ln_out_total = ln_out
...@@ -770,6 +776,11 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -770,6 +776,11 @@ class _LayerNormMLP(torch.autograd.Function):
) )
else: else:
accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
# --------------------------------------------------
# FC2 DGRAD
# --------------------------------------------------
# There are 6 possible fusion paths # There are 6 possible fusion paths
# 1 high-precision bias_gelu_fusion: gemm, FC1_bias + gelu, # 1 high-precision bias_gelu_fusion: gemm, FC1_bias + gelu,
# 2 high-precision fc2_dgrad_gemm_gelu_fusion: gemm + gelu, FC1_bias + quantize # 2 high-precision fc2_dgrad_gemm_gelu_fusion: gemm + gelu, FC1_bias + quantize
...@@ -784,12 +795,15 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -784,12 +795,15 @@ class _LayerNormMLP(torch.autograd.Function):
and (not ctx.debug) and (not ctx.debug)
) )
# FC2 DGRAD; Unconditional # Make sure required data is available
if ctx.fc2_weight_quantizer is not None and isinstance(ctx.fc2_weight, QuantizedTensor): if isinstance(grad_output, QuantizedTensorBase):
ctx.fc2_weight.update_usage( grad_output.update_usage(rowwise_usage=True)
rowwise_usage=ctx.fc2_weight_quantizer.rowwise_usage, if ctx.fc2_weight_quantizer is not None and isinstance(
columnwise_usage=ctx.fc2_weight_quantizer.columnwise_usage, ctx.fc2_weight, QuantizedTensorBase
) ):
ctx.fc2_weight.update_usage(columnwise_usage=True)
# Perform GEMM
gemm_output, *_ = general_gemm( gemm_output, *_ = general_gemm(
fc2_weight, fc2_weight,
grad_output, grad_output,
...@@ -804,52 +818,107 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -804,52 +818,107 @@ class _LayerNormMLP(torch.autograd.Function):
out_dtype=ctx.activation_dtype, out_dtype=ctx.activation_dtype,
gelu=fc2_dgrad_gemm_gelu_fusion, gelu=fc2_dgrad_gemm_gelu_fusion,
gelu_in=fc1_out if fc2_dgrad_gemm_gelu_fusion else None, gelu_in=fc1_out if fc2_dgrad_gemm_gelu_fusion else None,
use_split_accumulator=_2X_ACC_DGRAD, use_split_accumulator=dgrad_use_split_accumulator,
ub=ub_obj_fc2_dgrad, ub=ub_obj_fc2_dgrad,
ub_type=tex.CommOverlapType.AG if ctx.ub_overlap_ag else None, ub_type=tex.CommOverlapType.AG if ctx.ub_overlap_ag else None,
) )
# Prepare input grad tensor
dact = None
fc2_dgrad = None
if fc2_dgrad_gemm_gelu_fusion: if fc2_dgrad_gemm_gelu_fusion:
dact = gemm_output dact = gemm_output
fc2_dgrad = None
else: else:
fc2_dgrad = gemm_output fc2_dgrad = gemm_output
# --------------------------------------------------
# Finished FC2 DGRAD...
# --------------------------------------------------
# --------------------------------------------------
# FC2 WGRAD # FC2 WGRAD
# --------------------------------------------------
fc2_wgrad = None
if ctx.fc2_weight_requires_grad: if ctx.fc2_weight_requires_grad:
if isinstance(act_out, QuantizedTensor):
act_out.update_usage(rowwise_usage=True, columnwise_usage=True)
if isinstance(grad_output, QuantizedTensor): # Prepare input tensor
grad_output.update_usage(rowwise_usage=True, columnwise_usage=True) # Note: Synchronize tensor-parallel communication and
# make sure required data is available
if ctx.fp8 or ctx.debug:
if isinstance(act_out, QuantizedTensorBase):
act_out.update_usage(columnwise_usage=True)
else:
ctx.fc2_input_quantizer.set_usage(rowwise=False, columnwise=True)
act_out = ctx.fc2_input_quantizer(act_out)
# Prepare grad output tensor
# Note: Synchronize tensor-parallel communication and
# make sure required data is available
if ctx.ub_overlap_ag and isinstance(ctx.fc2_grad_output_quantizer, MXFP8Quantizer):
# UB does not support overlapping grad output
# all-gather with wgrad GEMM. Also, we can't
# convert row-scaled MXFP8 to column-scaled, so we
# can't reuse the grad output that was gathered
# for the dgrad GEMM. We work around with blocking
# all-gather for column-scaled MXFP8 data.
ctx.fc2_grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
grad_output, _ = gather_along_first_dim(
grad_outputs[0],
ctx.tp_group,
quantizer=ctx.fc2_grad_output_quantizer,
)
if ctx.fp8 or ctx.debug:
if isinstance(grad_output, QuantizedTensorBase):
grad_output.update_usage(columnwise_usage=True)
else:
ctx.fc2_grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
grad_output = ctx.fc2_grad_output_quantizer(grad_output)
# Whether to set grad arg in general_gemm
grad_arg = True grad_arg = True
if ctx.fp8 and ctx.fp8_recipe.float8_block_scaling(): if ctx.fp8 and ctx.fp8_recipe.float8_block_scaling():
grad_arg = False grad_arg = False
general_gemm_fc2_wgrad = functools.partial(
general_gemm, # Arguments to include in wgrad GEMM closure
out_dtype=( fc2_wgrad_gemm_kwargs = {
"workspace": get_workspace(),
"out_dtype": (
origin_fc2_weight.main_grad.dtype origin_fc2_weight.main_grad.dtype
if ctx.fuse_wgrad_accumulation if ctx.fuse_wgrad_accumulation
else ctx.activation_dtype else ctx.activation_dtype
), ),
workspace=get_workspace(), "quantization_params": ctx.fc2_grad_weight_quantizer, # wgrad in high precision
quantization_params=ctx.fc2_grad_weight_quantizer, # wgrad in high precision "accumulate": accumulate_wgrad_into_param_main_grad,
layout="NT", "layout": "NT",
grad=grad_arg, "out": origin_fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
bias=fc2_bias if fc2_bias is not None and fc2_bias_grad is None else None, "bias": fc2_bias if fc2_bias is not None and fc2_bias_grad is None else None,
accumulate=accumulate_wgrad_into_param_main_grad, "use_split_accumulator": wgrad_use_split_accumulator,
use_split_accumulator=_2X_ACC_WGRAD, "grad": grad_arg,
out=origin_fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None, }
)
def fc2_wgrad_gemm(
x: torch.Tensor,
dy: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Perform FC2 WGRAD GEMM
May be called outside of this function to enable
some advanced communication/compute overlapping.
"""
dw, db, *_ = general_gemm(x, dy, **fc2_wgrad_gemm_kwargs)
return dw, db
# Choose whether to call wgrad GEMM now or delay
if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute(): if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
ctx.wgrad_store.put([act_out, grad_output], general_gemm_fc2_wgrad) ctx.wgrad_store.put([act_out, grad_output], fc2_wgrad_gemm)
fc2_wgrad = None
else: else:
fc2_wgrad, fc2_bias_grad_, *_ = general_gemm_fc2_wgrad(
act_out,
grad_output,
)
# Call wgrad GEMM now
fc2_wgrad, fc2_bias_grad_ = fc2_wgrad_gemm(act_out, grad_output)
# Update grad bias if needed
if fc2_bias_grad is None: if fc2_bias_grad is None:
if ( if (
ctx.fp8 ctx.fp8
...@@ -858,12 +927,17 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -858,12 +927,17 @@ class _LayerNormMLP(torch.autograd.Function):
): ):
# BGRAD not fused with GEMM for float8 blockwise gemm. # BGRAD not fused with GEMM for float8 blockwise gemm.
fc2_bias_grad_ = act_out.view(-1, act_out.shape[-1]).sum(dim=0) fc2_bias_grad_ = act_out.view(-1, act_out.shape[-1]).sum(dim=0)
fc2_bias_grad = fc2_bias_grad_ fc2_bias_grad = fc2_bias_grad_
del fc2_bias_grad_ del fc2_bias_grad_
# Deallocate input tensor if permitted
if ctx.wgrad_store is not None and not ctx.wgrad_store.delay_wgrad_compute(): if ctx.wgrad_store is not None and not ctx.wgrad_store.delay_wgrad_compute():
clear_tensor_data(act_out) clear_tensor_data(act_out)
# --------------------------------------------------
# Finished FC2 WGRAD...
# --------------------------------------------------
# bias computation # bias computation
fc1_bias_grad = None fc1_bias_grad = None
fuse_gemm_and_bias_fc1_wgrad = False fuse_gemm_and_bias_fc1_wgrad = False
...@@ -927,63 +1001,69 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -927,63 +1001,69 @@ class _LayerNormMLP(torch.autograd.Function):
ub_obj_fc1_dgrad = None ub_obj_fc1_dgrad = None
ub_obj_fc1_wgrad = None ub_obj_fc1_wgrad = None
ub_type_fc1_dgrad = None ub_type_fc1_dgrad = None
ub_type_fc1_wgrad = None
fc1_dgrad_shape = [reduce(multiply_op, inputmat.shape[:-1]), inputmat.shape[-1]] fc1_dgrad_shape = [reduce(multiply_op, inputmat.shape[:-1]), inputmat.shape[-1]]
fc1_dgrad_rs_out = None
fc1_dgrad_bulk = None
if ctx.ub_overlap_rs_dgrad: if ctx.ub_overlap_rs_dgrad:
# Overlap DGRAD+RS # Overlap DGRAD+RS
ub_obj_fc1_dgrad = get_ub("fc1_dgrad") ub_obj_fc1_dgrad = get_ub("fc1_dgrad")
ub_type_fc1_dgrad = tex.CommOverlapType.RS ub_type_fc1_dgrad = tex.CommOverlapType.RS
fc1_dgrad_rs_out = torch.empty(
fc1_dgrad_shape, dtype=ctx.activation_dtype, device="cuda"
)
else: else:
if ctx.ub_bulk_dgrad: if ctx.ub_bulk_dgrad:
# Overlap ln_out all-gather with DGRAD compute # Overlap ln_out all-gather with DGRAD compute
# NOTE: Copying into communication buffer will always prefer rowwise data,
# and will copy columnwise data if rowwise does not exist. In that case,
# the all-gather will apply to the leading dimension of the transpose,
# which then needs to be interleaved correctly before WGRAD.
ub_obj_fc1_dgrad = get_ub("fc1_dgrad") ub_obj_fc1_dgrad = get_ub("fc1_dgrad")
ub_type_fc1_dgrad = tex.CommOverlapType.AG ub_type_fc1_dgrad = tex.CommOverlapType.AG
ub_obj_fc1_dgrad.copy_into_buffer(
ln_out, ctx.fc1_input_quantizer, local_chunk=True
)
if ctx.ub_bulk_wgrad: if ctx.ub_bulk_wgrad:
# Overlap FC1 DGRAD reduce-scatter with WGRAD compute # Overlap FC1 DGRAD reduce-scatter with WGRAD compute
ub_obj_fc1_wgrad = get_ub("fc1_wgrad") ub_obj_fc1_wgrad = get_ub("fc1_wgrad")
fc1_dgrad_bulk = ub_obj_fc1_wgrad.get_buffer(None) ub_type_fc1_wgrad = tex.CommOverlapType.RS
# FC1 DGRAD: Unconditional # --------------------------------------------------
# FC1 DGRAD
# --------------------------------------------------
# Make sure required data is available
if ctx.fc1_weight_quantizer is not None and isinstance( if ctx.fc1_weight_quantizer is not None and isinstance(
ctx.fc1_weight_quantizer, QuantizedTensor ctx.fc1_weight_quantizer, QuantizedTensorBase
): ):
ctx.fc1_weight.update_usage( ctx.fc1_weight.update_usage(columnwise_usage=True)
rowwise_usage=ctx.fc1_weight_quantizer.rowwise_usage,
columnwise_usage=ctx.fc1_weight_quantizer.columnwise_usage, # Output buffers for Userbuffers reduce-scatter
gemm_out = None
reduce_scatter_out = None
if ctx.ub_overlap_rs_dgrad:
reduce_scatter_out = torch.empty(
fc1_dgrad_shape, dtype=ctx.activation_dtype, device="cuda"
) )
fc1_dgrad, *_, fc1_dgrad_rs_out = general_gemm( if ctx.ub_bulk_wgrad:
gemm_out = ub_obj_fc1_wgrad.get_buffer(local_chunk=False)
# dgrad GEMM
gemm_out, *_, reduce_scatter_out = general_gemm(
fc1_weight, fc1_weight,
dact, dact,
get_workspace(), get_workspace(),
out=fc1_dgrad_bulk, out=gemm_out,
out_dtype=ctx.activation_dtype, out_dtype=ctx.activation_dtype,
quantization_params=ctx.fc1_grad_input_quantizer, quantization_params=ctx.fc1_grad_input_quantizer,
layout="NN", layout="NN",
grad=True, grad=True,
use_split_accumulator=dgrad_use_split_accumulator,
ub=ub_obj_fc1_dgrad, ub=ub_obj_fc1_dgrad,
ub_type=ub_type_fc1_dgrad, ub_type=ub_type_fc1_dgrad,
extra_output=fc1_dgrad_rs_out, extra_output=reduce_scatter_out,
bulk_overlap=ctx.ub_bulk_dgrad, bulk_overlap=ctx.ub_bulk_dgrad,
) )
# Overlap dgrad-RS/AR with wgrad # Prepare grad input tensor
# Note: Perform tensor-parallel communication
fc1_dgrad = None
fc1_dgrad_work = None fc1_dgrad_work = None
if ctx.ub_overlap_rs_dgrad: if ctx.ub_overlap_rs_dgrad:
fc1_dgrad = fc1_dgrad_rs_out fc1_dgrad = reduce_scatter_out
elif ctx.ub_bulk_wgrad:
fc1_dgrad = ub_obj_fc1_wgrad.get_buffer(local_chunk=True)
elif ctx.set_parallel_mode and not ctx.ub_bulk_wgrad: elif ctx.set_parallel_mode and not ctx.ub_bulk_wgrad:
fc1_dgrad = gemm_out
if ctx.sequence_parallel: if ctx.sequence_parallel:
if ctx.return_layernorm_output and ctx.return_layernorm_output_gathered: if ctx.return_layernorm_output and ctx.return_layernorm_output_gathered:
fc1_dgrad = fc1_dgrad + grad_outputs[1].view_as(fc1_dgrad) fc1_dgrad = fc1_dgrad + grad_outputs[1].view_as(fc1_dgrad)
...@@ -994,90 +1074,125 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -994,90 +1074,125 @@ class _LayerNormMLP(torch.autograd.Function):
) )
elif ctx.tensor_parallel: elif ctx.tensor_parallel:
fc1_dgrad, fc1_dgrad_work = allreduce(fc1_dgrad, ctx.tp_group, async_op=True) fc1_dgrad, fc1_dgrad_work = allreduce(fc1_dgrad, ctx.tp_group, async_op=True)
else:
fc1_dgrad = gemm_out
# --------------------------------------------------
# Finished FC1 DGRAD...
# --------------------------------------------------
# --------------------------------------------------
# FC1 WGRAD # FC1 WGRAD
# --------------------------------------------------
fc1_wgrad = None fc1_wgrad = None
if ctx.fc1_weight_requires_grad: if ctx.fc1_weight_requires_grad:
# Synchronize tensor-parallel communication for FC1 GEMM input tensor # Prepare input tensor
if ctx.ub_bulk_dgrad: # Note: Synchronize tensor-parallel communication and
ln_out_total = ub_obj_fc1_dgrad.get_buffer(ctx.fc1_input_quantizer) # make sure required data is available
if ctx.fp8:
if ln_out._data is None:
# All-gather executed on columnwise data and result is in rowwise data,
# so we need to fix the interleaving before WGRAD.
ln_out_total = _fix_gathered_fp8_transpose(ln_out_total, ctx.tp_size)
elif not non_tn_fp8_gemm_supported():
# FP8 GEMM on Hopper only supports TN layout so the gathered input must
# have a valid transpose.
ln_out_total._create_transpose()
if ln_out_total_work is not None: if ln_out_total_work is not None:
ln_out_total_work.wait() ln_out_total_work.wait()
ln_out_total_work = None ln_out_total_work = None
if ctx.fc1_input_quantizer is not None and not isinstance( if ctx.fp8 or ctx.debug:
ln_out_total, QuantizedTensor if isinstance(ln_out_total, QuantizedTensorBase):
): ln_out_total.update_usage(columnwise_usage=True)
# Async gather in BF16 does not asynchronously else:
# call quantizer after gather. ctx.fc1_input_quantizer.set_usage(rowwise=False, columnwise=True)
ctx.fc1_input_quantizer.set_usage(rowwise=False, columnwise=True) ln_out_total = ctx.fc1_input_quantizer(ln_out_total)
ln_out_total = ctx.fc1_input_quantizer(ln_out_total)
# Prepare grad output tensor
# Make sure GEMM inputs have required data # Note: Synchronize tensor-parallel communication and
if isinstance(ln_out_total, QuantizedTensor): # make sure required data is available
ln_out_total.update_usage(columnwise_usage=True) if ctx.fp8 or ctx.debug:
if isinstance(dact, QuantizedTensor): if isinstance(dact, QuantizedTensorBase):
dact.update_usage(columnwise_usage=True) dact.update_usage(columnwise_usage=True)
else:
ctx.fc1_grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
dact = ctx.fc1_grad_output_quantizer(dact)
# Output buffer for overlapping grad input # Output buffer for overlapping grad input
# reduce-scatter with wgrad GEMM # reduce-scatter with wgrad GEMM
reduce_scatter_out = None
if ctx.ub_bulk_wgrad and ub_obj_fc1_wgrad.is_fp8_ubuf(): if ctx.ub_bulk_wgrad and ub_obj_fc1_wgrad.is_fp8_ubuf():
fc1_dgrad_rs_out = torch.empty( reduce_scatter_out = torch.empty(
fc1_dgrad_shape, dtype=ctx.activation_dtype, device="cuda" fc1_dgrad_shape, dtype=ctx.activation_dtype, device="cuda"
) )
# wgrad GEMM # Arguments to include in wgrad GEMM closure
general_gemm_fc1_wgrad = functools.partial( fc1_wgrad_gemm_kwargs = {
general_gemm, "workspace": get_workspace(),
out_dtype=( "out_dtype": (
origin_fc1_weight.main_grad.dtype origin_fc1_weight.main_grad.dtype
if ctx.fuse_wgrad_accumulation if ctx.fuse_wgrad_accumulation
else ctx.activation_dtype else ctx.activation_dtype
), ),
workspace=get_workspace(), "quantization_params": ctx.fc1_grad_weight_quantizer,
layout="NT", "accumulate": accumulate_wgrad_into_param_main_grad,
quantization_params=ctx.fc1_grad_weight_quantizer, "layout": "NT",
grad=fuse_gemm_and_bias_fc1_wgrad, "out": origin_fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
bias=fc1_bias if fuse_gemm_and_bias_fc1_wgrad else None, "bias": fc1_bias if fuse_gemm_and_bias_fc1_wgrad else None,
accumulate=accumulate_wgrad_into_param_main_grad, "use_split_accumulator": wgrad_use_split_accumulator,
out=origin_fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None, "grad": fuse_gemm_and_bias_fc1_wgrad,
ub=ub_obj_fc1_wgrad, "ub": ub_obj_fc1_wgrad,
ub_type=tex.CommOverlapType.RS if ctx.ub_bulk_wgrad else None, "ub_type": ub_type_fc1_wgrad,
extra_output=fc1_dgrad_rs_out, "extra_output": reduce_scatter_out,
bulk_overlap=ctx.ub_bulk_wgrad, "bulk_overlap": ctx.ub_bulk_wgrad,
) }
def fc1_wgrad_gemm(
x: torch.Tensor,
dy: torch.Tensor,
_is_delayed: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Perform FC1 WGRAD GEMM
May be called outside of this function to enable
some advanced communication/compute overlapping.
"""
dw, db, *_ = general_gemm(x, dy, **fc1_wgrad_gemm_kwargs)
return dw, db
# Choose whether to call wgrad GEMM now or delay
if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute(): if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
ctx.wgrad_store.put([ln_out_total, dact], general_gemm_fc1_wgrad) if (
fc1_wgrad_gemm_kwargs["ub"] is not None
or fc1_wgrad_gemm_kwargs["ub_type"] is not None
or fc1_wgrad_gemm_kwargs["extra_output"] is not None
or fc1_wgrad_gemm_kwargs["bulk_overlap"]
):
raise NotImplementedError(
"Delayed weight grad computation is not supported "
"with Userbuffers (tensor-parallel communication overlapping)"
)
ctx.wgrad_store.put([ln_out_total, dact], fc1_wgrad_gemm)
fc1_wgrad = None fc1_wgrad = None
if fuse_gemm_and_bias_fc1_wgrad: if fuse_gemm_and_bias_fc1_wgrad:
fc1_bias_grad = None fc1_bias_grad = None
else: else:
fc1_wgrad_outputs = general_gemm_fc1_wgrad(
ln_out_total,
dact,
)
clear_tensor_data(ln_out_total, dact)
# Call wgrad GEMM now
fc1_wgrad_outputs = fc1_wgrad_gemm(ln_out_total, dact)
if fuse_gemm_and_bias_fc1_wgrad: if fuse_gemm_and_bias_fc1_wgrad:
fc1_wgrad, fc1_bias_grad, *_ = fc1_wgrad_outputs fc1_wgrad, fc1_bias_grad = fc1_wgrad_outputs
else: else:
fc1_wgrad, *_ = fc1_wgrad_outputs fc1_wgrad, _ = fc1_wgrad_outputs
# Deallocate tensors if permitted
clear_tensor_data(dact)
if not ctx.return_layernorm_output_gathered:
clear_tensor_data(ln_out_total)
# Update grad input if overlapping reduce-scatter with wgrad GEMM
if ctx.ub_bulk_wgrad: if ctx.ub_bulk_wgrad:
if ub_obj_fc1_wgrad.is_fp8_ubuf(): if ub_obj_fc1_wgrad.is_fp8_ubuf():
fc1_dgrad = fc1_dgrad_rs_out fc1_dgrad = reduce_scatter_out
else: else:
fc1_dgrad = ub_obj_fc1_wgrad.get_buffer(None, local_chunk=True) fc1_dgrad = ub_obj_fc1_wgrad.get_buffer(local_chunk=True).clone()
# --------------------------------------------------
# Finished FC1 WGRAD...
# --------------------------------------------------
# Make sure all tensor-parallel communication is finished # Make sure all tensor-parallel communication is finished
if ln_out_total_work is not None: if ln_out_total_work is not None:
...@@ -1748,7 +1863,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1748,7 +1863,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
) = [None] * 12 ) = [None] * 12
if self.fp8: if self.fp8:
fc1_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] fc1_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
fc1_input_quantizer.internal = False # temporary fc1_input_quantizer.internal = True
fc1_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] fc1_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
fc1_weight_quantizer.internal = True fc1_weight_quantizer.internal = True
fc2_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT] fc2_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT]
...@@ -1756,6 +1871,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1756,6 +1871,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
rowwise=True, rowwise=True,
columnwise=isinstance(fc2_input_quantizer, (MXFP8Quantizer, Float8BlockQuantizer)), columnwise=isinstance(fc2_input_quantizer, (MXFP8Quantizer, Float8BlockQuantizer)),
) )
fc1_input_quantizer.internal = True
fc2_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_WEIGHT] fc2_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_WEIGHT]
fc2_weight_quantizer.internal = True fc2_weight_quantizer.internal = True
if fp8_output: if fp8_output:
...@@ -1764,11 +1880,11 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1764,11 +1880,11 @@ class LayerNormMLP(TransformerEngineBaseModule):
] ]
if torch.is_grad_enabled(): if torch.is_grad_enabled():
fc2_grad_output_quantizer = self.quantizers["scaling_bwd"][ fc2_grad_output_quantizer = self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1 tex.FP8BwdTensors.GRAD_OUTPUT2
] ]
fc2_grad_output_quantizer.internal = True fc2_grad_output_quantizer.internal = True
fc1_grad_output_quantizer = self.quantizers["scaling_bwd"][ fc1_grad_output_quantizer = self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_INPUT1 tex.FP8BwdTensors.GRAD_OUTPUT1
] ]
fc1_grad_output_quantizer.internal = True fc1_grad_output_quantizer.internal = True
...@@ -1853,25 +1969,25 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1853,25 +1969,25 @@ class LayerNormMLP(TransformerEngineBaseModule):
else: else:
# fc2_grad_output_quantizer: set configs about amax epsilon and power_2_scale for fc2_grad_output_quantizer # fc2_grad_output_quantizer: set configs about amax epsilon and power_2_scale for fc2_grad_output_quantizer
self.quantizers["scaling_bwd"][ self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1 tex.FP8BwdTensors.GRAD_OUTPUT2
].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale ].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale
self.quantizers["scaling_bwd"][ self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1 tex.FP8BwdTensors.GRAD_OUTPUT2
].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon
# fc1_grad_output_quantizer: also set numerical configs for fc1_grad_output_quantizer # fc1_grad_output_quantizer: also set numerical configs for fc1_grad_output_quantizer
self.quantizers["scaling_bwd"][ self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_INPUT1 tex.FP8BwdTensors.GRAD_OUTPUT1
].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale ].force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale
self.quantizers["scaling_bwd"][ self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_INPUT1 tex.FP8BwdTensors.GRAD_OUTPUT1
].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon
if self.sequence_parallel and self.set_parallel_mode: if self.sequence_parallel and self.set_parallel_mode:
# fc2_grad_output_quantizer: customize grad_output_quantizer with amax reduction TP group, row parallel + sequence parallel here # fc2_grad_output_quantizer: customize grad_output_quantizer with amax reduction TP group, row parallel + sequence parallel here
self.quantizers["scaling_bwd"][ self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1 tex.FP8BwdTensors.GRAD_OUTPUT2
].with_amax_reduction = True ].with_amax_reduction = True
self.quantizers["scaling_bwd"][ self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1 tex.FP8BwdTensors.GRAD_OUTPUT2
].amax_reduction_group = self.tp_group ].amax_reduction_group = self.tp_group
def backward_dw(self): def backward_dw(self):
......
...@@ -6,24 +6,26 @@ ...@@ -6,24 +6,26 @@
from typing import Callable, Dict, Optional, Tuple, Union from typing import Callable, Dict, Optional, Tuple, Union
from functools import reduce from functools import reduce
from operator import mul as multiply_op from operator import mul as multiply_op
import warnings
import functools
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch import torch_version from transformer_engine.pytorch import torch_version
from .base import ( from .base import (
get_workspace, fill_userbuffers_buffer_for_all_gather,
get_dummy_wgrad,
get_ub, get_ub,
get_workspace,
TransformerEngineBaseModule, TransformerEngineBaseModule,
get_dummy_wgrad,
_2X_ACC_FPROP, _2X_ACC_FPROP,
_2X_ACC_DGRAD, _2X_ACC_DGRAD,
_2X_ACC_WGRAD, _2X_ACC_WGRAD,
) )
from ._common import noop_cat, _fix_gathered_fp8_transpose, WeightGradStore from ._common import noop_cat, WeightGradStore
from ..fp8 import FP8GlobalStateManager from ..fp8 import FP8GlobalStateManager
from ..utils import ( from ..utils import (
cast_if_needed, cast_if_needed,
...@@ -32,7 +34,6 @@ from ..utils import ( ...@@ -32,7 +34,6 @@ from ..utils import (
init_method_constant, init_method_constant,
requires_grad, requires_grad,
needs_quantized_gemm, needs_quantized_gemm,
non_tn_fp8_gemm_supported,
assert_dim_for_fp8_exec, assert_dim_for_fp8_exec,
nvtx_range_pop, nvtx_range_pop,
nvtx_range_push, nvtx_range_push,
...@@ -57,6 +58,7 @@ from ..jit import no_torch_dynamo ...@@ -57,6 +58,7 @@ from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing from ..graph import is_graph_capturing
from ..tensor.quantized_tensor import ( from ..tensor.quantized_tensor import (
QuantizedTensor, QuantizedTensor,
QuantizedTensorBase,
Quantizer, Quantizer,
prepare_for_saving, prepare_for_saving,
restore_from_saved, restore_from_saved,
...@@ -125,88 +127,100 @@ class _Linear(torch.autograd.Function): ...@@ -125,88 +127,100 @@ class _Linear(torch.autograd.Function):
# Make sure input dimensions are compatible # Make sure input dimensions are compatible
out_features, in_features = weight.shape out_features, in_features = weight.shape
inp_shape = inp.shape assert inp.shape[-1] == in_features, "GEMM not possible"
assert inp_shape[-1] == in_features, "GEMM not possible"
# Configure tensor-parallel communication
tp_world_size = get_distributed_world_size(tp_group) tp_world_size = get_distributed_world_size(tp_group)
backward_needs_input = is_grad_enabled and weight.requires_grad backward_needs_input = is_grad_enabled and weight.requires_grad
# Prepare input tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
nvtx_range_push(f"{nvtx_label}.input_cast_comm")
inputmat = inp.view(-1, in_features)
inputmat_total = None
with_input_all_gather_nccl = ( with_input_all_gather_nccl = (
parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop
) )
own_quantized_input = False
# TODO(kwyss): Support FP8 allgather for FP8 block quantization. # Do TP communication in high precision if quantized format
# does not support communication
force_hp_input_gather = ( force_hp_input_gather = (
fp8 and with_input_all_gather_nccl and isinstance(input_quantizer, Float8BlockQuantizer) fp8 and with_input_all_gather_nccl and isinstance(input_quantizer, Float8BlockQuantizer)
) # Perform TP communication in high precision. )
# Configure Userbuffers communication (comm+GEMM overlap)
ub_obj = None
ub_type = None
if ub_overlap_rs_fprop:
ub_obj = get_ub(ub_name + "_fprop")
ub_type = tex.CommOverlapType.RS
elif ub_overlap_ag_fprop:
ub_obj = get_ub(ub_name + "_fprop")
ub_type = tex.CommOverlapType.AG
# ------------------------------------------------------
# Prepare input tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
# ------------------------------------------------------
nvtx_range_push(f"{nvtx_label}.input_cast_comm")
inputmat = inp # Input tensor to save for backward (maybe sharded)
inputmat_total = None # Input tensor to pass to GEMM (gathered)
own_quantized_input = False
if fp8: if fp8:
assert_dim_for_fp8_exec(inputmat, weight) assert_dim_for_fp8_exec(inputmat, weight)
if any([ub_overlap_ag_fprop, ub_overlap_rs_fprop]) and not ( if with_input_all_gather_nccl or ub_overlap_ag_fprop: # All-gather input tensor
FP8GlobalStateManager.get_fp8_recipe().float8_per_tensor_scaling()
): # Cast local input tensor if needed
raise NotImplementedError( if fp8 or debug:
"Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor" if input_quantizer is None:
" current scaling" raise ValueError("Missing quantizer for input tensor")
) if not force_hp_input_gather and not isinstance(inputmat, QuantizedTensorBase):
if fp8 or debug: input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input)
if input_quantizer is None: if isinstance(
raise ValueError("Missing quantizer for input tensor") input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
if with_input_all_gather_nccl: ):
if force_hp_input_gather: # All-gather is not supported with FP8 column-wise data
input_quantizer.set_usage(rowwise=True, columnwise=False) input_quantizer.set_usage(columnwise=False)
inputmat_total, _ = gather_along_first_dim( inputmat = input_quantizer(inputmat)
inputmat, tp_group, quantizer=input_quantizer own_quantized_input = True
)
else:
if not isinstance(inputmat, QuantizedTensor):
columnwise_usage = backward_needs_input and isinstance(
input_quantizer, MXFP8Quantizer
)
# force_hp_input_gather should enforce this
assert not isinstance(input_quantizer, Float8BlockQuantizer)
input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
inputmat = input_quantizer(inputmat)
own_quantized_input = True
input_quantizer.set_usage(rowwise=True, columnwise=False)
inputmat_total, _ = gather_along_first_dim(
inputmat,
tp_group,
quantizer=input_quantizer,
)
else: else:
if ( inputmat = cast_if_needed(inp, activation_dtype) # Cast for AMP
FP8GlobalStateManager.get_fp8_recipe().float8_per_tensor_scaling()
and ub_bulk_dgrad # Initialize gathered input tensor
): quantizer = None
# reduce duplicated transpose in `_fix_gathered_fp8_transpose` if fp8 or debug:
input_quantizer.set_usage(rowwise=True, columnwise=False) quantizer = input_quantizer
quantizer.set_usage(rowwise=True, columnwise=False)
if with_input_all_gather_nccl: # Perform NCCL all-gather
inputmat_total, _ = gather_along_first_dim(
inputmat,
tp_group,
quantizer=quantizer,
)
elif ub_overlap_ag_fprop: # Initialize Userbuffers all-gather
inputmat_total, _ = fill_userbuffers_buffer_for_all_gather(
ub_obj,
inputmat,
quantizer,
tp_group,
)
else: # Do not all-gather input tensor
if fp8 or debug:
if isinstance(inputmat, QuantizedTensorBase):
inputmat.update_usage(rowwise_usage=True)
else: else:
input_quantizer.set_usage( if input_quantizer is None:
rowwise=True, raise ValueError("Missing quantizer for input tensor")
columnwise=backward_needs_input, input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input)
)
if not isinstance(inputmat, QuantizedTensor):
inputmat = input_quantizer(inputmat) inputmat = input_quantizer(inputmat)
own_quantized_input = True own_quantized_input = True
elif backward_needs_input:
inputmat.update_usage(rowwise_usage=True, columnwise_usage=True)
inputmat_total = inputmat
else:
inputmat = cast_if_needed(inp, activation_dtype)
if with_input_all_gather_nccl:
inputmat_total, _ = gather_along_first_dim(inputmat, tp_group)
else: else:
inputmat_total = inputmat inputmat = cast_if_needed(inp, activation_dtype) # Cast for AMP
inputmat_total = inputmat
nvtx_range_pop(f"{nvtx_label}.input_cast_comm") nvtx_range_pop(f"{nvtx_label}.input_cast_comm")
# ------------------------------------------------------
# Input tensor is ready for GEMM...
# ------------------------------------------------------
# Cast weight to expected dtype # ------------------------------------------------------
# Prepare weight tensor
# ------------------------------------------------------
weightmat = weight weightmat = weight
if fp8 or debug: if fp8 or debug:
# Configure quantizer # Configure quantizer
if weight_quantizer is not None: if weight_quantizer is not None:
...@@ -217,7 +231,8 @@ class _Linear(torch.autograd.Function): ...@@ -217,7 +231,8 @@ class _Linear(torch.autograd.Function):
and not in_fp8_activation_recompute_phase() and not in_fp8_activation_recompute_phase()
) )
weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
# FP8 cast to workspace buffer
# Get quantized weight
update_workspace = is_first_microbatch is None or is_first_microbatch update_workspace = is_first_microbatch is None or is_first_microbatch
weightmat = module.get_weight_workspace( weightmat = module.get_weight_workspace(
tensor=weight, tensor=weight,
...@@ -228,19 +243,21 @@ class _Linear(torch.autograd.Function): ...@@ -228,19 +243,21 @@ class _Linear(torch.autograd.Function):
fsdp_group=fsdp_group, fsdp_group=fsdp_group,
workspace_dtype=activation_dtype, workspace_dtype=activation_dtype,
) )
weightmat.update_usage(rowwise_usage=True)
else: else:
weightmat = cast_if_needed(weightmat, activation_dtype) weightmat = cast_if_needed(weightmat, activation_dtype) # Cast for AMP
# ------------------------------------------------------
# Weight tensor is ready for GEMM...
# ------------------------------------------------------
# Cast bias to expected dtype # Cast bias to expected dtype
bias_dtype = activation_dtype bias_dtype = activation_dtype
if needs_quantized_gemm(inputmat_total) and activation_dtype == torch.float32: if needs_quantized_gemm(inputmat_total) and activation_dtype == torch.float32:
# cuBLAS does not support FP8 GEMM with FP32 bias, so we cast to BF16
bias_dtype = torch.bfloat16 bias_dtype = torch.bfloat16
bias = cast_if_needed(bias, bias_dtype) if bias is not None else bias bias = cast_if_needed(bias, bias_dtype) if bias is not None else bias
# Configure output quantizer
if output_quantizer is not None:
output_quantizer.set_usage(rowwise=True, columnwise=False)
# Calibrate quantizers if needed # Calibrate quantizers if needed
if not fp8 and fp8_calibration: if not fp8 and fp8_calibration:
if input_quantizer is not None: if input_quantizer is not None:
...@@ -248,44 +265,74 @@ class _Linear(torch.autograd.Function): ...@@ -248,44 +265,74 @@ class _Linear(torch.autograd.Function):
if weight_quantizer is not None: if weight_quantizer is not None:
weight_quantizer.calibrate(weight) weight_quantizer.calibrate(weight)
ub_obj = None # Choose whether to use GEMM kernel with split accumulator
ub_type = None use_split_accumulator = _2X_ACC_FPROP
rs_out = None
out_dtype = activation_dtype
if ub_overlap_rs_fprop:
ub_obj = get_ub(ub_name + "_fprop")
ub_type = tex.CommOverlapType.RS
out_shape = [reduce(multiply_op, inp_shape[:-1]) // tp_world_size, out_features]
rs_out = torch.empty(out_shape, dtype=activation_dtype, device=inputmat_total.device)
elif ub_overlap_ag_fprop:
ub_obj = get_ub(ub_name + "_fprop")
ub_type = tex.CommOverlapType.AG
if fp8:
assert ub_obj.is_fp8_ubuf(), "AG overlap with FP8 GEMM inputs requires FP8 buffer."
ub_obj.copy_into_buffer(inputmat_total, input_quantizer, local_chunk=True)
inputmat_total = ub_obj.get_buffer(input_quantizer)
nvtx_range_push(f"{nvtx_label}.gemm")
fprop_gemm_use_split_accumulator = _2X_ACC_FPROP
if fp8: if fp8:
recipe = FP8GlobalStateManager.get_fp8_recipe() recipe = FP8GlobalStateManager.get_fp8_recipe()
if hasattr(recipe, "fp8_gemm_fprop"): if hasattr(recipe, "fp8_gemm_fprop"):
fprop_gemm_use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator
# Configure output quantizer
if output_quantizer is not None:
output_quantizer.set_usage(rowwise=True, columnwise=False)
out, *_, rs_out = general_gemm( # Output buffer for Userbuffers reduce-scatter
reduce_scatter_out = None
if ub_overlap_rs_fprop:
out_shape = list(inp.shape)
out_shape[0] //= tp_world_size
out_shape[-1] = out_features
reduce_scatter_out = torch.empty(out_shape, dtype=activation_dtype, device=inp.device)
# ------------------------------------------------------
# Forward GEMM
# Note: y = x * w^T
# ------------------------------------------------------
nvtx_range_push(f"{nvtx_label}.gemm")
gemm_out, *_, reduce_scatter_out = general_gemm(
weightmat, weightmat,
inputmat_total, inputmat_total,
get_workspace(), get_workspace(),
quantization_params=output_quantizer, quantization_params=output_quantizer,
out_dtype=out_dtype, out_dtype=activation_dtype,
bias=bias, bias=bias,
use_split_accumulator=fprop_gemm_use_split_accumulator, use_split_accumulator=use_split_accumulator,
ub=ub_obj, ub=ub_obj,
ub_type=ub_type, ub_type=ub_type,
extra_output=rs_out, extra_output=reduce_scatter_out,
) )
nvtx_range_pop(f"{nvtx_label}.gemm") nvtx_range_pop(f"{nvtx_label}.gemm")
# ------------------------------------------------------
# Finished forward GEMM...
# ------------------------------------------------------
# ------------------------------------------------------
# Prepare output tensor
# Note: Perform tensor-parallel communication
# ------------------------------------------------------
out = None
if ub_overlap_rs_fprop:
out = reduce_scatter_out
elif parallel_mode == "row" and tp_size > 1:
nvtx_range_push(f"{nvtx_label}.row_parallel_comm")
out = gemm_out
if sequence_parallel:
out, _ = reduce_scatter_along_first_dim(out, tp_group)
elif tensor_parallel:
if symmetric_ar_type is not None:
out, _ = symmetric_all_reduce(out, tp_group, all_reduce_type=symmetric_ar_type)
else:
out, _ = allreduce(out, tp_group)
nvtx_range_pop(f"{nvtx_label}.row_parallel_comm")
else:
out = gemm_out
# ------------------------------------------------------
# Output tensor is ready to return...
# ------------------------------------------------------
# ------------------------------------------------------
# Cache state for backward pass
# ------------------------------------------------------
if is_grad_enabled: if is_grad_enabled:
ctx.weight_quantizer = weight_quantizer ctx.weight_quantizer = weight_quantizer
...@@ -296,19 +343,19 @@ class _Linear(torch.autograd.Function): ...@@ -296,19 +343,19 @@ class _Linear(torch.autograd.Function):
) )
if backward_needs_input: if backward_needs_input:
if own_quantized_input and isinstance(inputmat, QuantizedTensor): if own_quantized_input and isinstance(inputmat, QuantizedTensorBase):
# For sequence parallel in vanilla FP8, rowwise data is # For sequence parallel in vanilla FP8, rowwise data is
# to gather the input. For MXFP8, columnwise only data # to gather the input. For MXFP8, columnwise only data
# can be allgathered. # can be allgathered.
if isinstance(inputmat, MXFP8TensorBase) or not ctx.backward_input_needs_gather: if isinstance(inputmat, MXFP8TensorBase) or not ctx.backward_input_needs_gather:
inputmat.update_usage(rowwise_usage=False, columnwise_usage=True) inputmat.update_usage(rowwise_usage=False, columnwise_usage=True)
if force_hp_input_gather: if force_hp_input_gather:
assert not isinstance(inputmat, QuantizedTensor) assert not isinstance(inputmat, QuantizedTensorBase)
saved_inputmat = inputmat saved_inputmat = inputmat
# Weight with column-wise usage is needed for dgrad GEMM. # Weight with column-wise usage is needed for dgrad GEMM.
if inp.requires_grad: if inp.requires_grad:
if isinstance(weightmat, QuantizedTensor): if isinstance(weightmat, QuantizedTensorBase):
weightmat.update_usage(columnwise_usage=True) weightmat.update_usage(columnwise_usage=True)
if cpu_offloading and saved_inputmat is not None: if cpu_offloading and saved_inputmat is not None:
...@@ -321,7 +368,7 @@ class _Linear(torch.autograd.Function): ...@@ -321,7 +368,7 @@ class _Linear(torch.autograd.Function):
ctx.fsdp_shapes = _fsdp_scatter_tensors( ctx.fsdp_shapes = _fsdp_scatter_tensors(
fsdp_group, fsdp_group,
saved_inputmat, saved_inputmat,
weightmat if fp8 and not isinstance(weight, QuantizedTensor) else None, weightmat if fp8 and not isinstance(weight, QuantizedTensorBase) else None,
) )
nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") nvtx_range_pop(f"{nvtx_label}.fsdp_scatter")
...@@ -364,7 +411,7 @@ class _Linear(torch.autograd.Function): ...@@ -364,7 +411,7 @@ class _Linear(torch.autograd.Function):
ctx.use_bias = bias is not None ctx.use_bias = bias is not None
ctx.sequence_parallel = sequence_parallel ctx.sequence_parallel = sequence_parallel
ctx.tensor_parallel = tensor_parallel ctx.tensor_parallel = tensor_parallel
ctx.inp_shape = inp_shape ctx.inp_shape = inp.shape
ctx.parallel_mode = parallel_mode ctx.parallel_mode = parallel_mode
ctx.tp_group = tp_group ctx.tp_group = tp_group
ctx.ub_overlap_ag = ub_overlap_ag_dgrad ctx.ub_overlap_ag = ub_overlap_ag_dgrad
...@@ -376,6 +423,7 @@ class _Linear(torch.autograd.Function): ...@@ -376,6 +423,7 @@ class _Linear(torch.autograd.Function):
ctx.requires_dgrad = inp.requires_grad ctx.requires_dgrad = inp.requires_grad
ctx.requires_wgrad = weight.requires_grad ctx.requires_wgrad = weight.requires_grad
ctx.reduce_and_update_bwd_fp8_tensors = False ctx.reduce_and_update_bwd_fp8_tensors = False
ctx.owns_input = saved_inputmat is not inp ctx.owns_input = saved_inputmat is not inp
if ctx.fp8 and requires_grad(inp, weight, bias): if ctx.fp8 and requires_grad(inp, weight, bias):
_first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE
...@@ -384,21 +432,10 @@ class _Linear(torch.autograd.Function): ...@@ -384,21 +432,10 @@ class _Linear(torch.autograd.Function):
FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module
ctx.wgrad_store = wgrad_store ctx.wgrad_store = wgrad_store
# Row Parallel Linear # ------------------------------------------------------
if ub_overlap_rs_fprop: # Cached state for backward pass is ready...
out = rs_out # ------------------------------------------------------
elif parallel_mode == "row":
nvtx_range_push(f"{nvtx_label}.row_parallel_comm")
if sequence_parallel:
out, _ = reduce_scatter_along_first_dim(out, tp_group)
elif tensor_parallel:
if symmetric_ar_type is not None:
out, _ = symmetric_all_reduce(out, tp_group, all_reduce_type=symmetric_ar_type)
else:
out, _ = allreduce(out, tp_group)
nvtx_range_pop(f"{nvtx_label}.row_parallel_comm")
out = out.view(-1, *inp_shape[1:-1], out_features)
return out return out
@staticmethod @staticmethod
...@@ -411,28 +448,11 @@ class _Linear(torch.autograd.Function): ...@@ -411,28 +448,11 @@ class _Linear(torch.autograd.Function):
nvtx_label = f"{nvtx_label}.{ctx.ub_name}" nvtx_label = f"{nvtx_label}.{ctx.ub_name}"
with torch.cuda.nvtx.range("_Linear_backward"): with torch.cuda.nvtx.range("_Linear_backward"):
if (
ctx.fp8
and any(
[
ctx.ub_overlap_ag,
ctx.ub_overlap_rs_dgrad,
ctx.ub_bulk_dgrad,
ctx.ub_bulk_wgrad,
]
)
and (ctx.fp8_recipe is not None)
):
if not ctx.fp8_recipe.float8_per_tensor_scaling():
raise NotImplementedError(
"Comm+GEMM overlap is only supported with FP8 delayed scaling or per-tensor"
" current scaling"
)
saved_tensors = ctx.saved_tensors saved_tensors = ctx.saved_tensors
inputmat, weight_fp8, weight, bias = ( # pylint: disable=unbalanced-tuple-unpacking inputmat, weight_fp8, weight, bias = ( # pylint: disable=unbalanced-tuple-unpacking
restore_from_saved(ctx.tensor_objects, saved_tensors) restore_from_saved(ctx.tensor_objects, saved_tensors)
) )
# Delete the references to tensor objects once they've been consumed # Delete the references to tensor objects once they've been consumed
# by the `restore_from_saved` method to construct back the actual tensors. # by the `restore_from_saved` method to construct back the actual tensors.
ctx.tensor_objects = None ctx.tensor_objects = None
...@@ -462,69 +482,55 @@ class _Linear(torch.autograd.Function): ...@@ -462,69 +482,55 @@ class _Linear(torch.autograd.Function):
) )
nvtx_range_pop(f"{nvtx_label}.fsdp_gather") nvtx_range_pop(f"{nvtx_label}.fsdp_gather")
# Configure Userbuffers communication (comm+GEMM overlap)
ctx.ub_obj_gradout = None ctx.ub_obj_gradout = None
ub_obj_dgrad = None ub_obj_dgrad = None
ub_obj_wgrad = None ub_obj_wgrad = None
ub_type_dgrad = None ub_type_dgrad = None
ub_type_wgrad = None ub_type_wgrad = None
dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]] dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]]
rs_out = None
dgrad_bulk = None
if ctx.ub_overlap_ag: if ctx.ub_overlap_ag:
# Overlap grad_output all-gather with dgrad compute # Overlap grad_output all-gather with dgrad compute
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
ub_obj_dgrad = ctx.ub_obj_gradout ub_obj_dgrad = ctx.ub_obj_gradout
ub_type_dgrad = tex.CommOverlapType.AG ub_type_dgrad = tex.CommOverlapType.AG
elif ctx.ub_overlap_rs_dgrad: elif ctx.ub_overlap_rs_dgrad:
# Overlap dgrad reduce-scatter with dgrad compute # Overlap dgrad reduce-scatter with dgrad compute
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
ub_obj_dgrad = ctx.ub_obj_gradout ub_obj_dgrad = ctx.ub_obj_gradout
ub_type_dgrad = tex.CommOverlapType.RS ub_type_dgrad = tex.CommOverlapType.RS
rs_out = torch.empty(
dgrad_shape, dtype=ctx.activation_dtype, device=grad_output.device
)
else: else:
if ctx.ub_bulk_dgrad: if ctx.ub_bulk_dgrad:
# Overlap inputmat all-gather with dgrad compute # Overlap inputmat all-gather with dgrad compute
# NOTE: Copying into communication buffer will always prefer rowwise data,
# and will copy columnwise data if rowwise does not exist. In that case,
# the all-gather will apply to the leading dimension of the transpose,
# which then needs to be interleaved correctly before WGRAD.
ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad")
ub_obj_dgrad = ctx.ub_obj_gradout ub_obj_dgrad = ctx.ub_obj_gradout
ub_type_dgrad = tex.CommOverlapType.AG ub_type_dgrad = tex.CommOverlapType.AG
ub_obj_dgrad.copy_into_buffer(inputmat, ctx.input_quantizer, local_chunk=True)
if ctx.ub_bulk_wgrad: if ctx.ub_bulk_wgrad:
# Overlap dgrad reduce-scatter with wgrad compute # Overlap dgrad reduce-scatter with wgrad compute
ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad") ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad")
ub_type_wgrad = tex.CommOverlapType.RS ub_type_wgrad = tex.CommOverlapType.RS
ub_obj_wgrad.set_buffer_params(ctx.grad_input_quantizer)
dgrad_bulk = ub_obj_wgrad.get_buffer(ctx.grad_input_quantizer) # --------------------------------------------------
# Prepare grad output tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
# --------------------------------------------------
# Unmodified grad output tensor
grad_output_arg = grad_output
# Configure quantizer for grad output tensor # Configure quantizer for grad output tensor
# Note: dgrad GEMM requires row-wise usage, wgrad GEMM # Note: dgrad GEMM requires row-wise usage, wgrad GEMM
# requires column-wise usage # requires column-wise usage
if ctx.grad_output_quantizer is not None: if ctx.grad_output_quantizer is not None:
rowwise_usage = True quantizer = ctx.grad_output_quantizer
columnwise_usage = True quantizer.set_usage(rowwise=True, columnwise=True)
if ctx.ub_overlap_ag and isinstance( if ctx.ub_overlap_ag:
ctx.grad_output_quantizer, # Userbuffers only supports communication for one
(Float8Quantizer, Float8CurrentScalingQuantizer), # tensor usage at a time. Configure quantizer with
): # usage for only dgrad GEMM.
# If data is in FP8 and communication is handled quantizer.set_usage(columnwise=False)
# with Userbuffers, we compute FP8 transposes
# manually
columnwise_usage = False
ctx.grad_output_quantizer.set_usage(
rowwise=rowwise_usage,
columnwise=columnwise_usage,
)
# Prepare grad output tensor # Prepare grad output tensor
# Note: Cast to expected dtype and perform tensor-parallel communication
nvtx_range_push(f"{nvtx_label}.grad_output_preprocess") nvtx_range_push(f"{nvtx_label}.grad_output_preprocess")
( (
grad_output, grad_output,
...@@ -537,12 +543,21 @@ class _Linear(torch.autograd.Function): ...@@ -537,12 +543,21 @@ class _Linear(torch.autograd.Function):
) )
nvtx_range_pop(f"{nvtx_label}.grad_output_preprocess") nvtx_range_pop(f"{nvtx_label}.grad_output_preprocess")
# Launch tensor-parallel communication for input tensor # --------------------------------------------------
# Grad output tensor is ready for computing grad input...
# --------------------------------------------------
# --------------------------------------------------
# Prepare input tensor
# Note: Input tensor is needed for wgrad GEMM.
# Tensor-parallel communication is overlapped with dgrad
# GEMM.
# --------------------------------------------------
inputmat_total = None inputmat_total = None
inputmat_total_work = None inputmat_total_work = None
if ctx.backward_input_needs_gather and not ctx.ub_bulk_dgrad: if ctx.backward_input_needs_gather:
quantizer = None quantizer = None
if ctx.fp8 or ctx.debug: if (ctx.fp8 or ctx.debug) and not ctx.force_hp_input_gather:
quantizer = ctx.input_quantizer quantizer = ctx.input_quantizer
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)):
# If data is in FP8, we compute FP8 transposes manually # If data is in FP8, we compute FP8 transposes manually
...@@ -550,72 +565,92 @@ class _Linear(torch.autograd.Function): ...@@ -550,72 +565,92 @@ class _Linear(torch.autograd.Function):
else: else:
# wgrad GEMM requires input with column-wise usage # wgrad GEMM requires input with column-wise usage
quantizer.set_usage(rowwise=False, columnwise=True) quantizer.set_usage(rowwise=False, columnwise=True)
nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input") if ctx.ub_bulk_dgrad:
gather_quantizer = None if ctx.force_hp_input_gather else quantizer inputmat_total, _ = fill_userbuffers_buffer_for_all_gather(
inputmat_total, inputmat_total_work = gather_along_first_dim( ub_obj_dgrad,
inputmat, inputmat,
ctx.tp_group, quantizer,
async_op=True, ctx.tp_group,
quantizer=gather_quantizer, )
) else:
nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_input") nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input")
inputmat_total, inputmat_total_work = gather_along_first_dim(
inputmat,
ctx.tp_group,
async_op=True,
quantizer=quantizer,
)
nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_input")
else: else:
inputmat_total = inputmat inputmat_total = inputmat
# --------------------------------------------------
# Input tensor is ready for computing grad weight...
# --------------------------------------------------
# Check whether to output wgrad GEMM directly into main grad # --------------------------------------------------
if ctx.is_first_microbatch is not None:
accumulate_wgrad_into_param_main_grad = (
ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch
)
else:
accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
# Compute grad input tensor # Compute grad input tensor
# --------------------------------------------------
dgrad = None dgrad = None
dgrad_work = None dgrad_work = None
if ctx.requires_dgrad: if ctx.requires_dgrad:
# Update quantizer # Make sure required data is available
if ctx.grad_input_quantizer is not None: if isinstance(grad_output, QuantizedTensorBase):
ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) grad_output.update_usage(rowwise_usage=True)
# dgrad GEMM if ctx.weight_quantizer is not None and isinstance(weight_fp8, QuantizedTensorBase):
nvtx_range_push(f"{nvtx_label}.dgrad_gemm") weight_fp8.update_usage(columnwise_usage=True)
dgrad_gemm_use_split_accumulator = _2X_ACC_DGRAD
# Choose whether to use GEMM kernel with split accumulator
use_split_accumulator = _2X_ACC_DGRAD
if ctx.fp8: if ctx.fp8:
recipe = ctx.fp8_recipe recipe = ctx.fp8_recipe
if hasattr(recipe, "fp8_gemm_dgrad"): if hasattr(recipe, "fp8_gemm_dgrad"):
dgrad_gemm_use_split_accumulator = ( use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator
recipe.fp8_gemm_dgrad.use_split_accumulator
)
if ctx.weight_quantizer is not None and isinstance(weight_fp8, QuantizedTensor): # Update grad input quantizer
weight_fp8.update_usage( if ctx.grad_input_quantizer is not None:
rowwise_usage=ctx.weight_quantizer.rowwise_usage, ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False)
columnwise_usage=ctx.weight_quantizer.columnwise_usage,
# Output buffers for Userbuffers reduce-scatter
gemm_out = None
reduce_scatter_out = None
if ctx.ub_overlap_rs_dgrad:
reduce_scatter_out = torch.empty(
dgrad_shape, dtype=ctx.activation_dtype, device=grad_output_arg.device
) )
elif ctx.ub_bulk_wgrad:
gemm_out = ub_obj_wgrad.get_buffer(local_chunk=False)
dgrad, *_, rs_out = general_gemm( # dgrad GEMM
# Note: dx = dy * w
nvtx_range_push(f"{nvtx_label}.dgrad_gemm")
gemm_out, *_, reduce_scatter_out = general_gemm(
weight_fp8, weight_fp8,
grad_output, grad_output,
get_workspace(), get_workspace(),
layout="NN", layout="NN",
grad=True, grad=True,
quantization_params=ctx.grad_input_quantizer, quantization_params=ctx.grad_input_quantizer,
out=dgrad_bulk, out=gemm_out,
out_dtype=ctx.activation_dtype, out_dtype=ctx.activation_dtype,
use_split_accumulator=dgrad_gemm_use_split_accumulator, use_split_accumulator=use_split_accumulator,
ub=ub_obj_dgrad, ub=ub_obj_dgrad,
ub_type=ub_type_dgrad, ub_type=ub_type_dgrad,
extra_output=rs_out, extra_output=reduce_scatter_out,
bulk_overlap=ctx.ub_bulk_dgrad, bulk_overlap=ctx.ub_bulk_dgrad,
) )
nvtx_range_pop(f"{nvtx_label}.dgrad_gemm") nvtx_range_pop(f"{nvtx_label}.dgrad_gemm")
# Launch tensor-parallel communication # Prepare grad input tensor
# Note: Perform tensor-parallel communication
if ctx.ub_overlap_rs_dgrad: if ctx.ub_overlap_rs_dgrad:
dgrad = rs_out dgrad = reduce_scatter_out
elif ctx.parallel_mode == "column" and not ctx.ub_bulk_wgrad: elif ctx.ub_bulk_wgrad:
dgrad = ub_obj_wgrad.get_buffer(local_chunk=True)
elif ctx.parallel_mode == "column" and ctx.tp_size > 1:
nvtx_range_push(f"{nvtx_label}.column_parallel_comm_dgrad") nvtx_range_push(f"{nvtx_label}.column_parallel_comm_dgrad")
dgrad = gemm_out
if ctx.sequence_parallel: if ctx.sequence_parallel:
dgrad, dgrad_work = reduce_scatter_along_first_dim( dgrad, dgrad_work = reduce_scatter_along_first_dim(
dgrad, dgrad,
...@@ -625,41 +660,55 @@ class _Linear(torch.autograd.Function): ...@@ -625,41 +660,55 @@ class _Linear(torch.autograd.Function):
else: else:
dgrad, dgrad_work = allreduce(dgrad, ctx.tp_group, async_op=True) dgrad, dgrad_work = allreduce(dgrad, ctx.tp_group, async_op=True)
nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_dgrad") nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_dgrad")
else:
dgrad = gemm_out
# --------------------------------------------------
# Grad input tensor has been computed...
# --------------------------------------------------
# --------------------------------------------------
# Compute grad weight
# --------------------------------------------------
# Compute grad weight tensor
wgrad = None wgrad = None
if ctx.requires_wgrad: if ctx.requires_wgrad:
# Synchronize tensor-parallel communication for input tensor # Prepare input tensor
if ctx.ub_bulk_dgrad: # Note: Synchronize tensor-parallel communication and
inputmat_total = ub_obj_dgrad.get_buffer(ctx.input_quantizer) # make sure required data is available
if ctx.fp8:
if inputmat._data is None:
# All-gather executed on columnwise data and result is in rowwise data,
# so we need to fix the interleaving before WGRAD.
inputmat_total = _fix_gathered_fp8_transpose(
inputmat_total, ctx.tp_size
)
elif not non_tn_fp8_gemm_supported():
# FP8 GEMM on Hopper only supports TN layout so the gathered input must
# have a valid transpose.
inputmat_total._create_transpose()
if inputmat_total_work is not None: if inputmat_total_work is not None:
inputmat_total_work.wait() inputmat_total_work.wait()
inputmat_total_work = None inputmat_total_work = None
if ctx.input_quantizer is not None and not isinstance( if ctx.fp8 or ctx.debug:
inputmat_total, QuantizedTensor if isinstance(inputmat_total, QuantizedTensorBase):
): inputmat_total.update_usage(columnwise_usage=True)
# Async gather in BF16 does not asynchronously else:
# call quantizer after gather. ctx.input_quantizer.set_usage(rowwise=False, columnwise=True)
ctx.input_quantizer.set_usage(rowwise=False, columnwise=True) inputmat_total = ctx.input_quantizer(inputmat_total)
inputmat_total = ctx.input_quantizer(inputmat_total)
# Prepare grad output tensor
# Make sure GEMM inputs have required data # Note: Synchronize tensor-parallel communication and
if isinstance(inputmat_total, QuantizedTensor): # make sure required data is available
inputmat_total.update_usage(columnwise_usage=True) if ctx.ub_overlap_ag and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer):
if isinstance(grad_output, QuantizedTensor): # UB does not support overlapping grad output
grad_output.update_usage(columnwise_usage=True) # all-gather with wgrad GEMM. Also, we can't
# convert row-scaled MXFP8 to column-scaled, so we
# can't reuse the grad output that was gathered
# for the dgrad GEMM. We work around with blocking
# all-gather for column-scaled MXFP8 data.
ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
grad_output, _ = gather_along_first_dim(
grad_output_arg,
ctx.tp_group,
quantizer=ctx.grad_output_quantizer,
)
if ctx.fp8 or ctx.debug:
if isinstance(grad_output, QuantizedTensorBase):
grad_output.update_usage(columnwise_usage=True)
else:
ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
grad_output = ctx.grad_output_quantizer(grad_output)
# Figure out whether to use split accumulator # Figure out whether to use split accumulator
use_split_accumulator = _2X_ACC_WGRAD use_split_accumulator = _2X_ACC_WGRAD
...@@ -668,54 +717,95 @@ class _Linear(torch.autograd.Function): ...@@ -668,54 +717,95 @@ class _Linear(torch.autograd.Function):
if hasattr(recipe, "fp8_gemm_wgrad"): if hasattr(recipe, "fp8_gemm_wgrad"):
use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator
# Output buffer for overlapping grad input # Figure out whether to output wgrad GEMM directly into main grad
if ctx.is_first_microbatch is not None:
accumulate_wgrad_into_param_main_grad = (
ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch
)
else:
accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
# Output buffer for overlapping FP8 grad input
# reduce-scatter with wgrad GEMM # reduce-scatter with wgrad GEMM
reduce_scatter_out = None
if ctx.ub_bulk_wgrad and ub_obj_wgrad.is_fp8_ubuf(): if ctx.ub_bulk_wgrad and ub_obj_wgrad.is_fp8_ubuf():
rs_out = torch.empty( reduce_scatter_out = torch.empty(
dgrad_shape, dtype=ctx.activation_dtype, device=grad_output.device dgrad_shape, dtype=ctx.activation_dtype, device=grad_output_arg.device
) )
# wgrad GEMM # Arguments to include in wgrad GEMM closure
# Note: Fuse with bgrad computation if needed wgrad_gemm_kwargs = {
nvtx_range_push(f"{nvtx_label}.wgrad_gemm") "workspace": get_workspace(),
general_gemm_wgrad = functools.partial( "out_dtype": (
general_gemm,
out_dtype=(
main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype
), ),
workspace=get_workspace(), "quantization_params": ctx.grad_weight_quantizer,
layout="NT", "accumulate": accumulate_wgrad_into_param_main_grad,
grad=True, "layout": "NT",
bias=(bias if (grad_bias is None and not ctx.fp8) else None), "out": main_grad if ctx.fuse_wgrad_accumulation else None,
out=main_grad if ctx.fuse_wgrad_accumulation else None, "bias": (bias if (grad_bias is None and not ctx.fp8) else None),
use_split_accumulator=use_split_accumulator, "use_split_accumulator": use_split_accumulator,
accumulate=accumulate_wgrad_into_param_main_grad, "grad": True,
quantization_params=ctx.grad_weight_quantizer, "ub": ub_obj_wgrad,
ub=ub_obj_wgrad, "ub_type": ub_type_wgrad,
ub_type=ub_type_wgrad, "extra_output": reduce_scatter_out,
extra_output=rs_out, "bulk_overlap": ctx.ub_bulk_wgrad,
bulk_overlap=ctx.ub_bulk_wgrad, }
)
def wgrad_gemm(
x: torch.Tensor,
dy: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Perform wgrad GEMM: dw = dy^T * x
May be fused with bgrad computation.
May be called outside of this function to enable
some advanced communication/compute overlapping.
"""
nvtx_range_push(f"{nvtx_label}.wgrad_gemm")
dw, db, *_ = general_gemm(x, dy, **wgrad_gemm_kwargs)
nvtx_range_pop(f"{nvtx_label}.wgrad_gemm")
return dw, db
# Choose whether to call wgrad GEMM now or delay
if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute(): if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
ctx.wgrad_store.put([inputmat_total, grad_output], general_gemm_wgrad) if (
wgrad_gemm_kwargs["ub"] is not None
or wgrad_gemm_kwargs["ub_type"] is not None
or wgrad_gemm_kwargs["extra_output"] is not None
or wgrad_gemm_kwargs["bulk_overlap"]
):
raise NotImplementedError(
"Delayed weight grad computation is not supported "
"with Userbuffers (tensor-parallel communication overlapping)"
)
ctx.wgrad_store.put([inputmat_total, grad_output], wgrad_gemm)
else: else:
wgrad, grad_bias_, _, rs_out = general_gemm_wgrad(inputmat_total, grad_output)
# Call wgrad GEMM now
wgrad, grad_bias_ = wgrad_gemm(inputmat_total, grad_output)
# Update grad bias if needed
if grad_bias is None: if grad_bias is None:
grad_bias = grad_bias_ grad_bias = grad_bias_
del grad_bias_ del grad_bias_
# Deallocate input tensor # Deallocate input tensor if permitted
if ctx.owns_input: if ctx.owns_input:
clear_tensor_data(inputmat_total) clear_tensor_data(inputmat_total)
nvtx_range_pop(f"{nvtx_label}.wgrad_gemm")
# Update grad input if overlapping reduce-scatter with wgrad GEMM
if ctx.ub_bulk_wgrad: if ctx.ub_bulk_wgrad:
if ub_obj_wgrad.is_fp8_ubuf(): if ub_obj_wgrad.is_fp8_ubuf():
dgrad = rs_out dgrad = reduce_scatter_out
else: else:
dgrad = ub_obj_wgrad.get_buffer(ctx.grad_input_quantizer, local_chunk=True) dgrad = ub_obj_wgrad.get_buffer(local_chunk=True).clone()
# --------------------------------------------------
# Grad weight has been computed...
# --------------------------------------------------
# Don't return grad bias if not needed # Don't return grad bias if not needed
if not ctx.use_bias: if not ctx.use_bias:
...@@ -753,13 +843,14 @@ class _Linear(torch.autograd.Function): ...@@ -753,13 +843,14 @@ class _Linear(torch.autograd.Function):
else: else:
wgrad = None wgrad = None
# Update FP8 scaling factors if needed
if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing(): if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing():
nvtx_range_push(f"{nvtx_label}.reduce_and_update_fp8_tensors") nvtx_range_push(f"{nvtx_label}.reduce_and_update_fp8_tensors")
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
nvtx_range_pop(f"{nvtx_label}.reduce_and_update_fp8_tensors") nvtx_range_pop(f"{nvtx_label}.reduce_and_update_fp8_tensors")
# Scatter fp8 weight buffers # Scatter fp8 weight buffers
if ctx.fp8 and not isinstance(weight, QuantizedTensor): if ctx.fp8 and not isinstance(weight, QuantizedTensorBase):
_fsdp_scatter_tensors(ctx.fsdp_group, weight_fp8) _fsdp_scatter_tensors(ctx.fsdp_group, weight_fp8)
return ( return (
wgrad, wgrad,
...@@ -1207,7 +1298,12 @@ class Linear(TransformerEngineBaseModule): ...@@ -1207,7 +1298,12 @@ class Linear(TransformerEngineBaseModule):
"Splitting QuantizedTensor into multiple params is not supported" "Splitting QuantizedTensor into multiple params is not supported"
) )
else: else:
warnings.warn(
"You are using quantized weights without quantized compute. "
"Please make sure this is intentional."
)
unfused_weights = [w.dequantize() for w in unfused_weights] unfused_weights = [w.dequantize() for w in unfused_weights]
weight_tensor = noop_cat(unfused_weights) weight_tensor = noop_cat(unfused_weights)
if self.use_bias: if self.use_bias:
bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names]) bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names])
...@@ -1302,7 +1398,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -1302,7 +1398,7 @@ class Linear(TransformerEngineBaseModule):
grad_output_quantizer = None grad_output_quantizer = None
output_quantizer = None output_quantizer = None
input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
input_quantizer.internal = False input_quantizer.internal = True
weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
weight_quantizer.internal = True weight_quantizer.internal = True
if fp8_output: if fp8_output:
......
...@@ -98,6 +98,9 @@ class RMSNorm(_RMSNormOp): ...@@ -98,6 +98,9 @@ class RMSNorm(_RMSNormOp):
) )
kwargs["dtype"] = params_dtype kwargs["dtype"] = params_dtype
# Flag for sequence parallelism (custom Megatron-LM integration)
self.sequence_parallel: Optional[bool] = sequence_parallel
# Initialize RMSNorm operation # Initialize RMSNorm operation
super().__init__( super().__init__(
normalized_shape, normalized_shape,
...@@ -106,11 +109,6 @@ class RMSNorm(_RMSNormOp): ...@@ -106,11 +109,6 @@ class RMSNorm(_RMSNormOp):
**kwargs, **kwargs,
) )
# Flag for sequence parallelism (custom Megatron-LM integration)
self.sequence_parallel: Optional[bool] = sequence_parallel
if sequence_parallel is not None:
self.weight.sequence_parallel = sequence_parallel
def reset_rms_norm_parameters(self) -> None: def reset_rms_norm_parameters(self) -> None:
"""Deprecated""" """Deprecated"""
warnings.warn( warnings.warn(
...@@ -139,7 +137,7 @@ class RMSNorm(_RMSNormOp): ...@@ -139,7 +137,7 @@ class RMSNorm(_RMSNormOp):
super().reset_parameters() super().reset_parameters()
# Flag for sequence parallelism (custom Megatron-LM integration) # Flag for sequence parallelism (custom Megatron-LM integration)
if getattr(self, "sequence_parallel", None) is not None: if self.sequence_parallel is not None:
self.weight.sequence_parallel = self.sequence_parallel self.weight.sequence_parallel = self.sequence_parallel
@property @property
......
...@@ -534,7 +534,9 @@ class BasicLinear(BasicOperation): ...@@ -534,7 +534,9 @@ class BasicLinear(BasicOperation):
# Configure input tensor for backward pass # Configure input tensor for backward pass
if with_quantized_compute and isinstance(x_local, QuantizedTensor): if with_quantized_compute and isinstance(x_local, QuantizedTensor):
x_local.update_usage(rowwise_usage=False, columnwise_usage=True) if not (isinstance(x_local, Float8TensorBase) and with_x_all_gather):
# FP8 does not support all-gather of transpose data
x_local.update_usage(rowwise_usage=False, columnwise_usage=True)
return y, x_local, w return y, x_local, w
...@@ -622,7 +624,10 @@ class BasicLinear(BasicOperation): ...@@ -622,7 +624,10 @@ class BasicLinear(BasicOperation):
# Check datatype # Check datatype
if dtype is None: if dtype is None:
dtype = weight.dtype if weight is not None:
dtype = weight.dtype
else:
dtype = grad_output.dtype
dtype = canonicalize_dtype(dtype) dtype = canonicalize_dtype(dtype)
if dtype not in (torch.float32, torch.float16, torch.bfloat16): if dtype not in (torch.float32, torch.float16, torch.bfloat16):
raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})") raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})")
...@@ -814,7 +819,7 @@ class BasicLinear(BasicOperation): ...@@ -814,7 +819,7 @@ class BasicLinear(BasicOperation):
x_async = None x_async = None
dy_async = None dy_async = None
# Check grad input tensor # Check grad weight tensor
dw = grad_weight dw = grad_weight
dw_dtype = dtype dw_dtype = dtype
if dw is None: if dw is None:
......
...@@ -4,30 +4,27 @@ ...@@ -4,30 +4,27 @@
"""Linear layer backward with Userbuffers communication.""" """Linear layer backward with Userbuffers communication."""
# pylint: skip-file ### TODO Debug Userbuffers support
from __future__ import annotations from __future__ import annotations
from collections.abc import Iterable from typing import Optional
from typing import Any, Optional
import warnings import warnings
import torch import torch
from transformer_engine_torch import CommOverlapAlgo from transformer_engine_torch import CommOverlapType
from ...cpp_extensions import general_gemm from ...cpp_extensions import general_gemm
from ...distributed import get_distributed_world_size from ...distributed import gather_along_first_dim, get_distributed_world_size
from ...float8_tensor import Float8Tensor from ...module.base import (
from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype fill_userbuffers_buffer_for_all_gather,
from ...module.base import get_ub, get_workspace get_ub,
get_workspace,
_2X_ACC_DGRAD,
_2X_ACC_WGRAD,
)
from ...tensor.quantized_tensor import QuantizedTensorBase, Quantizer
from ...tensor.mxfp8_tensor import MXFP8Quantizer
from ...utils import canonicalize_device, canonicalize_dtype, clear_tensor_data from ...utils import canonicalize_device, canonicalize_dtype, clear_tensor_data
from ..basic import BasicLinear, Bias, ReduceScatter from ..basic import BasicLinear, Bias, ReduceScatter
from ..op import FusedOperation, FusibleOperation, OperationContext from ..op import FusedOperation, FusibleOperation, OperationContext
from .._common import (
convert_tensor,
get_fp8_meta_from_fp8_tensor,
is_float8_tensor,
reshape,
)
class UserbuffersBackwardLinear(FusedOperation): class UserbuffersBackwardLinear(FusedOperation):
...@@ -47,9 +44,6 @@ class UserbuffersBackwardLinear(FusedOperation): ...@@ -47,9 +44,6 @@ class UserbuffersBackwardLinear(FusedOperation):
reduce_scatter: Optional[ReduceScatter], reduce_scatter: Optional[ReduceScatter],
) -> None: ) -> None:
### TODO Debug Userbuffers support
raise NotImplementedError("Userbuffers support has been broken by recent refactors")
# Basic operations that comprise this fused operation # Basic operations that comprise this fused operation
op_idxs = {"linear": None, "bias": None, "reduce_scatter": None} op_idxs = {"linear": None, "bias": None, "reduce_scatter": None}
ops = [] ops = []
...@@ -89,9 +83,8 @@ class UserbuffersBackwardLinear(FusedOperation): ...@@ -89,9 +83,8 @@ class UserbuffersBackwardLinear(FusedOperation):
grad_output: torch.Tensor, grad_output: torch.Tensor,
input: Optional[torch.Tensor], # pylint: disable=redefined-builtin input: Optional[torch.Tensor], # pylint: disable=redefined-builtin
weight: Optional[torch.Tensor], weight: Optional[torch.Tensor],
input_dims: Iterable[int],
weight_dims: Iterable[int],
*, *,
input_requires_grad: bool = True,
weight_requires_grad: bool = True, weight_requires_grad: bool = True,
bias_requires_grad: bool = False, bias_requires_grad: bool = False,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
...@@ -102,11 +95,11 @@ class UserbuffersBackwardLinear(FusedOperation): ...@@ -102,11 +95,11 @@ class UserbuffersBackwardLinear(FusedOperation):
tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None, tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
tensor_parallel_size: Optional[int] = None, tensor_parallel_size: Optional[int] = None,
sequence_parallel: bool = False, sequence_parallel: bool = False,
with_fp8_compute: bool = False, with_quantized_compute: bool = False,
input_fp8_meta: Optional[dict[str, Any]] = None, input_quantizer: Optional[Quantizer] = None,
weight_fp8_meta: Optional[dict[str, Any]] = None, weight_quantizer: Optional[Quantizer] = None,
grad_output_fp8_meta: Optional[dict[str, Any]] = None, grad_output_quantizer: Optional[Quantizer] = None,
grad_input_fp8_meta: Optional[dict[str, Any]] = None, grad_input_quantizer: Optional[Quantizer] = None,
ub_comm_name: str, ub_comm_name: str,
) -> tuple[torch.Tensor, Optional[torch.Tensor], dict]: ) -> tuple[torch.Tensor, Optional[torch.Tensor], dict]:
"""Functional API for backward pass """Functional API for backward pass
...@@ -121,10 +114,6 @@ class UserbuffersBackwardLinear(FusedOperation): ...@@ -121,10 +114,6 @@ class UserbuffersBackwardLinear(FusedOperation):
weight: torch.Tensor, optional weight: torch.Tensor, optional
Weight tensor. Required to compute loss gradient w.r.t. Weight tensor. Required to compute loss gradient w.r.t.
input. input.
input_dims: iterable of int
Input tensor dimensions
weight_dims: iterable of int
Weight tensor dimensions
weight_requires_grad: bool weight_requires_grad: bool
Whether to compute loss gradient w.r.t. weight tensor Whether to compute loss gradient w.r.t. weight tensor
bias_requires_grad: bool bias_requires_grad: bool
...@@ -146,21 +135,18 @@ class UserbuffersBackwardLinear(FusedOperation): ...@@ -146,21 +135,18 @@ class UserbuffersBackwardLinear(FusedOperation):
parallelism, i.e. distributing input or output tensors parallelism, i.e. distributing input or output tensors
along outer dimension (sequence or batch dim) when not along outer dimension (sequence or batch dim) when not
distributing along inner dimension (embedding dim) distributing along inner dimension (embedding dim)
with_fp8_compute: bool, default = `False` with_quantized_compute: bool, default = `False`
Whether to perform compute in FP8 Whether to perform compute with quantized data.
input_fp8_meta: dict, optional input_quantizer: Quantizer, optional
FP8 metadata for casting input tensor to FP8. Required for Builder class for quantized input tensor.
FP8 compute if input is not already in FP8. weight_quantizer: Quantizer, optional
weight_fp8_meta: dict, optional Builder class for quantized weight tensor.
FP8 metadata for casting weight tensor to FP8. Required for grad_output_quantizer: Quantizer, optional
FP8 compute if weight is not already in FP8. Builder class for quantized loss gradient w.r.t. output
grad_output_fp8_meta: dict, optional tensor.
FP8 metadata for casting loss gradient w.r.t. output grad_input_quantizer: Quantizer, optional
tensor to FP8. Required if output grad is not already in Builder class for quantized loss gradient w.r.t. input
FP8. tensor.
grad_input_fp8_meta: dict, optional
FP8 metadata for casting loss gradient w.r.t. input
tensor to FP8
ub_comm_name: str ub_comm_name: str
Layer type (e.g. "qkv", "proj", "fc1", "fc2"). This is Layer type (e.g. "qkv", "proj", "fc1", "fc2"). This is
used to access the corresponding Userbuffers communicators used to access the corresponding Userbuffers communicators
...@@ -183,37 +169,24 @@ class UserbuffersBackwardLinear(FusedOperation): ...@@ -183,37 +169,24 @@ class UserbuffersBackwardLinear(FusedOperation):
# Check device # Check device
if device is None: if device is None:
device = weight.device if weight is not None:
device = weight.device
else:
device = grad_output.device
device = canonicalize_device(device) device = canonicalize_device(device)
if device.type != "cuda": if device.type != "cuda":
raise ValueError(f"Only CUDA devices are supported (got {device})") raise ValueError(f"Only CUDA devices are supported (got {device})")
# Check datatype # Check datatype
if dtype is None: if dtype is None:
dtype = weight.dtype if weight is not None:
dtype = weight.dtype
else:
dtype = grad_output.dtype
dtype = canonicalize_dtype(dtype) dtype = canonicalize_dtype(dtype)
if dtype not in (torch.float32, torch.float16, torch.bfloat16): if dtype not in (torch.float32, torch.float16, torch.bfloat16):
raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})") raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})")
# Input tensor dims
output_dims = tuple(grad_output.size())
input_dims = tuple(input_dims)
weight_dims = tuple(weight_dims)
if len(weight_dims) != 2:
raise ValueError(f"Weight tensor is not 2D (shape={weight_dims})")
if len(input_dims) == 0 or weight_dims[1] != input_dims[-1]:
raise ValueError(
f"Input tensor (shape={input_dims}) "
f"and weight tensor (shape={weight_dims}) "
"are not compatible"
)
if weight_dims[0] != output_dims[-1]:
raise ValueError(
f"Grad output tensor (shape={output_dims}) "
f"and weight tensor (shape={weight_dims}) "
"are not compatible"
)
# Check tensor parallel group # Check tensor parallel group
if tensor_parallel_size is None: if tensor_parallel_size is None:
tensor_parallel_size = get_distributed_world_size(tensor_parallel_group) tensor_parallel_size = get_distributed_world_size(tensor_parallel_group)
...@@ -227,373 +200,283 @@ class UserbuffersBackwardLinear(FusedOperation): ...@@ -227,373 +200,283 @@ class UserbuffersBackwardLinear(FusedOperation):
if not sequence_parallel: if not sequence_parallel:
raise RuntimeError(f"Invalid configuration for Userbuffers ({sequence_parallel=})") raise RuntimeError(f"Invalid configuration for Userbuffers ({sequence_parallel=})")
# Check if FP8 is enabled # dgrad GEMM is required
if with_fp8_compute: if not input_requires_grad:
if grad_output_fp8_meta is None and not is_float8_tensor(grad_output): warnings.warn(
raise ValueError("No FP8 metadata was provided for casting output gradient to FP8") "Linear input doesn't require gradient, "
"but Userbuffers implementation requires dgrad GEMM."
)
input_requires_grad = True
# Check quantizers
if with_quantized_compute:
if weight_requires_grad and input_quantizer is None:
raise ValueError("Missing quantizer for input tensor")
if input_requires_grad and weight_quantizer is None:
raise ValueError("Missing quantizer for weight tensor")
if grad_output_quantizer is None:
raise ValueError("Missing quantizer for grad output tensor")
if grad_input_quantizer is not None:
raise ValueError("Quantized grad input is not supported")
else: else:
input_fp8_meta = None input_quantizer = None
weight_fp8_meta = None weight_quantizer = None
grad_output_fp8_meta = None grad_output_quantizer = None
grad_input_fp8_meta = None grad_input_quantizer = None
with_fp8_grad_input = (
with_fp8_compute
and tensor_parallel_mode != "column"
and grad_input_fp8_meta is not None
)
# Get Userbuffers communicators and algorithms # Get Userbuffers communicators
# Note: communication patterns are (1) overlap dy all-gather # Note: Communication patterns are (1) overlap dy all-gather
# with dgrad GEMM, (2) overlap x all-gather with dgrad GEMM # with dgrad GEMM, (2) overlap x all-gather with dgrad GEMM
# and dx reduce-scatter with wgrad GEMM, (3) overlap dx # and dx reduce-scatter with wgrad GEMM, (3) overlap dx
# reduce-scatter with dgrad GEMM. # reduce-scatter with dgrad GEMM
with_ub_all_gather_dy = False ub_comm_dgrad = None
with_ub_reduce_scatter_dx = False ub_comm_wgrad = None
with_ub_all_gather_x = False ub_type_dgrad = None
ub_comm_dy = None ub_type_wgrad = None
ub_comm_dx = None with_bulk_overlap = False
ub_comm_x = None with_dgrad_all_gather_dy = False
ub_algo_dy = None with_dgrad_reduce_scatter_dx = False
ub_algo_dx = None with_dgrad_all_gather_x = False
ub_algo_x = None with_wgrad_reduce_scatter_dx = False
if tensor_parallel_mode == "row": if tensor_parallel_mode == "row":
with_ub_all_gather_dy = True ub_comm_dgrad = get_ub(ub_comm_name + "_dgrad")
ub_comm_dy = get_ub(ub_comm_name + "_dgrad") ub_type_dgrad = CommOverlapType.AG
if with_fp8_compute and ub_comm_dy.is_atomic_gemm(): with_dgrad_all_gather_dy = True
ub_algo_dy = CommOverlapAlgo.ATOMIC_GEMM_AG_P2P
else:
ub_algo_dy = CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P
elif tensor_parallel_mode == "column": elif tensor_parallel_mode == "column":
with_ub_reduce_scatter_dx = True if input_requires_grad and weight_requires_grad:
if weight_requires_grad: with_bulk_overlap = True
with_ub_all_gather_x = True ub_comm_dgrad = get_ub(ub_comm_name + "_dgrad")
ub_comm_dx = get_ub(ub_comm_name + "_wgrad") ub_type_dgrad = CommOverlapType.AG
ub_comm_x = get_ub(ub_comm_name + "_dgrad") with_dgrad_all_gather_x = True
ub_algo_dx = CommOverlapAlgo.BULK_OVERLAP_RS ub_comm_wgrad = get_ub(ub_comm_name + "_wgrad")
ub_algo_x = CommOverlapAlgo.BULK_OVERLAP_AG ub_type_wgrad = CommOverlapType.RS
with_wgrad_reduce_scatter_dx = True
if ub_comm_wgrad.is_fp8_ubuf():
raise RuntimeError(
"Userbuffers reduce-scatter is not supported with FP8 buffers"
)
else: else:
with_ub_all_gather_x = False ub_comm_dgrad = get_ub(ub_comm_name + "_dgrad")
ub_comm_dx = get_ub(ub_comm_name + "_dgrad") ub_type_dgrad = CommOverlapType.RS
is_atomic_gemm = with_fp8_compute and ub_comm_dx.is_atomic_gemm() with_dgrad_reduce_scatter_dx = True
ub_algo_dx = { if ub_comm_dgrad.is_fp8_ubuf():
(True, True): CommOverlapAlgo.ATOMIC_GEMM_RS_P2P, raise RuntimeError(
(True, False): CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P, "Userbuffers reduce-scatter is not supported with FP8 buffers"
(False, True): CommOverlapAlgo.ATOMIC_GEMM_RS, )
(False, False): CommOverlapAlgo.SPLIT_PIPELINED_RS,
}[(ub_comm_dx.is_p2p_overlap(), is_atomic_gemm)] # Compute grad bias if needed
# Check grad output tensor
# Note: Possibly fuse cast with computing grad bias
dy_local = reshape(
grad_output,
(-1, output_dims[-1]),
device=device,
dtype=dtype,
)
db = None db = None
db_async = None db_async = None
if bias_requires_grad and with_fp8_compute and with_ub_all_gather_dy: if bias_requires_grad:
# We don't have a grad bias impl that takes FP8 input. For db = grad_output.sum(tuple(range(grad_output.dim() - 1)))
# cases where we cast to FP8 and all-gather, it's better if tensor_parallel_mode == "row":
# to compute the grad bias on ungathered, non-FP8 values. db_async = torch.distributed.all_reduce(
db = dy_local.sum(dim=0) db,
db_async = torch.distributed.all_reduce( group=tensor_parallel_group,
db, async_op=True,
group=tensor_parallel_group,
async_op=True,
)
if with_fp8_compute and not is_float8_tensor(dy_local):
fp8_dtype = get_fp8_te_dtype(
grad_output_fp8_meta["recipe"],
fprop_tensor=False,
)
if bias_requires_grad and db is None:
# Fused cast-transpose-bgrad
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=False)
fp8_scale_inv = torch.empty([1], dtype=torch.float32, device=device)
db, data, data_transpose = fp8_cast_transpose_bgrad_fused(
dy_local,
grad_output_fp8_meta[fp8_meta_key],
0,
fp8_dtype,
scale_inv=fp8_scale_inv,
)
if with_ub_all_gather_dy:
data = ub_comm_dy.get_ubuf_output(0).copy_(data)
dy_local = Float8Tensor(
data=data,
fp8_meta=grad_output_fp8_meta,
fp8_meta_forward=False,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
fp8_scale_inv=fp8_scale_inv,
dtype=dtype,
data_transpose=data_transpose,
) )
else:
dy_local = Float8Tensor.to_float8( # Cast grad output tensor dtype if needed
dy_local, dy_local = grad_output
fp8_meta=grad_output_fp8_meta, if with_quantized_compute:
fp8_meta_forward=False, if not isinstance(dy_local, QuantizedTensorBase):
fp8_meta_index=0, with_columnwise = weight_requires_grad
fp8_dtype=fp8_dtype, if (
data=(ub_comm_dy.get_ubuf_output(0) if with_ub_all_gather_dy else None), with_columnwise
with_transpose_cache=(not with_ub_all_gather_dy), and with_dgrad_all_gather_dy
and not isinstance(grad_output_quantizer, MXFP8Quantizer)
):
with_columnwise = False
grad_output_quantizer.set_usage(
rowwise=True,
columnwise=with_columnwise,
) )
elif not with_fp8_compute and is_float8_tensor(dy_local): dy_local = grad_output_quantizer(dy_local)
if with_ub_all_gather_dy: else:
ub_local_buffer = ub_comm_dy.get_ubuf_output(0) if isinstance(dy_local, QuantizedTensorBase):
dy_local = ub_local_buffer.copy_(dy_local) dy_local = dy_local.dequantize(dtype=dtype)
else: elif dy_local.dtype != dtype:
dy_local = dy_local.dequantize() dy_local = dy_local.to(dtype=dtype)
if bias_requires_grad and db is None and with_fp8_compute and with_ub_all_gather_dy: # Cast weight tensor dtype if needed
# We don't have a fused grad bias impl that takes FP8 if weight is None:
# input. For cases where we cast to FP8 and all-gather, raise ValueError("Weight tensor is required to compute input grad")
# it's better to compute the grad bias on ungathered, w = weight
# non-FP8 values. if with_quantized_compute:
db = dy_local.sum(dim=0) if not isinstance(w, QuantizedTensorBase):
db_async = torch.distributed.all_reduce( weight_quantizer.set_usage(columnwise=True)
db, w = weight_quantizer(w)
group=tensor_parallel_group, else:
async_op=True, if isinstance(w, QuantizedTensorBase):
) w = w.dequantize(dtype=dtype)
elif w.dtype != dtype:
w = w.to(dtype=dtype)
# Check input tensor # Cast input tensor dtype if needed
x_local = None x_local = None
if weight_requires_grad: if weight_requires_grad:
x_local = reshape( if input is None:
input, raise ValueError("Input tensor is required to compute weight grad")
(-1, input_dims[-1]), x_local = input
device=device, if with_quantized_compute:
dtype=dtype, if not isinstance(x_local, QuantizedTensorBase):
) input_quantizer.set_usage(columnwise=True)
if with_fp8_compute and not is_float8_tensor(x_local): x_local = input_quantizer(x_local)
fp8_dtype = get_fp8_te_dtype( else:
input_fp8_meta["recipe"], if isinstance(x_local, QuantizedTensorBase):
fprop_tensor=True, x_local = x_local.dequantize(dtype=dtype)
elif x_local.dtype != dtype:
x_local = x_local.to(dtype=dtype)
# dgrad GEMM
dx_local = None
dx = None
dy = None
x = None
if input_requires_grad:
# Initialize grad output
if with_dgrad_all_gather_dy:
if grad_output_quantizer is not None:
grad_output_quantizer.set_usage(rowwise=True, columnwise=False)
dy, _ = fill_userbuffers_buffer_for_all_gather(
ub_comm_dgrad,
dy_local,
grad_output_quantizer,
tensor_parallel_group,
) )
x_local = Float8Tensor.to_float8( else:
dy = dy_local
# Construct grad input tensor if needed
if with_dgrad_reduce_scatter_dx or with_wgrad_reduce_scatter_dx:
dx_size = list(dy.size())
dx_size[-1] = w.size(-1)
dx_local_size = list(dx_size)
dx_local_size[0] //= tensor_parallel_size
if with_dgrad_reduce_scatter_dx:
dx_local = torch.empty(
dx_local_size,
dtype=dtype,
device=device,
)
elif with_wgrad_reduce_scatter_dx:
dx_local = ub_comm_wgrad.get_buffer(
local_chunk=True,
shape=dx_local_size,
)
dx = ub_comm_wgrad.get_buffer(
local_chunk=False,
shape=dx_size,
)
# Initialize input tensor if needed
if with_dgrad_all_gather_x:
if input_quantizer is not None:
input_quantizer.set_usage(rowwise=False, columnwise=True)
x, _ = fill_userbuffers_buffer_for_all_gather(
ub_comm_dgrad,
x_local, x_local,
fp8_meta=input_fp8_meta, input_quantizer,
fp8_meta_forward=True, tensor_parallel_group,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
data=(ub_comm_x.get_ubuf_output(0) if with_ub_all_gather_x else None),
with_transpose_cache=(not with_ub_all_gather_x),
) )
elif not with_fp8_compute and is_float8_tensor(x_local):
if with_ub_all_gather_x: # Perform dgrad GEMM
ub_local_buffer = ub_comm_x.get_ubuf_output(0) dx, *_ = general_gemm(
x_local = ub_local_buffer.copy_(x_local)
else:
x_local = x_local.dequantize()
# Check weight tensor
w = convert_tensor(
weight,
device=device,
dtype=dtype,
memory_format=torch.contiguous_format,
)
if with_fp8_compute and not is_float8_tensor(w):
fp8_dtype = get_fp8_te_dtype(
weight_fp8_meta["recipe"],
fprop_tensor=True,
)
w = Float8Tensor.to_float8(
w, w,
fp8_meta=weight_fp8_meta, dy,
fp8_meta_forward=True, get_workspace(),
fp8_meta_index=0, out_dtype=dtype,
fp8_dtype=fp8_dtype, quantization_params=grad_input_quantizer,
with_transpose_cache=True, layout="NN",
out=dx,
use_split_accumulator=_2X_ACC_DGRAD,
grad=True,
ub=ub_comm_dgrad,
ub_type=ub_type_dgrad,
extra_output=dx_local if with_dgrad_reduce_scatter_dx else None,
bulk_overlap=with_bulk_overlap,
) )
elif not with_fp8_compute and is_float8_tensor(w): if not (with_dgrad_reduce_scatter_dx or with_wgrad_reduce_scatter_dx):
w = w.dequantize() dx_local = dx
# Initialize buffers for UB all-gather if needed
dy = dy_local
x = x_local
if with_ub_all_gather_dy:
ub_local_buffer = ub_comm_dy.get_ubuf_output(0)
ub_global_buffer = ub_comm_dy.get_ubuf_output(1)
if with_fp8_compute:
dy = Float8Tensor.make_like(dy_local, data=ub_global_buffer)
if dy_local._data.data_ptr() != ub_local_buffer.data_ptr():
ub_local_buffer.copy_(dy_local._data)
else:
dy = ub_global_buffer
if dy_local.data_ptr() != ub_local_buffer.data_ptr():
ub_local_buffer.copy_(dy_local)
if with_ub_all_gather_x:
ub_local_buffer = ub_comm_x.get_ubuf_output(0)
ub_global_buffer = ub_comm_x.get_ubuf_output(1)
if with_fp8_compute:
x = Float8Tensor.make_like(x_local, data=ub_global_buffer)
if x_local._data.data_ptr() != ub_local_buffer.data_ptr():
ub_local_buffer.copy_(x_local._data)
else:
x = ub_global_buffer
if x_local.data_ptr() != ub_local_buffer.data_ptr():
ub_local_buffer.copy_(x_local)
# Construct grad input tensor # wgrad GEMM
dx = None dw = None
dx_local = None if weight_requires_grad:
if with_ub_reduce_scatter_dx:
# Initialize buffers for UB reduce-scatter # Initialize grad output
dx = ub_comm_dx.get_ubuf_output(1) if tensor_parallel_mode == "row" and isinstance(grad_output_quantizer, MXFP8Quantizer):
ub_local_buffer = ub_comm_dx.get_ubuf_output(0) # UB does not support overlapping grad output
if with_ub_all_gather_x: # all-gather with wgrad GEMM. Also, MXFP8 does not
dx_local = ub_local_buffer # allow reusing the grad output that was gathered for
else: # the dgrad GEMM. We work around with blocking
dx_local = torch.empty_like(ub_local_buffer) # all-gather for column-scaled MXFP8 data.
else: grad_output_quantizer.set_usage(rowwise=False, columnwise=True)
# Allocate grad input tensor dy, _ = gather_along_first_dim(
if with_fp8_grad_input: grad_output,
fp8_dtype = get_fp8_te_dtype( tensor_parallel_group,
grad_input_fp8_meta["recipe"], quantizer=grad_output_quantizer,
fprop_tensor=False,
)
data = torch.empty(
(dy.size(0), w.size(-1)),
dtype=torch.uint8,
device=device,
)
dx = Float8Tensor(
data=data,
fp8_meta=grad_input_fp8_meta,
fp8_meta_forward=False,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
dtype=dtype,
) )
else: if tensor_parallel_mode == "column":
dx = torch.empty( dy = dy_local
(dy.size(0), w.size(-1)), if dy is None:
dtype=dtype, raise RuntimeError(
device=device, "wgrad GEMM requires grad output tensor, which has not been initialized"
) )
dx_local = dx if isinstance(dy, QuantizedTensorBase):
dy.update_usage(rowwise_usage=False, columnwise_usage=True)
# Allocate grad input tensor # Initialize input tensor
if grad_weight is None: if tensor_parallel_mode == "row":
if accumulate_into_grad_weight: x = x_local
raise ValueError( if x is None:
"Attempted to accumulate into grad weight bufferwithout providing grad weight" raise RuntimeError(
"wgrad GEMM requires input tensor, which has not been initialized"
) )
grad_weight = torch.empty( if isinstance(x, QuantizedTensorBase):
weight_dims, x.update_usage(rowwise_usage=False, columnwise_usage=True)
dtype=dtype,
device=device, # Check grad weight tensor
memory_format=torch.contiguous_format, dw = grad_weight
) dw_dtype = dtype
if dw is None:
if accumulate_into_grad_weight:
raise ValueError(
"Attempted to accumulate into grad weight tensor "
"without providing grad weight tensor"
)
else:
dw_dtype = dw.dtype
# Perform dgrad GEMM # Perform wgrad GEMM
if with_fp8_compute: dw, *_ = general_gemm(
kwargs = {"out": dx, "use_split_accumulator": True}
if with_ub_all_gather_dy:
kwargs["ub_algo"] = ub_algo_dy
kwargs["ub"] = ub_comm_dy
elif with_ub_all_gather_x:
kwargs["ub_algo"] = ub_algo_x
kwargs["ub"] = ub_comm_x
elif with_ub_reduce_scatter_dx:
kwargs["ub_algo"] = ub_algo_dx
kwargs["ub"] = ub_comm_dx
kwargs["extra_output_tensor"] = dx_local
if with_fp8_grad_input:
fp8_meta, fp8_meta_index = get_fp8_meta_from_fp8_tensor(dx)
kwargs.update(
{
"out": dx._data,
"out_index": fp8_meta_index,
"fp8_meta_tensor": fp8_meta,
"D_dtype": dx._fp8_dtype,
}
)
fp8_gemm(
w.transpose_2d(),
w._scale_inv,
0,
w._fp8_dtype,
dy._data,
dy._scale_inv,
0,
dy._fp8_dtype,
dy.dtype,
get_workspace(),
**kwargs,
)
else:
kwargs = {"grad": True, "layout": "NN", "out": dx}
if with_ub_all_gather_dy:
kwargs["ub_algo"] = ub_algo_dy
kwargs["ub"] = ub_comm_dy
elif with_ub_all_gather_x:
kwargs["ub_algo"] = ub_algo_x
kwargs["ub"] = ub_comm_x
elif with_ub_reduce_scatter_dx:
kwargs["ub_algo"] = ub_algo_dx
kwargs["ub"] = ub_comm_dx
kwargs["extra_output_tensor"] = dx_local
gemm(w, dy, dx.dtype, get_workspace(), **kwargs)
grad_input = reshape(dx_local, input_dims)
# Perform wgrad GEMM
if not weight_requires_grad:
pass
elif with_fp8_compute:
kwargs = {
"accumulate": accumulate_into_grad_weight,
"out": grad_weight,
"use_split_accumulator": True,
}
if with_ub_reduce_scatter_dx:
kwargs["ub_algo"] = ub_algo_dx
kwargs["ub"] = ub_comm_dx
fp8_gemm(
x.transpose_2d(),
x._scale_inv,
0,
x._fp8_dtype,
dy.transpose_2d(),
dy._scale_inv,
0,
dy._fp8_dtype,
grad_weight.dtype,
get_workspace(),
**kwargs,
)
else:
kwargs = {
"accumulate": accumulate_into_grad_weight,
"layout": "NT",
"grad": True,
"use_bias": bias_requires_grad,
"out": grad_weight,
}
if with_ub_reduce_scatter_dx:
kwargs["ub_algo"] = ub_algo_dx
kwargs["ub"] = ub_comm_dx
grad_weight, db, _ = gemm(
x, x,
dy, dy,
grad_weight.dtype,
get_workspace(), get_workspace(),
**kwargs, out_dtype=dw_dtype,
accumulate=accumulate_into_grad_weight,
layout="NT",
out=dw,
use_split_accumulator=_2X_ACC_WGRAD,
grad=True,
ub=ub_comm_wgrad,
ub_type=ub_type_wgrad,
bulk_overlap=with_bulk_overlap,
) )
# Bulk overlap reduce-scatter with non-FP8 buffer is
# in-place. Need to copy grad input tensor to avoid data
# corruption in Userbuffers buffer.
if with_wgrad_reduce_scatter_dx:
dx_local = dx_local.clone()
# Compute grad bias if needed # Compute grad bias if needed
if db_async is not None: if db_async is not None:
db_async.wait() db_async.wait()
if bias_requires_grad: if bias_requires_grad:
if db is None:
db = dy.sum(dim=0)
extra_outputs["grad_bias"] = db extra_outputs["grad_bias"] = db
return grad_input, grad_weight, extra_outputs return dx_local, dw, extra_outputs
def fuser_backward( def fuser_backward(
self, self,
...@@ -633,40 +516,24 @@ class UserbuffersBackwardLinear(FusedOperation): ...@@ -633,40 +516,24 @@ class UserbuffersBackwardLinear(FusedOperation):
else: else:
accumulate_into_main_grad = False accumulate_into_main_grad = False
# Hackily workaround Userbuffers bug with non-FP8 dgrad
# reduce-scatter overlap
weight_requires_grad = linear_op_ctx.weight_requires_grad
if not linear_op_ctx.with_fp8_compute and not weight_requires_grad:
warnings.warn(
"There is a correctness bug when using Userbuffers "
"to overlap a dgrad reduce-scatter with a non-FP8 dgrad GEMM. "
"Hackily working around by overlapping dgrad reduce-scatter "
"with wgrad GEMM, even though wgrad isn't needed. "
"Please contact Transformer Engine team "
"if you encounter this use-case."
)
weight_requires_grad = True
# Linear backward pass # Linear backward pass
retval = UserbuffersBackwardLinear._functional_backward( retval = UserbuffersBackwardLinear._functional_backward(
grad_output=grad_output, grad_output=grad_output,
input=x_local, input=x_local,
weight=linear_op.weight, weight=linear_op.weight,
input_dims=linear_op_ctx.input_dims, weight_requires_grad=linear_op_ctx.weight_requires_grad,
weight_dims=linear_op.weight.size(),
weight_requires_grad=weight_requires_grad,
bias_requires_grad=(bias_op is not None), bias_requires_grad=(bias_op is not None),
device=linear_op.device,
dtype=linear_op_ctx.dtype, dtype=linear_op_ctx.dtype,
grad_weight=grad_weight, grad_weight=grad_weight,
accumulate_into_grad_weight=accumulate_into_main_grad, accumulate_into_grad_weight=accumulate_into_main_grad,
tensor_parallel_mode=self.tensor_parallel_mode, tensor_parallel_mode=self.tensor_parallel_mode,
tensor_parallel_group=self.tensor_parallel_group, tensor_parallel_group=self.tensor_parallel_group,
sequence_parallel=self.sequence_parallel, sequence_parallel=self.sequence_parallel,
with_fp8_compute=linear_op_ctx.with_fp8_compute, with_quantized_compute=linear_op_ctx.with_quantized_compute,
weight_fp8_meta=linear_op_ctx.weight_fp8_meta, input_quantizer=linear_op_ctx.input_quantizer,
grad_output_fp8_meta=linear_op_ctx.grad_output_fp8_meta, weight_quantizer=linear_op_ctx.weight_quantizer,
grad_input_fp8_meta=linear_op_ctx.grad_input_fp8_meta, grad_output_quantizer=linear_op_ctx.grad_output_quantizer,
grad_input_quantizer=None, # Not supported
ub_comm_name=linear_op._userbuffers_options["comm_name"], ub_comm_name=linear_op._userbuffers_options["comm_name"],
) )
grad_input, grad_weight, extra_outputs = retval grad_input, grad_weight, extra_outputs = retval
...@@ -707,8 +574,6 @@ def fuse_userbuffers_backward_linear( ...@@ -707,8 +574,6 @@ def fuse_userbuffers_backward_linear(
""" """
return ops ### TODO Debug Userbuffers support
# Return immediately if environment is not distributed # Return immediately if environment is not distributed
if not torch.distributed.is_initialized() or torch.distributed.get_world_size() == 1: if not torch.distributed.is_initialized() or torch.distributed.get_world_size() == 1:
return ops return ops
......
...@@ -4,20 +4,25 @@ ...@@ -4,20 +4,25 @@
"""Linear layer forward with Userbuffers communication.""" """Linear layer forward with Userbuffers communication."""
# pylint: skip-file ### TODO Debug Userbuffers support
from __future__ import annotations from __future__ import annotations
from collections.abc import Iterable from collections.abc import Iterable
from typing import Any, Optional from typing import Any, Optional
import torch import torch
from transformer_engine_torch import CommOverlapAlgo from transformer_engine_torch import CommOverlapType
from ...cpp_extensions import general_gemm from ...cpp_extensions import general_gemm
from ...distributed import get_distributed_world_size from ...distributed import get_distributed_world_size
from ...float8_tensor import Float8Tensor from ...fp8 import FP8GlobalStateManager
from ...fp8 import FP8GlobalStateManager, get_fp8_te_dtype from ...module.base import (
from ...module.base import get_ub, get_workspace fill_userbuffers_buffer_for_all_gather,
get_ub,
get_workspace,
_2X_ACC_FPROP,
)
from ...tensor.quantized_tensor import QuantizedTensorBase, Quantizer
from ...tensor.float8_tensor import Float8Quantizer
from ...tensor._internal.float8_tensor_base import Float8TensorBase
from ...utils import canonicalize_device, canonicalize_dtype from ...utils import canonicalize_device, canonicalize_dtype
from ..basic import BasicLinear, Bias, ReduceScatter from ..basic import BasicLinear, Bias, ReduceScatter
from ..op import ( from ..op import (
...@@ -26,12 +31,6 @@ from ..op import ( ...@@ -26,12 +31,6 @@ from ..op import (
FusibleOperation, FusibleOperation,
OperationContext, OperationContext,
) )
from .._common import (
convert_tensor,
get_fp8_meta_from_fp8_tensor,
is_float8_tensor,
reshape,
)
class UserbuffersForwardLinear(FusedOperation): class UserbuffersForwardLinear(FusedOperation):
...@@ -51,9 +50,6 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -51,9 +50,6 @@ class UserbuffersForwardLinear(FusedOperation):
reduce_scatter: Optional[ReduceScatter], reduce_scatter: Optional[ReduceScatter],
) -> None: ) -> None:
### TODO Debug Userbuffers support
raise NotImplementedError("Userbuffers support has been broken by recent refactors")
# Basic operations that comprise this fused operation # Basic operations that comprise this fused operation
op_idxs = {"linear": 0, "bias": None, "reduce_scatter": None} op_idxs = {"linear": 0, "bias": None, "reduce_scatter": None}
ops = [linear] ops = [linear]
...@@ -98,10 +94,10 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -98,10 +94,10 @@ class UserbuffersForwardLinear(FusedOperation):
tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None, tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
tensor_parallel_size: Optional[int] = None, tensor_parallel_size: Optional[int] = None,
sequence_parallel: bool = False, sequence_parallel: bool = False,
with_fp8_compute: bool = False, with_quantized_compute: bool = False,
input_fp8_meta: Optional[dict[str, Any]] = None, input_quantizer: Optional[Quantizer] = None,
weight_fp8_meta: Optional[dict[str, Any]] = None, weight_quantizer: Optional[Quantizer] = None,
output_fp8_meta: Optional[dict[str, Any]] = None, output_quantizer: Optional[Quantizer] = None,
ub_comm_name: str, ub_comm_name: str,
) -> tuple[torch.Tensor, dict]: ) -> tuple[torch.Tensor, dict]:
"""Functional API for forward pass """Functional API for forward pass
...@@ -127,16 +123,14 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -127,16 +123,14 @@ class UserbuffersForwardLinear(FusedOperation):
parallelism, i.e. distributing input or output tensors parallelism, i.e. distributing input or output tensors
along outer dimension (sequence or batch dim) when not along outer dimension (sequence or batch dim) when not
distributing along inner dimension (embedding dim) distributing along inner dimension (embedding dim)
with_fp8_compute: bool, default = `False` with_quantized_compute: bool, default = `False`
Whether to perform compute in FP8 Whether to perform compute with quantized data.
input_fp8_meta: dict, optional input_quantizer: Quantizer, optional
FP8 metadata for casting input tensor to FP8. Required for Builder class for quantized input tensor.
FP8 compute if input is not already in FP8. weight_quantizer: Quantizer, optional
weight_fp8_meta: dict, optional Builder class for quantized weight tensor.
FP8 metadata for casting weight tensor to FP8. Required for output_quantizer: Quantizer, optional
FP8 compute if weight is not already in FP8. Builder class for quantized output tensor.
output_fp8_meta: dict, optional
FP8 metadata for casting output tensor to FP8
ub_comm_name: str ub_comm_name: str
Layer type (e.g. "qkv", "proj", "fc1", "fc2"). This is Layer type (e.g. "qkv", "proj", "fc1", "fc2"). This is
used to access the corresponding Userbuffers communicators used to access the corresponding Userbuffers communicators
...@@ -166,23 +160,6 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -166,23 +160,6 @@ class UserbuffersForwardLinear(FusedOperation):
if dtype not in (torch.float32, torch.float16, torch.bfloat16): if dtype not in (torch.float32, torch.float16, torch.bfloat16):
raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})") raise ValueError(f"Supported dtypes are float32, float16, bfloat16 (got {dtype})")
# Input tensor dims
input_dims = tuple(input.size())
weight_dims = tuple(weight.size())
if len(weight_dims) != 2:
raise ValueError(f"Weight tensor is not 2D (shape={weight_dims})")
if len(input_dims) == 0 or weight_dims[1] != input_dims[-1]:
raise ValueError(
f"Input tensor (shape={input_dims}) "
f"and weight tensor (shape={weight_dims}) "
"are not compatible"
)
# Output tensor dims
output_dims = list(input_dims)
output_dims[0] = -1
output_dims[-1] = weight_dims[0]
# Check tensor parallel group # Check tensor parallel group
if tensor_parallel_size is None: if tensor_parallel_size is None:
tensor_parallel_size = get_distributed_world_size(tensor_parallel_group) tensor_parallel_size = get_distributed_world_size(tensor_parallel_group)
...@@ -196,235 +173,106 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -196,235 +173,106 @@ class UserbuffersForwardLinear(FusedOperation):
if not sequence_parallel: if not sequence_parallel:
raise RuntimeError(f"Invalid configuration for Userbuffers ({sequence_parallel=})") raise RuntimeError(f"Invalid configuration for Userbuffers ({sequence_parallel=})")
# Check if FP8 is enabled # Check quantizers
if with_fp8_compute: if with_quantized_compute:
if input_fp8_meta is None and not is_float8_tensor(input): if input_quantizer is None:
raise ValueError("No FP8 metadata was provided for casting input to FP8") raise ValueError("Missing quantizer for input tensor")
if weight_fp8_meta is None and not is_float8_tensor(weight): if weight_quantizer is None:
raise ValueError("No FP8 metadata was provided for casting weight to FP8") raise ValueError("Missing quantizer for weight tensor")
if output_quantizer is not None:
raise ValueError("FP8 output is not supported")
else: else:
input_fp8_meta = None input_quantizer = None
weight_fp8_meta = None weight_quantizer = None
output_fp8_meta = None output_quantizer = None
with_fp8_output = (
with_fp8_compute and tensor_parallel_mode != "row" and output_fp8_meta is not None
)
# Get Userbuffers communicator # Get Userbuffers communicator
ub_comm = get_ub(ub_comm_name + "_fprop") ub_comm = get_ub(ub_comm_name + "_fprop")
ub_local_buffer = ub_comm.get_ubuf_output(0)
ub_global_buffer = ub_comm.get_ubuf_output(1)
with_ub_all_gather = tensor_parallel_mode == "column" with_ub_all_gather = tensor_parallel_mode == "column"
with_ub_reduce_scatter = tensor_parallel_mode == "row" with_ub_reduce_scatter = tensor_parallel_mode == "row"
ub_type = CommOverlapType.AG if with_ub_all_gather else CommOverlapType.RS
# Choose Userbuffers communication algorithm # Initialize input tensor
ub_algo = None x_local = input
x = None
if with_ub_all_gather: if with_ub_all_gather:
if with_fp8_compute and ub_comm.is_atomic_gemm(): if input_quantizer is not None:
ub_algo = CommOverlapAlgo.ATOMIC_GEMM_AG_P2P if not isinstance(x_local, QuantizedTensorBase):
else: input_quantizer.set_usage(rowwise=True, columnwise=True)
ub_algo = CommOverlapAlgo.SPLIT_PIPELINED_AG_P2P if isinstance(input_quantizer, Float8Quantizer):
elif with_ub_reduce_scatter: input_quantizer.set_usage(columnwise=False)
is_atomic_gemm = with_fp8_compute and ub_comm.is_atomic_gemm() x_local = input_quantizer(x_local)
ub_algo = { input_quantizer.set_usage(rowwise=True, columnwise=False)
(True, True): CommOverlapAlgo.ATOMIC_GEMM_RS_P2P, x, x_local = fill_userbuffers_buffer_for_all_gather(
(True, False): CommOverlapAlgo.SPLIT_PIPELINED_RS_P2P, ub_comm,
(False, True): CommOverlapAlgo.ATOMIC_GEMM_RS,
(False, False): CommOverlapAlgo.SPLIT_PIPELINED_RS,
}[(ub_comm.is_p2p_overlap(), is_atomic_gemm)]
else:
raise RuntimeError("Could not choose Userbuffers communication algorithm")
# Cast input tensor to correct dtype
x_local = reshape(
input,
(-1, input_dims[-1]),
device=device,
dtype=dtype,
)
if with_fp8_compute and not is_float8_tensor(x_local):
fp8_dtype = get_fp8_te_dtype(
input_fp8_meta["recipe"],
fprop_tensor=True,
)
with_transpose_cache = weight.requires_grad
if tensor_parallel_mode == "column" and sequence_parallel:
with_transpose_cache = False
x_local = Float8Tensor.to_float8(
x_local, x_local,
fp8_meta=input_fp8_meta, input_quantizer,
fp8_meta_forward=True, tensor_parallel_group,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
data=(ub_local_buffer if with_ub_all_gather else None),
with_transpose_cache=with_transpose_cache,
) )
elif not with_fp8_compute and is_float8_tensor(x_local): else:
if with_ub_all_gather: if with_quantized_compute:
x_local = ub_local_buffer.copy_(x_local) if not isinstance(x_local, QuantizedTensorBase):
else: input_quantizer.set_usage(rowwise=True, columnwise=True)
x_local = x_local.dequantize() x_local = input_quantizer(x_local)
# Initialize buffers for UB all-gather if needed
x = x_local
if with_ub_all_gather:
if with_fp8_compute:
x = Float8Tensor.make_like(x_local, data=ub_global_buffer)
if x_local._data.data_ptr() != ub_local_buffer.data_ptr():
ub_local_buffer.copy_(x_local._data)
else:
x_local._data = torch.empty_like(x_local._data)
else: else:
x = ub_global_buffer if isinstance(x_local, QuantizedTensorBase):
if x_local.data_ptr() != ub_local_buffer.data_ptr(): x_local = x_local.dequantize(dtype=dtype)
ub_local_buffer.copy_(x_local) if x_local.dtype != dtype:
else: x_local = x_local.to(dtype=dtype)
x_local = torch.empty_like(x_local) x = x_local
# Check weight tensor # Initialize weight tensor
w = convert_tensor( w = weight
weight, w_is_quantized = isinstance(w, QuantizedTensorBase)
device=device, if with_quantized_compute and not w_is_quantized:
dtype=dtype, weight_quantizer.set_usage(rowwise=True)
memory_format=torch.contiguous_format, w = weight_quantizer(w)
) elif not with_quantized_compute and w_is_quantized:
if with_fp8_compute and not is_float8_tensor(w):
fp8_dtype = get_fp8_te_dtype(
weight_fp8_meta["recipe"],
fprop_tensor=True,
)
w = Float8Tensor.to_float8(
w,
fp8_meta=weight_fp8_meta,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
)
elif not with_fp8_compute and is_float8_tensor(w):
w = w.dequantize() w = w.dequantize()
if not with_quantized_compute and w.dtype != dtype:
w = w.to(dtype=dtype)
# Check bias tensor # Construct output tensor if needed
b = None reduce_scatter_output = None
if bias is not None:
b = convert_tensor(
bias,
device=device,
dtype=dtype,
memory_format=torch.contiguous_format,
)
# Construct output tensor
y = None
y_local = None
if with_ub_reduce_scatter: if with_ub_reduce_scatter:
# Initialize buffers for UB reduce-scatter y_local_size = list(x.size())
if with_fp8_output: y_local_size[0] //= tensor_parallel_size
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True) y_local_size[-1] = w.size(0)
fp8_dtype = get_fp8_te_dtype( reduce_scatter_output = torch.empty(y_local_size, dtype=dtype, device=device)
output_fp8_meta["recipe"],
fprop_tensor=True,
)
y = Float8Tensor(
data=ub_global_buffer,
fp8_meta=output_fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
fp8_scale_inv=output_fp8_meta[fp8_meta_key].scale_inv[0],
dtype=dtype,
)
ub_comm.set_ubuf_scale_inv(y._scale_inv)
else:
y = ub_global_buffer
y_local = torch.empty(
(x.size(0) // tensor_parallel_size, weight_dims[0]),
dtype=dtype,
device=device,
)
else:
# Allocate output tensor
if with_fp8_output:
fp8_dtype = get_fp8_te_dtype(
output_fp8_meta["recipe"],
fprop_tensor=True,
)
data = torch.empty(
(x.size(0), weight_dims[0]),
dtype=torch.uint8,
device=device,
)
y = Float8Tensor(
data=data,
fp8_meta=output_fp8_meta,
fp8_meta_forward=True,
fp8_meta_index=0,
fp8_dtype=fp8_dtype,
dtype=dtype,
)
else:
y = torch.empty(
(x.size(0), weight_dims[0]),
dtype=dtype,
device=device,
)
y_local = y
# Perform GEMM # Perform GEMM
if with_fp8_compute: gemm_output, *_, reduce_scatter_output = general_gemm(
kwargs = { w,
"out": y, x,
"bias": b, get_workspace(),
"use_bias": (b is not None), out_dtype=dtype,
"use_split_accumulator": False, quantization_params=output_quantizer,
"ub_algo": ub_algo, bias=bias,
"ub": ub_comm, use_split_accumulator=_2X_ACC_FPROP,
} ub=ub_comm,
if with_ub_all_gather: ub_type=ub_type,
kwargs["extra_output_tensor"] = x_local._data extra_output=reduce_scatter_output,
if with_ub_reduce_scatter: )
kwargs["extra_output_tensor"] = y_local if with_ub_reduce_scatter:
if with_fp8_output: y_local = reduce_scatter_output
fp8_meta, fp8_meta_index = get_fp8_meta_from_fp8_tensor(y)
kwargs.update(
{
"out": y._data,
"out_index": fp8_meta_index,
"fp8_meta_tensor": fp8_meta,
"D_dtype": y._fp8_dtype,
}
)
fp8_gemm(
w._data,
w._scale_inv,
0,
w._fp8_dtype,
x._data,
x._scale_inv,
0,
x._fp8_dtype,
y.dtype,
get_workspace(),
**kwargs,
)
else: else:
kwargs = { y_local = gemm_output
"out": y,
"bias": b, # Detach input tensor if needed
"use_bias": (b is not None), # Note: PyTorch autograd produces esoteric errors if we save
"ub_algo": ub_algo, # input tensor as context for backward pass.
"ub": ub_comm, if x_local is input:
} x_local = x_local.detach()
if with_ub_all_gather:
kwargs["extra_output_tensor"] = x_local # Configure input tensor for backward pass
if with_ub_reduce_scatter: if with_quantized_compute and isinstance(x_local, QuantizedTensorBase):
kwargs["extra_output_tensor"] = y_local if not (isinstance(x_local, Float8TensorBase) and with_ub_all_gather):
gemm(w, x, y.dtype, get_workspace(), **kwargs) # FP8 does not support all-gather of transpose data
x_local.update_usage(rowwise_usage=False, columnwise_usage=True)
# Reshape output tensor
out = reshape(y_local, output_dims)
# Return cast tensors # Return cast tensors
extra_outputs = {"input": x_local, "weight": w} extra_outputs = {"input": x_local, "weight": w}
return out, extra_outputs return y_local, extra_outputs
def fuser_forward( def fuser_forward(
self, self,
...@@ -450,23 +298,22 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -450,23 +298,22 @@ class UserbuffersForwardLinear(FusedOperation):
if basic_op_kwargs[idx]: if basic_op_kwargs[idx]:
raise ValueError("Bias operation forward does not expect keyword arguments") raise ValueError("Bias operation forward does not expect keyword arguments")
# FP8 metadata # Quantization metadata
with_fp8_compute = FP8GlobalStateManager.is_fp8_enabled() with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
input_fp8_meta = None input_quantizer = None
weight_fp8_meta = None weight_quantizer = None
output_fp8_meta = None grad_output_quantizer = None
grad_output_fp8_meta = None grad_input_quantizer = None
grad_input_fp8_meta = None if with_quantized_compute:
if with_fp8_compute: recipe = FP8GlobalStateManager.get_fp8_recipe()
input_fp8_meta = linear_op.get_fp8_meta("input") if not recipe.delayed() and not recipe.mxfp8():
weight_fp8_meta = linear_op.get_fp8_meta("param") raise RuntimeError("Userbuffers is only supported with FP8 delayed scaling recipe")
next_op = basic_op_next_ops[-1] input_quantizer = linear_op.get_quantizer("forward", 0)
if next_op is not None and next_op.num_fp8_scales("input") > 0: weight_quantizer = linear_op.get_quantizer("forward", 1)
output_fp8_meta = next_op.get_fp8_meta("input") grad_output_quantizer = linear_op.get_quantizer("backward", 0)
grad_output_fp8_meta = linear_op.get_fp8_meta("grad_output")
prev_op = basic_op_prev_ops[0] prev_op = basic_op_prev_ops[0]
if prev_op is not None and prev_op.num_fp8_scales("grad_output") > 0: if prev_op is not None and prev_op.num_quantizers("backward") > 0 and recipe.delayed():
grad_input_fp8_meta = prev_op.get_fp8_meta("grad_output") grad_input_quantizer = prev_op.get_quantizer("backward", 0)
# Get autocast dtype if needed # Get autocast dtype if needed
dtype = None dtype = None
...@@ -482,26 +329,26 @@ class UserbuffersForwardLinear(FusedOperation): ...@@ -482,26 +329,26 @@ class UserbuffersForwardLinear(FusedOperation):
input=input_, input=input_,
weight=linear_op.weight, weight=linear_op.weight,
bias=bias, bias=bias,
device=linear_op.device,
dtype=dtype, dtype=dtype,
tensor_parallel_mode=self.tensor_parallel_mode, tensor_parallel_mode=self.tensor_parallel_mode,
tensor_parallel_group=self.tensor_parallel_group, tensor_parallel_group=self.tensor_parallel_group,
tensor_parallel_size=self.tensor_parallel_size, tensor_parallel_size=self.tensor_parallel_size,
sequence_parallel=self.sequence_parallel, sequence_parallel=self.sequence_parallel,
with_fp8_compute=with_fp8_compute, with_quantized_compute=with_quantized_compute,
input_fp8_meta=input_fp8_meta, input_quantizer=input_quantizer,
weight_fp8_meta=weight_fp8_meta, weight_quantizer=weight_quantizer,
output_fp8_meta=output_fp8_meta, output_quantizer=None, # Not supported
ub_comm_name=linear_op._userbuffers_options["comm_name"], ub_comm_name=linear_op._userbuffers_options["comm_name"],
) )
x_local = extra_outputs["input"] x_local = extra_outputs["input"]
# Save state for backward pass # Save state for backward pass
linear_op_ctx.save_for_backward(x_local) linear_op_ctx.save_for_backward(x_local)
linear_op_ctx.with_fp8_compute = with_fp8_compute linear_op_ctx.with_quantized_compute = with_quantized_compute
linear_op_ctx.weight_fp8_meta = weight_fp8_meta linear_op_ctx.input_quantizer = input_quantizer
linear_op_ctx.grad_output_fp8_meta = grad_output_fp8_meta linear_op_ctx.weight_quantizer = weight_quantizer
linear_op_ctx.grad_input_fp8_meta = grad_input_fp8_meta linear_op_ctx.grad_output_quantizer = grad_output_quantizer
linear_op_ctx.grad_input_quantizer = grad_input_quantizer
linear_op_ctx.dtype = dtype linear_op_ctx.dtype = dtype
linear_op_ctx.input_dims = input_.size() linear_op_ctx.input_dims = input_.size()
linear_op_ctx.input_requires_grad = input_.requires_grad linear_op_ctx.input_requires_grad = input_.requires_grad
...@@ -529,8 +376,6 @@ def fuse_userbuffers_forward_linear( ...@@ -529,8 +376,6 @@ def fuse_userbuffers_forward_linear(
""" """
return ops ### TODO Debug Userbuffers support
# Return immediately if environment is not distributed # Return immediately if environment is not distributed
if not torch.distributed.is_initialized() or torch.distributed.get_world_size() == 1: if not torch.distributed.is_initialized() or torch.distributed.get_world_size() == 1:
return ops return ops
......
...@@ -55,7 +55,17 @@ if __name__ == "__main__": ...@@ -55,7 +55,17 @@ if __name__ == "__main__":
description="Transformer acceleration library - Torch Lib", description="Transformer acceleration library - Torch Lib",
ext_modules=ext_modules, ext_modules=ext_modules,
cmdclass={"build_ext": CMakeBuildExtension}, cmdclass={"build_ext": CMakeBuildExtension},
install_requires=["torch"], setup_requires=[
"torch>=2.1",
"nvidia-cuda-runtime-cu12",
"nvidia-cublas-cu12",
"nvidia-cudnn-cu12",
"nvidia-cuda-cccl-cu12",
"nvidia-cuda-nvcc-cu12",
"nvidia-nvtx-cu12",
"nvidia-cuda-nvrtc-cu12",
],
install_requires=["torch>=2.1"],
tests_require=["numpy", "torchvision"], tests_require=["numpy", "torchvision"],
) )
if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")): if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment