Commit c520cba3 authored by yuguo's avatar yuguo
Browse files

[DCU] Preliminary adaptation

parent 5b6ef054
......@@ -17,6 +17,7 @@ import numpy as np
from packaging.version import Version as PkgVersion
import torch
from torch.utils.cpp_extension import IS_HIP_EXTENSION
import transformer_engine_torch as tex
from transformer_engine.pytorch.utils import (
......@@ -98,7 +99,7 @@ try:
except PackageNotFoundError:
if (
torch.cuda.is_available()
and get_device_compute_capability() >= (8, 0)
and (IS_HIP_EXTENSION or get_device_compute_capability() >= (8, 0))
and dpa_utils._NVTE_FLASH_ATTN
):
attn_log.fa_logger.debug(
......@@ -128,7 +129,7 @@ else:
fa_utils.set_flash_attention_version()
elif (
torch.cuda.is_available()
and get_device_compute_capability() >= (8, 0)
and (IS_HIP_EXTENSION or get_device_compute_capability() >= (8, 0))
and dpa_utils._NVTE_FLASH_ATTN
):
attn_log.fa_logger.warning(
......@@ -147,9 +148,10 @@ else:
# Detect flash-attn v3 in the environment
# This section will be removed when FA3 is released as a regular FA package,
# i.e. flashattn-hopper 3.0.0 as flash-attn 3.0.0
try:
if not IS_HIP_EXTENSION:
try:
fa_utils.fa3_version = PkgVersion(get_pkg_version("flashattn-hopper"))
except PackageNotFoundError:
except PackageNotFoundError:
if (
torch.cuda.is_available()
and get_device_compute_capability() >= (9, 0)
......@@ -159,7 +161,7 @@ except PackageNotFoundError:
"flash-attn v3 is not installed. To use, please install it by \n%s",
fa_utils.v3_installation_steps,
)
else:
else:
from flashattn_hopper.flash_attn_interface import flash_attn_func as flash_attn_func_v3
from flashattn_hopper.flash_attn_interface import (
flash_attn_varlen_func as flash_attn_varlen_func_v3,
......
......@@ -5,5 +5,7 @@
"""Python interface for c++ extensions"""
from transformer_engine_torch import *
from .fused_attn import *
from torch.utils.cpp_extension import IS_HIP_EXTENSION
if not IS_HIP_EXTENSION:
from .fused_attn import *
from .gemm import *
......@@ -224,3 +224,85 @@ def general_grouped_gemm(
)
return out, bias, gelu_input
def general_batched_gemm(
A: List[torch.Tensor],
B: List[torch.Tensor],
out: List[torch.Tensor],
out_dtype: torch.dtype,
workspaces: List[torch.Tensor],
layout: str = "TN",
m_splits: Optional[List[int]] = None,
gelu: bool = False,
grad=False,
accumulate: bool = False,
bias: Optional[List[torch.Tensor]] = None,
use_bias: bool = False,
use_split_accumulator: bool = False,
D_dtype: Optional[tex.DType] = None,
single_output=False,
) -> Tuple[List[torch.Tensor], ...]:
"""
TN layout Grouped GEMM with fp8 inputs.
"""
num_gemms = len(A)
transa = layout[0] == "T"
transb = layout[1] == "T"
# assert [a.is_contiguous() for a in A]
# assert [b.is_contiguous() for b in B]
if isinstance(A[0], Float8TensorBase):
for a, b in zip(A, B):
assert_dim_for_fp8_exec(a._data)
assert_dim_for_fp8_exec(b._data)
empty_tensor = _empty_tensor()
empty_tensors = [empty_tensor] * num_gemms
# Use bfloat16 as default bias_dtype
gelu_input = empty_tensors
out_dtype = TE_DType[out[0].dtype] if D_dtype is None else D_dtype
sm_count = get_sm_count()
if grad and use_bias:
grad_bias = [
torch.empty(B[i].shape[1], dtype=out[0].dtype, device="cuda") for i in range(num_gemms)
]
else:
grad_bias = empty_tensors
bias = bias if use_bias else empty_tensors
if use_bias:
bias_dtype = TE_DType[grad_bias[0].dtype] if grad else TE_DType[bias[0].dtype]
else:
bias_dtype = TE_DType[torch.bfloat16]
if gelu:
gelu_input = [
torch.empty_like(o, dtype=bias_dtype, memory_format=torch.contiguous_format)
for o in out
] # this should differ with respect to single output
bias = tex.te_general_batched_gemm(
A,
transa,
B,
transb,
out,
out_dtype,
m_splits,
grad_bias if grad else bias,
bias_dtype,
single_output,
gelu_input, # this is pre_gelu_out
grad, # grad
workspaces,
workspaces[0].shape[0],
accumulate,
use_split_accumulator,
sm_count - int(os.getenv("NVTE_EXT_MARGIN_SM", str(sm_count))),
)
return out, bias, gelu_input
......@@ -14,11 +14,15 @@
#include <ATen/cudnn/Handle.h>
#include <ATen/native/DispatchStub.h>
#include <c10/macros/Macros.h>
#include <cuda_runtime.h>
#ifndef USE_ROCM
#include <cublasLt.h>
#include <cuda.h>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <cudnn.h>
#include <cuda_bf16.h>
#else
#include <hip/hip_bf16.h>
#endif
#include <torch/extension.h>
#include <torch/torch.h>
#include <transformer_engine/activation.h>
......
......@@ -93,6 +93,16 @@ std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
bool grad, std::vector<at::Tensor> workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, int math_sm_count);
#ifdef __HIP_PLATFORM_AMD__
std::optional<std::vector<at::Tensor>> te_general_batched_gemm(
std::vector<py::handle> A, bool transa, std::vector<py::handle> B, bool transb,
std::optional<std::vector<at::Tensor>> D, transformer_engine::DType D_type,
std::vector<int64_t> m_splits, std::vector<at::Tensor> bias,
transformer_engine::DType bias_type, bool single_output, std::vector<at::Tensor> pre_gelu_out,
bool grad, std::vector<at::Tensor> workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, int math_sm_count)
#endif
/***************************************************************************************************
* Transpose
**************************************************************************************************/
......
......@@ -16,11 +16,16 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend(
float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left,
int64_t window_size_right) {
#ifdef __HIP_PLATFORM_AMD__
static_assert(false,
"Get_fused_attn_backend is not surpported in rocm for normalization yet.");
#else
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type,
attn_mask_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv,
head_dim_qk, head_dim_v, window_size_left, window_size_right);
return fused_attention_backend;
#endif
}
// fast zero-fills of tensors
......@@ -93,6 +98,10 @@ std::vector<py::object> fused_attn_fwd(
const c10::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer,
py::handle o_quantizer, const c10::optional<at::Tensor> Bias,
const c10::optional<at::Generator> rng_gen, size_t rng_elts_per_thread) {
#ifdef __HIP_PLATFORM_AMD__
static_assert(false,
"Fused_attn_fwd is not surpported in rocm for normalization yet.");
#else
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
TensorWrapper te_Q, te_K, te_V, te_O, te_S;
......@@ -254,6 +263,7 @@ std::vector<py::object> fused_attn_fwd(
// if training, [O, softmax-related tensors, rng_state]; if inference, [O]
return output_tensors;
#endif
}
// fused attention BWD with separate Q, K and V
......@@ -267,6 +277,10 @@ std::vector<py::object> fused_attn_bwd(
const c10::optional<at::Tensor> cu_seqlens_q_padded,
const c10::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer,
py::handle dp_quantizer, py::handle dqkv_quantizer) {
#ifdef __HIP_PLATFORM_AMD__
static_assert(false,
"Fused_attn_bwd is not surpported in rocm for normalization yet.");
#else
using namespace transformer_engine;
using namespace transformer_engine::pytorch;
auto none = py::none();
......@@ -492,6 +506,7 @@ std::vector<py::object> fused_attn_bwd(
nvte_tensor_pack_destroy(&nvte_aux_tensor_pack);
return {py_dQ, py_dK, py_dV, py::cast(dBias)};
#endif
}
namespace flash_attention {
......
......@@ -6,6 +6,12 @@
#include "extensions.h"
#ifdef USE_ROCM
size_t get_cublasLt_version() { int version = 10000000; return version; }
size_t get_cudnn_version() { int version = 0; return version; }
#else
size_t get_cublasLt_version() { return cublasLtGetVersion(); }
size_t get_cudnn_version() { return cudnnGetVersion(); }
#endif
......@@ -174,6 +174,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("input_list"), py::arg("output_list"), py::arg("quantizer_list"), py::arg("otype"));
m.def("te_general_grouped_gemm", &te_general_grouped_gemm, "Grouped GEMM");
#ifdef USE_ROCM
m.def("te_general_batched_gemm", &te_general_batched_gemm, "Batched GEMM"); /// rocblas
#endif
m.def("fused_attn_fwd", &fused_attn_fwd,
"Fused Attention FP8/BF16/FP16 FWD with separate Q, K and V");
m.def("fused_attn_bwd", &fused_attn_bwd,
......@@ -207,6 +210,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("get_cudnn_version", &get_cudnn_version, "Get cuDNN version",
py::call_guard<py::gil_scoped_release>());
m.attr("_num_cublas_streams") = py::int_(transformer_engine::num_streams);
#ifdef USE_ROCM
m.attr("_num_cublas_batchgemm_streams") = py::int_(transformer_engine::num_batchgemm_streams);
#endif
// Support THD format for Context Parallel
m.def("thd_read_half_tensor", &thd_read_half_tensor,
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment