Commit c520cba3 authored by yuguo's avatar yuguo
Browse files

[DCU] Preliminary adaptation

parent 5b6ef054
......@@ -17,6 +17,7 @@ import numpy as np
from packaging.version import Version as PkgVersion
import torch
from torch.utils.cpp_extension import IS_HIP_EXTENSION
import transformer_engine_torch as tex
from transformer_engine.pytorch.utils import (
......@@ -98,7 +99,7 @@ try:
except PackageNotFoundError:
if (
torch.cuda.is_available()
and get_device_compute_capability() >= (8, 0)
and (IS_HIP_EXTENSION or get_device_compute_capability() >= (8, 0))
and dpa_utils._NVTE_FLASH_ATTN
):
attn_log.fa_logger.debug(
......@@ -128,7 +129,7 @@ else:
fa_utils.set_flash_attention_version()
elif (
torch.cuda.is_available()
and get_device_compute_capability() >= (8, 0)
and (IS_HIP_EXTENSION or get_device_compute_capability() >= (8, 0))
and dpa_utils._NVTE_FLASH_ATTN
):
attn_log.fa_logger.warning(
......@@ -147,33 +148,34 @@ else:
# Detect flash-attn v3 in the environment
# This section will be removed when FA3 is released as a regular FA package,
# i.e. flashattn-hopper 3.0.0 as flash-attn 3.0.0
try:
fa_utils.fa3_version = PkgVersion(get_pkg_version("flashattn-hopper"))
except PackageNotFoundError:
if (
torch.cuda.is_available()
and get_device_compute_capability() >= (9, 0)
and dpa_utils._NVTE_FLASH_ATTN
):
attn_log.fa_logger.debug(
"flash-attn v3 is not installed. To use, please install it by \n%s",
fa_utils.v3_installation_steps,
if not IS_HIP_EXTENSION:
try:
fa_utils.fa3_version = PkgVersion(get_pkg_version("flashattn-hopper"))
except PackageNotFoundError:
if (
torch.cuda.is_available()
and get_device_compute_capability() >= (9, 0)
and dpa_utils._NVTE_FLASH_ATTN
):
attn_log.fa_logger.debug(
"flash-attn v3 is not installed. To use, please install it by \n%s",
fa_utils.v3_installation_steps,
)
else:
from flashattn_hopper.flash_attn_interface import flash_attn_func as flash_attn_func_v3
from flashattn_hopper.flash_attn_interface import (
flash_attn_varlen_func as flash_attn_varlen_func_v3,
)
else:
from flashattn_hopper.flash_attn_interface import flash_attn_func as flash_attn_func_v3
from flashattn_hopper.flash_attn_interface import (
flash_attn_varlen_func as flash_attn_varlen_func_v3,
)
from flashattn_hopper.flash_attn_interface import _flash_attn_forward as _flash_attn_fwd_v3
from flashattn_hopper.flash_attn_interface import _flash_attn_backward as _flash_attn_bwd_v3
from flashattn_hopper.flash_attn_interface import (
_flash_attn_varlen_forward as _flash_attn_varlen_fwd_v3,
)
from flashattn_hopper.flash_attn_interface import (
_flash_attn_varlen_backward as _flash_attn_varlen_bwd_v3,
)
fa_utils.set_flash_attention_3_params()
from flashattn_hopper.flash_attn_interface import _flash_attn_forward as _flash_attn_fwd_v3
from flashattn_hopper.flash_attn_interface import _flash_attn_backward as _flash_attn_bwd_v3
from flashattn_hopper.flash_attn_interface import (
_flash_attn_varlen_forward as _flash_attn_varlen_fwd_v3,
)
from flashattn_hopper.flash_attn_interface import (
_flash_attn_varlen_backward as _flash_attn_varlen_bwd_v3,
)
fa_utils.set_flash_attention_3_params()
# Global vars for available attention backends and ALiBi cache
_attention_backends = {
......
......@@ -5,5 +5,7 @@
"""Python interface for c++ extensions"""
from transformer_engine_torch import *
from .fused_attn import *
from torch.utils.cpp_extension import IS_HIP_EXTENSION
if not IS_HIP_EXTENSION:
from .fused_attn import *
from .gemm import *
......@@ -224,3 +224,85 @@ def general_grouped_gemm(
)
return out, bias, gelu_input
def general_batched_gemm(
A: List[torch.Tensor],
B: List[torch.Tensor],
out: List[torch.Tensor],
out_dtype: torch.dtype,
workspaces: List[torch.Tensor],
layout: str = "TN",
m_splits: Optional[List[int]] = None,
gelu: bool = False,
grad=False,
accumulate: bool = False,
bias: Optional[List[torch.Tensor]] = None,
use_bias: bool = False,
use_split_accumulator: bool = False,
D_dtype: Optional[tex.DType] = None,
single_output=False,
) -> Tuple[List[torch.Tensor], ...]:
"""
TN layout Grouped GEMM with fp8 inputs.
"""
num_gemms = len(A)
transa = layout[0] == "T"
transb = layout[1] == "T"
# assert [a.is_contiguous() for a in A]
# assert [b.is_contiguous() for b in B]
if isinstance(A[0], Float8TensorBase):
for a, b in zip(A, B):
assert_dim_for_fp8_exec(a._data)
assert_dim_for_fp8_exec(b._data)
empty_tensor = _empty_tensor()
empty_tensors = [empty_tensor] * num_gemms
# Use bfloat16 as default bias_dtype
gelu_input = empty_tensors
out_dtype = TE_DType[out[0].dtype] if D_dtype is None else D_dtype
sm_count = get_sm_count()
if grad and use_bias:
grad_bias = [
torch.empty(B[i].shape[1], dtype=out[0].dtype, device="cuda") for i in range(num_gemms)
]
else:
grad_bias = empty_tensors
bias = bias if use_bias else empty_tensors
if use_bias:
bias_dtype = TE_DType[grad_bias[0].dtype] if grad else TE_DType[bias[0].dtype]
else:
bias_dtype = TE_DType[torch.bfloat16]
if gelu:
gelu_input = [
torch.empty_like(o, dtype=bias_dtype, memory_format=torch.contiguous_format)
for o in out
] # this should differ with respect to single output
bias = tex.te_general_batched_gemm(
A,
transa,
B,
transb,
out,
out_dtype,
m_splits,
grad_bias if grad else bias,
bias_dtype,
single_output,
gelu_input, # this is pre_gelu_out
grad, # grad
workspaces,
workspaces[0].shape[0],
accumulate,
use_split_accumulator,
sm_count - int(os.getenv("NVTE_EXT_MARGIN_SM", str(sm_count))),
)
return out, bias, gelu_input
......@@ -14,11 +14,15 @@
#include <ATen/cudnn/Handle.h>
#include <ATen/native/DispatchStub.h>
#include <c10/macros/Macros.h>
#include <cuda_runtime.h>
#ifndef USE_ROCM
#include <cublasLt.h>
#include <cuda.h>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <cudnn.h>
#include <cuda_bf16.h>
#else
#include <hip/hip_bf16.h>
#endif
#include <torch/extension.h>
#include <torch/torch.h>
#include <transformer_engine/activation.h>
......
......@@ -93,6 +93,16 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
bool grad, std::vector<at::Tensor> workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, int math_sm_count);
#ifdef __HIP_PLATFORM_AMD__
std::optional<std::vector<at::Tensor>> te_general_batched_gemm(
std::vector<py::handle> A, bool transa, std::vector<py::handle> B, bool transb,
std::optional<std::vector<at::Tensor>> D, transformer_engine::DType D_type,
std::vector<int64_t> m_splits, std::vector<at::Tensor> bias,
transformer_engine::DType bias_type, bool single_output, std::vector<at::Tensor> pre_gelu_out,
bool grad, std::vector<at::Tensor> workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, int math_sm_count)
#endif
/***************************************************************************************************
* Transpose
**************************************************************************************************/
......
......@@ -16,11 +16,16 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend(
float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left,
int64_t window_size_right) {
#ifdef __HIP_PLATFORM_AMD__
static_assert(false,
"Get_fused_attn_backend is not surpported in rocm for normalization yet.");
#else
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type,
attn_mask_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv,
head_dim_qk, head_dim_v, window_size_left, window_size_right);
return fused_attention_backend;
#endif
}
// fast zero-fills of tensors
......@@ -93,6 +98,10 @@ std::vector<py::object> fused_attn_fwd(
const c10::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer,
py::handle o_quantizer, const c10::optional<at::Tensor> Bias,
const c10::optional<at::Generator> rng_gen, size_t rng_elts_per_thread) {
#ifdef __HIP_PLATFORM_AMD__
static_assert(false,
"Fused_attn_fwd is not surpported in rocm for normalization yet.");
#else
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
TensorWrapper te_Q, te_K, te_V, te_O, te_S;
......@@ -254,6 +263,7 @@ std::vector<py::object> fused_attn_fwd(
// if training, [O, softmax-related tensors, rng_state]; if inference, [O]
return output_tensors;
#endif
}
// fused attention BWD with separate Q, K and V
......@@ -267,6 +277,10 @@ std::vector<py::object> fused_attn_bwd(
const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer,
py::handle dp_quantizer, py::handle dqkv_quantizer) {
#ifdef __HIP_PLATFORM_AMD__
static_assert(false,
"Fused_attn_bwd is not surpported in rocm for normalization yet.");
#else
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
auto none = py::none();
......@@ -492,6 +506,7 @@ std::vector<py::object> fused_attn_bwd(
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
return {py_dQ, py_dK, py_dV, py::cast(dBias)};
#endif
}
namespace flash_attention {
......
......@@ -411,3 +411,124 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
math_sm_count, at::cuda::getCurrentCUDAStream());
return bias;
}
#ifdef USE_ROCM
std::optional<std::vector<at::Tensor>> te_general_batched_gemm(
std::vector<py::handle> A, bool transa, std::vector<py::handle> B, bool transb,
std::optional<std::vector<at::Tensor>> D, transformer_engine::DType D_type,
std::vector<int64_t> m_splits, std::vector<at::Tensor> bias,
transformer_engine::DType bias_type, bool single_output, std::vector<at::Tensor> pre_gelu_out,
bool grad, std::vector<at::Tensor> workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, int math_sm_count) {
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
std::vector<NVTETensor> te_A_vector, te_B_vector, te_D_vector, te_bias_vector,
te_pre_gelu_out_vector, te_workspace_vector;
std::vector<TensorWrapper> wrappers;
std::vector<at::Tensor> D_vectors;
auto none = py::none();
std::vector<size_t> single_output_begins;
std::vector<size_t> single_output_ends;
int slicing_dim;
if (single_output && D == std::nullopt) {
NVTE_ERROR("not implemented, D should be allocated for single output case.");
}
void* output_data_ptr;
if (single_output) {
output_data_ptr = (*D)[0].data_ptr();
}
for (size_t i = 0; i < A.size(); i++) {
auto te_A = makeTransformerEngineTensor(A[i], none);
auto te_B = makeTransformerEngineTensor(B[i], none);
// if there is single output
at::Tensor out_tensor;
auto size_t_shape =
pytorch::detail::getGemmOutputShape(te_A.shape(), transa, te_B.shape(), transb);
bool D_numel_is_zero = false;
std::vector<int64_t> D_shape;
for (size_t t : size_t_shape) {
D_shape.push_back(t);
if (t == 0) {
D_numel_is_zero = true;
}
}
auto dtype = GetATenDType(D_type);
auto opts = torch::TensorOptions().dtype(dtype).device(torch::kCUDA);
if (single_output) {
if (output_data_ptr == nullptr) {
out_tensor = at::empty(D_shape, opts);
} else {
// We need to check !D_numel_is_zero because if the final input portion has zero elements,
// output_data_ptr would point beyond the allocated memory of D. This would cause
// at::from_blob to fail as it would reference memory not allocated by CUDA.
if (!D_numel_is_zero) {
out_tensor = at::from_blob(output_data_ptr, D_shape, opts);
}
}
char* char_ptr = reinterpret_cast<char*>(output_data_ptr);
char_ptr += D_shape[0] * D_shape[1] * (*D)[0].element_size();
output_data_ptr = reinterpret_cast<void*>(char_ptr);
D_vectors.emplace_back(out_tensor);
} else {
if (D == std::nullopt) {
auto opts = torch::TensorOptions().dtype(dtype).device(torch::kCUDA);
out_tensor = at::empty(D_shape, opts);
D_vectors.emplace_back(out_tensor);
} else {
out_tensor = (*D)[i];
}
}
if (te_A.numel() == 0 || te_B.numel() == 0) {
if (out_tensor.numel() != 0 && !accumulate) out_tensor.zero_();
if (bias[i].numel() != 0 && grad) {
bias[i].zero_();
}
if (pre_gelu_out[i].numel() != 0) pre_gelu_out[i].zero_();
continue;
}
auto te_D = makeTransformerEngineTensor(out_tensor);
auto te_bias = makeTransformerEngineTensor(bias[i]);
auto te_pre_gelu_out = makeTransformerEngineTensor(pre_gelu_out[i]);
const auto gelu_shape = pre_gelu_out[i].data_ptr() == nullptr
? std::vector<size_t>{static_cast<size_t>(te_pre_gelu_out.size(0))}
: std::vector<size_t>{static_cast<size_t>(te_pre_gelu_out.size(0)),
static_cast<size_t>(te_pre_gelu_out.size(1))};
DType gelu_type = bias_type;
te_pre_gelu_out =
makeTransformerEngineTensor(get_data_ptr(pre_gelu_out[i]), gelu_shape, gelu_type);
te_A_vector.emplace_back(te_A.data());
te_B_vector.emplace_back(te_B.data());
te_D_vector.emplace_back(te_D.data());
te_bias_vector.emplace_back(te_bias.data());
te_pre_gelu_out_vector.emplace_back(te_pre_gelu_out.data());
wrappers.emplace_back(std::move(te_A));
wrappers.emplace_back(std::move(te_B));
wrappers.emplace_back(std::move(te_D));
wrappers.emplace_back(std::move(te_bias));
wrappers.emplace_back(std::move(te_pre_gelu_out));
}
for (size_t i = 0; i < workspace.size(); i++) {
auto wsp = makeTransformerEngineTensor(workspace[i].data_ptr(), {workspaceSize}, DType::kByte);
te_workspace_vector.emplace_back(wsp.data());
wrappers.emplace_back(std::move(wsp));
}
// For now, we only have multi-stream cublas backend.
nvte_multi_stream_cublas_batchgemm(te_A.data(), te_B.data(), te_D.data(), te_bias.data(),
te_pre_gelu_out.data(), te_A.size(), transa, transb, grad,
te_workspace.data(), accumulate, use_split_accumulator,
math_sm_count, at::cuda::getCurrentCUDAStream());
return bias;
}
#endif
......@@ -6,6 +6,12 @@
#include "extensions.h"
#ifdef USE_ROCM
size_t get_cublasLt_version() { int version = 10000000; return version; }
size_t get_cudnn_version() { int version = 0; return version; }
#else
size_t get_cublasLt_version() { return cublasLtGetVersion(); }
size_t get_cudnn_version() { return cudnnGetVersion(); }
#endif
......@@ -8,7 +8,11 @@
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#ifdef __HIP_PLATFORM_AMD__
#include "amd_detail/hip_float8.h"
#else
#include <cuda_fp8.h>
#endif
// Another possibility:
// #include <torch/all.h>
......@@ -28,8 +32,13 @@ typedef enum {
} adamMode_t;
using MATH_T = float;
#ifndef __HIP_PLATFORM_AMD__
using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2;
#else
using fp8e4m3 = hip_f8<hip_f8_type::fp8>;
using fp8e5m2 = hip_f8<hip_f8_type::bf8>;
#endif
using transformer_engine::DType;
template <typename T>
......
......@@ -174,6 +174,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("input_list"), py::arg("output_list"), py::arg("quantizer_list"), py::arg("otype"));
m.def("te_general_grouped_gemm", &te_general_grouped_gemm, "Grouped GEMM");
#ifdef USE_ROCM
m.def("te_general_batched_gemm", &te_general_batched_gemm, "Batched GEMM"); /// rocblas
#endif
m.def("fused_attn_fwd", &fused_attn_fwd,
"Fused Attention FP8/BF16/FP16 FWD with separate Q, K and V");
m.def("fused_attn_bwd", &fused_attn_bwd,
......@@ -207,6 +210,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("get_cudnn_version", &get_cudnn_version, "Get cuDNN version",
py::call_guard<py::gil_scoped_release>());
m.attr("_num_cublas_streams") = py::int_(transformer_engine::num_streams);
#ifdef USE_ROCM
m.attr("_num_cublas_batchgemm_streams") = py::int_(transformer_engine::num_batchgemm_streams);
#endif
// Support THD format for Context Parallel
m.def("thd_read_half_tensor", &thd_read_half_tensor,
......
......@@ -267,6 +267,8 @@
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
}
constexpr uint32_t THREADS_PER_WARP = 32;
template <typename T>
__device__ __forceinline__ T
reduce_block_into_lanes(T *x, T val, int lanes = 1,
......@@ -295,7 +297,11 @@ reduce_block_into_lanes(T *x, T val, int lanes = 1,
// __SYNCWARP();
#pragma unroll
#ifdef __HIP_PLATFORM_AMD__
for (int i = 16; i >= lanes; i >>= 1) final = final + __shfl_down(final, i, THREADS_PER_WARP);
#else
for (int i = 16; i >= lanes; i >>= 1) final = final + __shfl_down_sync(0xffffffff, final, i);
#endif
}
if (share_result) {
......@@ -337,7 +343,11 @@ reduce_block_into_lanes_max_op(T *x, T val, int lanes = 1,
#pragma unroll
for (int i = 16; i >= lanes; i >>= 1)
#ifdef __HIP_PLATFORM_AMD__
final = fmaxf(fabsf(final), fabsf(__shfl_down(final, i, THREADS_PER_WARP)));
#else
final = fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i)));
#endif
}
if (share_result) {
......
......@@ -37,7 +37,7 @@ from transformer_engine.pytorch.cpp_extensions.fused_attn import (
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.fp8 import get_fp8_te_dtype
from transformer_engine.pytorch.constants import TE_DType
from torch.utils.cpp_extension import IS_HIP_EXTENSION
from transformer_engine.pytorch.utils import (
get_device_compute_capability,
......@@ -347,7 +347,7 @@ def get_attention_backend(
logger.debug("Disabling UnfusedDotProductAttention due to NVTE_UNFUSED_ATTN=0")
# Filter: Compute capability
if device_compute_capability < (8, 0):
if not IS_HIP_EXTENSION and device_compute_capability < (8, 0):
if use_flash_attention and FlashAttentionUtils.is_installed:
logger.debug("Disabling FlashAttention as it requires compute capability sm80+")
use_flash_attention = False
......@@ -395,12 +395,22 @@ def get_attention_backend(
if use_unfused_attention:
logger.debug("Disabling UnfusedDotProductAttention as it does not support FP8")
use_unfused_attention = False
# TODO: rocm fused attention backends does not support fp8 yet
if IS_HIP_EXTENSION and use_fused_attention:
logger.debug("Disabling ROCm FusedAttention as it does not support FP8")
use_fused_attention = False
# Filter: Head dimension
if use_flash_attention and head_dim_qk != head_dim_v:
if FlashAttentionUtils.is_installed:
logger.debug("Disabling FlashAttention as it does not support MLA.")
use_flash_attention = False
if not IS_HIP_EXTENSION:
if use_flash_attention and head_dim_qk != head_dim_v:
if FlashAttentionUtils.is_installed:
logger.debug("Disabling FlashAttention as it does not support MLA.")
use_flash_attention = False
else:
if use_fused_attention and head_dim_qk != head_dim_v:
logger.debug("Disabling FusedAttention as it does not support MLA in rocm backend.")
use_fused_attention = False
if use_flash_attention and (
head_dim_qk > 256
or head_dim_qk % 8 != 0
......@@ -441,6 +451,12 @@ def get_attention_backend(
"padding between sequences, i.e. [a, a, PAD, b, b, b, PAD, c, PAD]"
)
use_flash_attention = False
if IS_HIP_EXTENSION and use_fused_attention and pad_between_seqs:
logger.debug(
"Disabling rocm fused attn for qkv_format = thd when there is "
"padding between sequences, i.e. [a, a, PAD, b, b, b, PAD, c, PAD]"
)
use_fused_attention = False
# Filter: Dropout
if attention_dropout != 0.0 and use_flash_attention and FlashAttentionUtils.use_v3:
......@@ -839,7 +855,7 @@ def get_attention_backend(
# Select FusedAttention for performance
if (
use_flash_attention
use_flash_attention and (not IS_HIP_EXTENSION)
and use_fused_attention
and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
):
......@@ -852,6 +868,7 @@ def get_attention_backend(
if (
use_flash_attention
and use_fused_attention
and not IS_HIP_EXTENSION
and fused_attention_backend == FusedAttnBackend["FP8"]
and FlashAttentionUtils.use_v3
):
......
......@@ -24,6 +24,7 @@ from transformer_engine.common.recipe import (
from .constants import dist_group_type
from .utils import get_device_compute_capability
from .jit import jit_fuser
from torch.utils.cpp_extension import IS_HIP_EXTENSION
__all__ = ["fp8_autocast", "fp8_model_init"]
......@@ -31,14 +32,20 @@ __all__ = ["fp8_autocast", "fp8_model_init"]
def check_fp8_support() -> Tuple[bool, str]:
"""Return if fp8 support is available"""
if get_device_compute_capability() >= (9, 0): # hopper and above
return True, ""
if get_device_compute_capability() < (8, 9): # pre-ada
return False, "Device compute capability 8.9 or higher required for FP8 execution."
if tex.get_cublasLt_version() < 120103:
return False, "CublasLt version 12.1.3.x or higher required for FP8 execution on Ada."
if float(torch.version.cuda) < 12.1:
return False, "Cuda version 12.1 or higher required for FP8 execution on Ada."
if IS_HIP_EXTENSION:
if get_device_compute_capability() == (9, 4):
return True, ""
else:
return False, "DCU not support fp8 for now"
else:
if get_device_compute_capability() >= (9, 0): # hopper and above
return True, ""
if get_device_compute_capability() < (8, 9): # pre-ada
return False, "Device compute capability 8.9 or higher required for FP8 execution."
if tex.get_cublasLt_version() < 120103:
return False, "CublasLt version 12.1.3.x or higher required for FP8 execution on Ada."
if float(torch.version.cuda) < 12.1:
return False, "Cuda version 12.1 or higher required for FP8 execution on Ada."
return True, ""
......
......@@ -7,6 +7,7 @@ import os
from typing import Callable, Optional, Tuple
import torch
from torch.utils.cpp_extension import IS_HIP_EXTENSION
# pylint: disable=unnecessary-lambda-assignment
......@@ -27,27 +28,28 @@ no_torch_dynamo = lambda recursive=True: lambda f: torch._dynamo.disable(f, recu
def set_jit_fusion_options() -> None:
"""Set PyTorch JIT layer fusion options."""
# flags required to enable jit fusion kernels
TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split(".")[1])
if TORCH_MAJOR == 2 and TORCH_MINOR >= 2:
pass
elif (TORCH_MAJOR == 2) or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10):
# nvfuser
torch._C._jit_set_profiling_executor(True)
torch._C._jit_set_profiling_mode(True)
torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_override_can_fuse_on_gpu(False)
torch._C._jit_set_texpr_fuser_enabled(False)
torch._C._jit_set_nvfuser_enabled(True)
torch._C._debug_set_autodiff_subgraph_inlining(False)
else:
# legacy pytorch fuser
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
if not IS_HIP_EXTENSION:
"""Set PyTorch JIT layer fusion options."""
# flags required to enable jit fusion kernels
TORCH_MAJOR = int(torch.__version__.split(".")[0])
TORCH_MINOR = int(torch.__version__.split(".")[1])
if TORCH_MAJOR == 2 and TORCH_MINOR >= 2:
pass
elif (TORCH_MAJOR == 2) or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10):
# nvfuser
torch._C._jit_set_profiling_executor(True)
torch._C._jit_set_profiling_mode(True)
torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_override_can_fuse_on_gpu(False)
torch._C._jit_set_texpr_fuser_enabled(False)
torch._C._jit_set_nvfuser_enabled(True)
torch._C._debug_set_autodiff_subgraph_inlining(False)
else:
# legacy pytorch fuser
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)
@jit_fuser
......
......@@ -35,6 +35,7 @@ from ..constants import dist_group_type
from ..tensor import QuantizedTensor, Quantizer
from ..tensor._internal.float8_tensor_base import Float8TensorBase
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from torch.utils.cpp_extension import IS_HIP_EXTENSION
__all__ = ["initialize_ub", "destroy_ub"]
......@@ -42,6 +43,7 @@ _2X_ACC_FPROP = False
_2X_ACC_DGRAD = True
_2X_ACC_WGRAD = True
_multi_stream_cublas_workspace = []
_multi_stream_cublas_batchgemm_workspace = []
_cublas_workspace = None
_ub_communicators = None
_NUM_MAX_UB_STREAMS = 3
......@@ -51,6 +53,13 @@ layers_atomic_ring_exchange = []
def get_cublas_workspace_size_bytes() -> None:
"""Return 32 MiB if using hopper, 4 MiB for all other architectures."""
# Add env for control the padding for blaslt
if IS_HIP_EXTENSION:
nvte_blaslt_nopad = int(os.environ.get("NVTE_BLASLT_NOPAD", 0))
if(nvte_blaslt_nopad):
return 536_870_912
else:
return 1_073_741_824
if torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 9:
return 33_554_432
return 4_194_304
......@@ -76,6 +85,16 @@ def get_multi_stream_cublas_workspace() -> List[torch.Tensor]:
)
return _multi_stream_cublas_workspace
def get_multi_stream_cublas_batchgemm_workspace() -> List[torch.Tensor]:
"""Returns workspace for multi-stream cublas."""
global _multi_stream_cublas_batchgemm_workspace
if not _multi_stream_cublas_batchgemm_workspace:
for _ in range(tex._num_cublas_batchgemm_streams):
_multi_stream_cublas_batchgemm_workspace.append(
torch.empty(get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda")
)
return _multi_stream_cublas_batchgemm_workspace
def initialize_ub(
shape: list,
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""BatchedLinear API"""
from typing import Union, Optional, Callable, Tuple, List
import torch
import transformer_engine_torch as tex
from .base import (
get_multi_stream_cublas_batchgemm_workspace,
TransformerEngineBaseModule,
_2X_ACC_FPROP,
_2X_ACC_DGRAD,
_2X_ACC_WGRAD,
)
from ..fp8 import FP8GlobalStateManager
from ..utils import (
divide,
cast_if_needed,
assert_dim_for_fp8_exec,
clear_tensor_data,
init_method_constant,
requires_grad,
)
from ..distributed import (
set_tensor_model_parallel_attributes,
get_distributed_world_size,
is_fp8_activation_recompute_enabled,
in_fp8_activation_recompute_phase,
)
from ..cpp_extensions import (
general_batched_gemm,
)
from ..constants import GemmParallelModes, dist_group_type, TE_DType
from ..jit import no_torch_dynamo
from ..graph import is_graph_capturing
from ..tensor.float8_tensor import Float8Tensor
from ..cpu_offload import is_cpu_offload_enabled
from ..tensor.quantized_tensor import (
QuantizedTensor,
Quantizer,
prepare_for_saving,
restore_from_saved,
)
__all__ = ["BatchedLinear"]
class _BatchedLinear(torch.autograd.Function):
"""BatchedLinear semi-top level module
Calls custom cuda extensions.
"""
@staticmethod
def forward(
ctx,
inp: torch.Tensor,
m_splits: List[int],
use_bias: bool,
is_first_microbatch: Union[bool, None],
fp8: bool,
fp8_calibration: bool,
input_quantizers: List[Quantizer],
weight_quantizers: List[Quantizer],
output_quantizers: List[Quantizer],
grad_output_quantizers: List[Quantizer],
fuse_wgrad_accumulation: bool,
cpu_offloading: bool,
sequence_parallel: bool,
activation_dtype: torch.dtype,
is_grad_enabled: bool,
module,
skip_fp8_weight_update,
*weights_and_biases,
) -> torch.Tensor:
# pylint: disable=missing-function-docstring
num_gemms = len(m_splits)
weights = weights_and_biases[:num_gemms]
biases = weights_and_biases[num_gemms:]
device = inp.device
# TODO Support MXFP8 # pylint: disable=fixme
if fp8 and FP8GlobalStateManager.get_fp8_recipe().mxfp8():
raise NotImplementedError("BatchedLinear does not yet support MXFP8")
# TODO Support Float8 Current Scaling # pylint: disable=fixme
if fp8 and FP8GlobalStateManager.get_fp8_recipe().float8_current_scaling():
raise NotImplementedError("BatchedLinear does not yet support Float8 Current Scaling")
# Make sure input dimensions are compatible
in_features = weights[0].shape[-1]
assert inp.shape[-1] == in_features, "GEMM not possible"
inputmats = torch.split(inp.view(-1, in_features), m_splits)
if fp8:
assert_dim_for_fp8_exec(*inputmats, *weights)
# Cast input to expected dtype
inputmats_no_fp8 = [cast_if_needed(mat, activation_dtype) for mat in inputmats]
inputmats = []
weight_requires_grad = weights[0].requires_grad
if input_quantizers[0] is not None:
for input_quantizer in input_quantizers:
input_quantizer.set_usage(
rowwise=True,
columnwise=(is_grad_enabled and weight_requires_grad),
)
columnwise_usage = is_grad_enabled and inp.requires_grad
if not columnwise_usage:
columnwise_usage = (
is_fp8_activation_recompute_enabled()
and not in_fp8_activation_recompute_phase()
)
if weight_quantizers[0] is not None:
for weight_quantizer in weight_quantizers:
weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
if output_quantizers[0] is not None:
for output_quantizer in output_quantizers:
output_quantizer.set_usage(rowwise=True, columnwise=False)
if fp8:
inputmats = tex.fused_multi_quantize(
inputmats_no_fp8, None, input_quantizers, TE_DType[activation_dtype]
)
weights_fp8 = []
bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype
if not isinstance(weights[0], QuantizedTensor):
# FP8 cast to workspace buffer
update_workspace = is_first_microbatch is None or is_first_microbatch
for i in range(num_gemms):
weight_fp8 = module.get_weight_workspace(
tensor=weights[i],
quantizer=weight_quantizers[i],
cache_name=(None if is_first_microbatch is None else f"weight{i}"),
update_workspace=update_workspace,
skip_update_flag=skip_fp8_weight_update,
)
weights_fp8.append(weight_fp8)
else:
weights_fp8 = weights
else:
inputmats = inputmats_no_fp8
bias_dtype = activation_dtype
weights_fp8 = [cast_if_needed(weight, activation_dtype) for weight in weights]
biases = [cast_if_needed(bias, bias_dtype) for bias in biases] if use_bias else biases
out = torch.empty(
[sum(m_splits), weights_fp8[0].size(0)],
dtype=activation_dtype,
device=device,
)
_ = general_batched_gemm(
weights_fp8,
inputmats,
[out],
activation_dtype,
get_multi_stream_cublas_batchgemm_workspace(),
single_output=True,
m_splits=m_splits,
bias=biases,
use_bias=use_bias,
use_split_accumulator=_2X_ACC_FPROP,
)
if fp8_calibration:
for i in range(num_gemms):
# amax of input
for i in range(num_gemms):
input_quantizers[i].calibrate(inputmats[i])
for i in range(num_gemms):
weight_quantizers[i].calibrate(weights[i])
if is_grad_enabled:
ctx.weights_shape_1 = weights[0].shape[1]
tensors_to_save, tensor_objects = prepare_for_saving(*inputmats, *weights_fp8, *biases)
ctx.save_for_backward(*tensors_to_save)
ctx.tensor_objects = tensor_objects
ctx.weights_requires_grad = weights[0].requires_grad
if fuse_wgrad_accumulation and ctx.weights_requires_grad:
ctx.main_grads = [weights[i].main_grad for i in range(num_gemms)]
else:
ctx.main_grads = [None] * num_gemms
ctx.device = device
ctx.grad_output_quantizers = grad_output_quantizers
ctx.m_splits = m_splits
ctx.num_gemms = num_gemms
ctx.activation_dtype = activation_dtype
ctx.fp8 = fp8
ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
ctx.cpu_offloading = cpu_offloading
ctx.is_first_microbatch = is_first_microbatch
ctx.use_bias = use_bias
ctx.sequence_parallel = sequence_parallel
ctx.inp_shape = inp.shape
ctx.requires_dgrad = inp.requires_grad
ctx.reduce_and_update_bwd_fp8_tensors = False
if ctx.fp8 and requires_grad(inp, weights[0], biases[0]):
ctx.reduce_and_update_bwd_fp8_tensors = (
ctx.reduce_and_update_bwd_fp8_tensors
or FP8GlobalStateManager.is_first_fp8_module()
)
# [*, in_features] -> [*, out_features] except first dimension changes for SP
return out.view(-1, *inp.shape[1:-1], out.shape[-1])
@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
# pylint: disable=missing-function-docstring
with torch.cuda.nvtx.range("_BatchedLinear_backward"):
saved_tensors = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors)
N = ctx.num_gemms
inputmats = saved_tensors[:N]
weights = saved_tensors[N : 2 * N]
biases = saved_tensors[2 * N : 3 * N]
main_grads = ctx.main_grads
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: # TOSO
for i in ctx.num_gemms:
w = torch.nn.Parameter(weights[i], weights[i].requires_grad)
w.main_grad = main_grads[i]
weights[i] = w
# preprocess grad_output
grad_output = grad_output.contiguous()
grad_output_mats = torch.split(
grad_output.view(-1, grad_output.shape[-1]), ctx.m_splits
)
grad_output = [None] * ctx.num_gemms
grad_biases = [None] * ctx.num_gemms
if ctx.fp8:
if ctx.use_bias:
for i in range(ctx.num_gemms):
grad_biases[i], grad_output[i] = tex.bgrad_quantize(
grad_output_mats[i], ctx.grad_output_quantizers[i]
)
else:
grad_output = tex.fused_multi_quantize(
grad_output_mats,
None,
ctx.grad_output_quantizers,
TE_DType[ctx.activation_dtype],
)
else:
grad_output = grad_output_mats
if ctx.is_first_microbatch is not None:
accumulate_wgrad_into_param_main_grad = (
ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch
)
else:
accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation
if ctx.requires_dgrad:
dgrad = torch.empty(
(sum(ctx.m_splits), ctx.weights_shape_1),
dtype=ctx.activation_dtype,
device=ctx.device,
)
general_batched_gemm(
weights,
grad_output,
[dgrad],
ctx.activation_dtype,
get_multi_stream_cublas_batchgemm_workspace(),
single_output=True,
layout="NN",
m_splits=ctx.m_splits,
grad=True,
use_split_accumulator=_2X_ACC_DGRAD,
)
if ctx.weights_requires_grad:
if ctx.fuse_wgrad_accumulation:
wgrad_list = main_grads
else:
wgrad_list = [
torch.empty(w.size(), dtype=ctx.activation_dtype, device=ctx.device)
for w in weights
]
# WGRAD
_, grad_biases_, _ = general_batched_gemm(
inputmats,
grad_output,
wgrad_list,
ctx.activation_dtype,
get_multi_stream_cublas_batchgemm_workspace(),
layout="NT",
grad=True,
m_splits=ctx.m_splits,
use_bias=ctx.use_bias if grad_biases[0] is None else None,
bias=biases,
use_split_accumulator=_2X_ACC_WGRAD,
accumulate=accumulate_wgrad_into_param_main_grad,
)
for i in range(ctx.num_gemms):
if grad_biases[i] is None:
grad_biases[i] = grad_biases_[i]
del grad_biases_
# Deallocate input tensor
clear_tensor_data(*inputmats)
def handle_custom_ddp_from_mcore(w, wgrad):
if ctx.weights_requires_grad:
if ctx.fuse_wgrad_accumulation and hasattr(w, "grad_added_to_main_grad"):
w.grad_added_to_main_grad = True
if getattr(w, "zero_out_wgrad", False):
wgrad = torch.zeros(
w.main_grad.shape,
dtype=w.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
else:
wgrad = torch.empty(
w.main_grad.shape,
dtype=w.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
elif ctx.fuse_wgrad_accumulation:
wgrad = None
else:
wgrad = None
return wgrad
wgrad_list = [
handle_custom_ddp_from_mcore(w, wgrad) for w, wgrad in zip(weights, wgrad_list)
]
else:
wgrad_list = [None] * ctx.num_gemms
if not ctx.use_bias:
grad_biases = [None] * ctx.num_gemms
if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing():
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
return (
dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None, # is_grad_enabled
None, # is_grad_enabled
*wgrad_list,
*grad_biases,
)
class BatchedLinear(TransformerEngineBaseModule):
"""Applies linear transformations to the incoming data list
:math:`y_i = x_iA_i^T + b_i` in a batched way.
Parameters
----------
num_gemms : int
number of GEMMs to be performed simutaneously.
in_features : int
size of each input sample.
out_features : int
size of each output sample.
bias : bool, default = `True`
if set to `False`, the layer will not learn an additive bias.
init_method : Callable, default = `None`
used for initializing weights in the following way: `init_method(weight)`.
When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
get_rng_state_tracker : Callable, default = `None`
used to get the random number generator state tracker for initializing weights.
rng_tracker_name : str, default = `None`
the param passed to get_rng_state_tracker to get the specific rng tracker.
device : Union[torch.device, str], default = "cuda"
The device on which the parameters of the model will be allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
forward pass.
Optimization parameters
-----------------------
fuse_wgrad_accumulation : bool, default = 'False'
if set to `True`, enables fusing of creation and accumulation of
the weight gradient. When enabled, it is assumed that the weights
have an additional `main_grad` attribute (used instead of the
regular `grad`) which is a pre-allocated buffer of the correct
size to accumulate gradients in.
return_bias : bool, default = `False`
when set to `True`, this module will not apply the additive bias itself, but
instead return the bias value during the forward pass together with the
output of the linear transformation :math:`y = xA^T`. This is useful when
the bias addition can be fused to subsequent operations.
params_dtype : torch.dtype, default = `torch.get_default_dtype()`
it controls the type used to allocate the initial parameters. Useful when
the model is trained with lower precision and the original FP32 parameters
would not fit in GPU memory.
"""
def __init__(
self,
num_gemms: int,
in_features: int,
out_features: int,
sequence_parallel: bool = False,
fuse_wgrad_accumulation: bool = False,
tp_group: Optional[dist_group_type] = None,
tp_size: int = 1,
get_rng_state_tracker: Optional[Callable] = None,
rng_tracker_name: Optional[str] = None,
init_method: Optional[Callable] = None,
bias: bool = True,
return_bias: bool = False,
params_dtype: Optional[torch.dtype] = None,
parallel_mode: Optional[str] = None,
device: Union[torch.device, str] = "cuda",
ub_overlap_rs: bool = False,
ub_overlap_ag: bool = False,
ub_name: Optional[str] = None,
) -> None:
super().__init__()
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
self.num_gemms = num_gemms
self.in_features = in_features
self.out_features = out_features
self.fuse_wgrad_accumulation = fuse_wgrad_accumulation
self.use_bias = bias
self.return_bias = return_bias
self.apply_bias = bias and not return_bias
self.ub_overlap_rs = ub_overlap_rs
self.ub_overlap_ag = ub_overlap_ag
self.ub_name = ub_name
assert (
not ub_overlap_rs and not ub_overlap_ag
), "BatchedLinear doesn't support Userbuffer overlap."
self.get_rng_state_tracker = get_rng_state_tracker
self.rng_tracker_name = rng_tracker_name
self._offsets = {"input": 0, "weight": num_gemms, "output": 2 * num_gemms, "grad_output": 0}
if tp_group is None:
self.tp_size = tp_size
if tp_size == 1:
self.set_tensor_parallel_group(tp_group)
else:
self.tp_size = get_distributed_world_size(tp_group)
self.set_tensor_parallel_group(tp_group)
self.set_nccl_overlap_warning_if_tp()
self.parallel_mode = parallel_mode
assert (
self.parallel_mode in GemmParallelModes
), f"parallel_mode {parallel_mode} not supported"
if self.parallel_mode == "column":
self.out_features = divide(self.out_features, self.tp_size)
elif self.parallel_mode == "row":
self.in_features = divide(self.in_features, self.tp_size)
self.sequence_parallel = (self.tp_size > 1) and sequence_parallel
for i in range(self.num_gemms):
# Construct weight parameter
self.register_parameter(
f"weight{i}",
torch.nn.Parameter(
torch.empty(
self.out_features,
self.in_features,
device=device,
dtype=params_dtype,
),
),
init_fn=init_method,
get_rng_state_tracker=get_rng_state_tracker,
fp8_meta_index=self._offsets["weight"] + i,
)
# Construct bias parameters if needed
if self.use_bias:
self.register_parameter(
f"bias{i}",
torch.nn.Parameter(
torch.empty(
self.out_features,
device=device,
dtype=params_dtype,
),
),
init_fn=init_method_constant(0.0),
)
else:
bias = torch.Tensor().to(dtype=params_dtype, device=device)
setattr(self, f"bias{i}", bias)
if self.primary_weights_in_fp8:
self.init_fp8_metadata(num_gemms=self.num_gemms)
self.reset_parameters(defer_init=device == "meta")
# For RPL, bias has to be added after TP collectives
# So it cannot be fused with the GEMM
if self.parallel_mode == "row" and self.apply_bias:
self.gemm_bias_unfused_add = True
else:
self.gemm_bias_unfused_add = False
def reset_parameters(self, defer_init=False):
super().reset_parameters(defer_init=defer_init)
if not defer_init:
# Set parallelism attributes for linear weights
for i in range(self.num_gemms):
set_tensor_model_parallel_attributes(
tensor=getattr(self, f"weight{i}"),
is_parallel=True,
dim=1 if self.parallel_mode == "row" else 0,
stride=1,
)
# Set parallelism attributes for linear biases
if self.use_bias:
for i in range(self.num_gemms):
if self.parallel_mode == "row":
setattr(
getattr(self, f"bias{i}"),
"sequence_parallel",
self.sequence_parallel,
)
elif self.parallel_mode == "column":
set_tensor_model_parallel_attributes(getattr(self, f"bias{i}"), True, 0, 1)
@no_torch_dynamo()
def forward(
self,
inp: torch.Tensor,
m_splits: List[int],
is_first_microbatch: Optional[bool] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
"""
Apply the linear transformation to the input.
Parameters
----------
inp : torch.Tensor
Input tensor.
m_splits : List[int]
List of integers representing the split of the input tensor.
is_first_microbatch : {True, False, None}, default = None
During training using either gradient accumulation or
pipeline parallelism a minibatch of data is further split
into microbatches. Between the microbatches of the same minibatch
the model weights are not updated. Setting this parameter indicates
whether the current microbatch is the first in a minibatch or not.
When set, this parameter enables additional optimizations:
* during FP8 training, it allows caching of the FP8 versions of
the weights
* it also allows skipping gradient accumulation during the
first microbatch (since it is the first gradient being
produced)
"""
assert not isinstance(
inp, Float8Tensor
), "BatchedLinear doesn't support input tensor in FP8."
assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs."
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
if skip_fp8_weight_update is not None:
is_first_microbatch = False
with self.prepare_forward(inp, num_gemms=self.num_gemms) as inp:
weight_tensors = [getattr(self, f"weight{i}") for i in range(self.num_gemms)]
bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)]
if not self.fp8:
weight_tensors = [
w.dequantize() if isinstance(w, QuantizedTensor) else w for w in weight_tensors
]
input_quantizers, weight_quantizers, output_quantizers = (
[None] * self.num_gemms,
[None] * self.num_gemms,
[None] * self.num_gemms,
)
grad_output_quantizers, _ = [None] * self.num_gemms, [None] * self.num_gemms
if self.fp8:
input_quantizers = [
self.quantizers["scaling_fwd"][self._offsets["input"] + i]
for i in range(self.num_gemms)
]
for i in range(self.num_gemms):
input_quantizers[i].internal = True
weight_quantizers = [
self.quantizers["scaling_fwd"][self._offsets["weight"] + i]
for i in range(self.num_gemms)
]
for i in range(self.num_gemms):
weight_quantizers[i].internal = True
if torch.is_grad_enabled():
grad_output_quantizers = [
self.quantizers["scaling_bwd"][self._offsets["input"] + i]
for i in range(self.num_gemms)
]
for i in range(self.num_gemms):
grad_output_quantizers[i].internal = True
if torch.is_grad_enabled():
linear_fn = _BatchedLinear.apply
args = []
else:
linear_fn = _BatchedLinear.forward
args = [None]
args += (
inp,
m_splits,
self.apply_bias and not self.gemm_bias_unfused_add,
is_first_microbatch,
self.fp8,
self.fp8_calibration,
input_quantizers,
weight_quantizers,
output_quantizers,
grad_output_quantizers,
self.fuse_wgrad_accumulation,
is_cpu_offload_enabled(),
self.sequence_parallel,
self.activation_dtype,
torch.is_grad_enabled(),
self,
skip_fp8_weight_update,
*weight_tensors,
*bias_tensors,
)
out = linear_fn(*args)
if self.gemm_bias_unfused_add:
out_shape = out.shape
out = torch.cat(
[
o + cast_if_needed(b, self.activation_dtype)
for o, b in zip(
torch.split(out.view(-1, self.out_features), m_splits), bias_tensors
)
]
).view(out_shape)
if self.return_bias:
return out, [cast_if_needed(b, self.activation_dtype) for b in bias_tensors]
return out
......@@ -12,6 +12,7 @@ from operator import mul as multiply_op
import torch
from torch.nn.parameter import Parameter
from torch.nn import init
from torch.utils.cpp_extension import IS_HIP_EXTENSION
import transformer_engine_torch as tex
......@@ -1454,7 +1455,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
fc2_weight = fc2_weight.from_float8()
# Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode
if self.bias_gelu_nvfusion and not use_reentrant_activation_recompute():
if ( not IS_HIP_EXTENSION
and self.bias_gelu_nvfusion and not use_reentrant_activation_recompute() ):
self.bias_gelu_nvfusion = False
if torch.is_grad_enabled():
......
......@@ -11,7 +11,14 @@ import triton
import triton.language as tl
from transformer_engine_torch import DType as TE_DType
from torch.utils.cpp_extension import IS_HIP_EXTENSION
if IS_HIP_EXTENSION:
e5m2_data_type = tl.float8e5b16
e4m3_data_type = tl.float8e4b8
else:
e5m2_data_type = tl.float8e5
e4m3_data_type = tl.float8e4nv
@triton.jit
def _row_id_map_pass_1_kernel(
......
......@@ -13,7 +13,7 @@ import torch
import transformer_engine.pytorch.cpp_extensions as ext
from .tensor.quantized_tensor import QuantizedTensor
from torch.utils.cpp_extension import IS_HIP_EXTENSION
def requires_grad(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None:
"""Check if any of the given tensors require gradient."""
......@@ -242,12 +242,34 @@ def assert_dim_for_fp8_exec(*tensors: List[torch.Tensor]) -> None:
f"but got tensor with dims={list(tensor.size())}"
)
if IS_HIP_EXTENSION:
def is_mi200():
"""check whether this machine is mi200/210/250"""
import re
return (re.search('AMD Instinct MI2.0', torch.cuda.get_device_name(torch.cuda.current_device())) is not None)
def is_K100_AI():
"""check whether this machine is K100_AI"""
import re
return (re.search('K100_AI', torch.cuda.get_device_name(torch.cuda.current_device())) is not None)
def is_BW3000():
"""check whether this machine is BW"""
import re
return (re.search('BW', torch.cuda.get_device_name(torch.cuda.current_device())) is not None)
def is_bf16_compatible() -> None:
"""Replaces torch.cuda.is_bf16_compatible() with an explicit
check on device compute capability to enforce sm_80 or higher.
"""
return torch.cuda.get_device_capability()[0] >= 8
if IS_HIP_EXTENSION:
# only MI200 and MI300 machines support bf16
if get_device_compute_capability() == (9, 4) or is_mi200() or is_K100_AI() or is_BW3000():
return True
else:
return False
else:
return torch.cuda.get_device_capability()[0] >= 8
def non_tn_fp8_gemm_supported() -> bool:
......@@ -260,6 +282,9 @@ def non_tn_fp8_gemm_supported() -> bool:
@functools.lru_cache(maxsize=None)
def get_cudnn_version() -> Tuple[int, int, int]:
"""Runtime cuDNN version (major, minor, patch)"""
# ROCm fused attn does not use cudnn, return high numbers to avoid tests filtering out
if IS_HIP_EXTENSION:
return (99, 0, 0)
encoded_version = ext.get_cudnn_version()
major_version_magnitude = 1000 if encoded_version < 90000 else 10000
major, encoded_version = divmod(encoded_version, major_version_magnitude)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment