Commit a13c52ad authored by wenjh's avatar wenjh
Browse files

Fix user args core dump in mt

parent 3a5755b1
/************************************************************************* /*************************************************************************
* Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
* *
* License for AMD contributions = MIT. See LICENSE for more information * License for AMD contributions = MIT. See LICENSE for more information
************************************************************************/ ************************************************************************/
#include <transformer_engine/gemm.h> #include <transformer_engine/gemm.h>
#include <transformer_engine/transformer_engine.h> #include <transformer_engine/transformer_engine.h>
#include <type_traits> #include <type_traits>
#ifdef USE_HIPBLASLT #ifdef USE_HIPBLASLT
#include <hipblaslt/hipblaslt.h> #include <hipblaslt/hipblaslt.h>
#include <unistd.h> #include <unistd.h>
#include <chrono> #include <chrono>
#include <forward_list> #include <forward_list>
#include <fstream> #include <fstream>
#include <mutex> #include <mutex>
#include <optional> #include <optional>
#include <sstream> #include <sstream>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#endif #endif
#ifdef USE_ROCBLAS #ifdef USE_ROCBLAS
#define ROCBLAS_BETA_FEATURES_API #define ROCBLAS_BETA_FEATURES_API
#include <rocblas/rocblas.h> #include <rocblas/rocblas.h>
#include <hipblaslt/hipblaslt-ext.hpp> #include <hipblaslt/hipblaslt-ext.hpp>
#include <hipcub/hipcub.hpp> #include <hipcub/hipcub.hpp>
#endif #endif
#include <cstdint> #include <cstdint>
#include <cstdlib> #include <cstdlib>
#include <iostream> #include <iostream>
#include <string> #include <string>
#include "../common.h" #include "../common.h"
#include "../util/handle_manager.h" #include "../util/handle_manager.h"
#include "../util/logging.h" #include "../util/logging.h"
#include "../util/vectorized_pointwise.h" #include "../util/vectorized_pointwise.h"
namespace { namespace {
#ifdef USE_HIPBLASLT #ifdef USE_HIPBLASLT
static hipDataType get_hipblaslt_dtype(const transformer_engine::DType t) { static hipDataType get_hipblaslt_dtype(const transformer_engine::DType t) {
using namespace transformer_engine; using namespace transformer_engine;
switch (t) { switch (t) {
case DType::kFloat16: case DType::kFloat16:
return HIP_R_16F; return HIP_R_16F;
case DType::kFloat32: case DType::kFloat32:
return HIP_R_32F; return HIP_R_32F;
case DType::kBFloat16: case DType::kBFloat16:
return HIP_R_16BF; return HIP_R_16BF;
case DType::kFloat8E4M3: case DType::kFloat8E4M3:
return HIP_R_8F_E4M3; return HIP_R_8F_E4M3;
case DType::kFloat8E5M2: case DType::kFloat8E5M2:
return HIP_R_8F_E5M2; return HIP_R_8F_E5M2;
case DType::kInt8: case DType::kInt8:
return HIP_R_8I; return HIP_R_8I;
case DType::kInt32: case DType::kInt32:
return HIP_R_32I; return HIP_R_32I;
default: default:
NVTE_ERROR("Invalid type"); NVTE_ERROR("Invalid type");
} }
} }
#endif #endif
#ifdef USE_ROCBLAS #ifdef USE_ROCBLAS
rocblas_datatype get_rocblas_dtype(const transformer_engine::DType t) { rocblas_datatype get_rocblas_dtype(const transformer_engine::DType t) {
using namespace transformer_engine; using namespace transformer_engine;
switch (t) { switch (t) {
case DType::kFloat16: case DType::kFloat16:
return rocblas_datatype_f16_r; return rocblas_datatype_f16_r;
case DType::kFloat32: case DType::kFloat32:
return rocblas_datatype_f32_r; return rocblas_datatype_f32_r;
case DType::kBFloat16: case DType::kBFloat16:
return rocblas_datatype_bf16_r; return rocblas_datatype_bf16_r;
case DType::kFloat8E4M3: case DType::kFloat8E4M3:
return rocblas_datatype_f8_r; return rocblas_datatype_f8_r;
case DType::kFloat8E5M2: case DType::kFloat8E5M2:
return rocblas_datatype_bf8_r; return rocblas_datatype_bf8_r;
default: default:
NVTE_ERROR("Invalid type"); NVTE_ERROR("Invalid type");
} }
} }
#endif #endif
} //namespace } //namespace
namespace transformer_engine { namespace transformer_engine {
#ifdef USE_ROCBLAS #ifdef USE_ROCBLAS
namespace detail { namespace detail {
struct Empty {}; struct Empty {};
__device__ inline fp32 identity(fp32 value, const Empty&) { return value; } __device__ inline fp32 identity(fp32 value, const Empty&) { return value; }
__inline__ __device__ float gelu(float x, const Empty&) { __inline__ __device__ float gelu(float x, const Empty&) {
float cdf = 0.5f * (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x)))); float cdf = 0.5f * (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x))));
return x * cdf; return x * cdf;
} }
__inline__ __device__ float gelu_forward(float x) { __inline__ __device__ float gelu_forward(float x) {
float cdf = 0.5f * (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x)))); float cdf = 0.5f * (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x))));
return x * cdf; return x * cdf;
} }
template <typename T, int THREADS_PER_BLOCK> template <typename T, int THREADS_PER_BLOCK>
__global__ void __launch_bounds__(THREADS_PER_BLOCK) __global__ void __launch_bounds__(THREADS_PER_BLOCK)
gelu_forward_kernel(const float* in, T* out, float* amax, const float* scale, int m, int n) { gelu_forward_kernel(const float* in, T* out, float* amax, const float* scale, int m, int n) {
// fp8 output flow // fp8 output flow
if constexpr (std::is_same<T, fp8e4m3>::value || std::is_same<T, fp8e5m2>::value) { if constexpr (std::is_same<T, fp8e4m3>::value || std::is_same<T, fp8e5m2>::value) {
typedef hipcub::BlockReduce<float, THREADS_PER_BLOCK> BlockReduce; typedef hipcub::BlockReduce<float, THREADS_PER_BLOCK> BlockReduce;
__shared__ typename BlockReduce::TempStorage block_temp_storage; __shared__ typename BlockReduce::TempStorage block_temp_storage;
float thread_amax = 0; float thread_amax = 0;
for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) { for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) {
float x = in[id]; float x = in[id];
float y = gelu_forward(x); float y = gelu_forward(x);
out[id] = (T)((*scale) * y); out[id] = (T)((*scale) * y);
thread_amax = std::fmax(std::fabs(y), thread_amax); thread_amax = std::fmax(std::fabs(y), thread_amax);
} }
float block_amax = BlockReduce(block_temp_storage).Reduce(thread_amax, hipcub::Max()); float block_amax = BlockReduce(block_temp_storage).Reduce(thread_amax, hipcub::Max());
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
atomicMaxFloat(amax, block_amax); atomicMaxFloat(amax, block_amax);
} }
} else { } else {
for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) { for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) {
float x = in[id]; float x = in[id];
float y = gelu_forward(x); float y = gelu_forward(x);
out[id] = (T)(y); out[id] = (T)(y);
} }
} }
} }
template <typename T> template <typename T>
void gelu_forward_kernelLauncher(const float* in, T* out, float* amax, const float* scale, int m, void gelu_forward_kernelLauncher(const float* in, T* out, float* amax, const float* scale, int m,
int n, hipStream_t stream) { int n, hipStream_t stream) {
dim3 block, grid; dim3 block, grid;
constexpr int THREADS_PER_BLOCK = 1024; constexpr int THREADS_PER_BLOCK = 1024;
block.x = THREADS_PER_BLOCK; block.x = THREADS_PER_BLOCK;
grid.x = ceil(1.0 * m * n / THREADS_PER_BLOCK); grid.x = ceil(1.0 * m * n / THREADS_PER_BLOCK);
hipLaunchKernelGGL((gelu_forward_kernel<T, THREADS_PER_BLOCK>), dim3(grid), dim3(block), 0, hipLaunchKernelGGL((gelu_forward_kernel<T, THREADS_PER_BLOCK>), dim3(grid), dim3(block), 0,
stream, in, out, amax, scale, m, n); stream, in, out, amax, scale, m, n);
} }
__inline__ __device__ float gelu_backward(float x, float dy) { __inline__ __device__ float gelu_backward(float x, float dy) {
constexpr float kBeta = 0.7978845608028654f; constexpr float kBeta = 0.7978845608028654f;
constexpr float kKappa = 0.044715f; constexpr float kKappa = 0.044715f;
float x_sq = x * x; float x_sq = x * x;
float x_cube = x_sq * x; float x_cube = x_sq * x;
float tanh_inner = tanhf((kBeta * (x + kKappa * x_cube))); float tanh_inner = tanhf((kBeta * (x + kKappa * x_cube)));
float left = 0.5 * x; float left = 0.5 * x;
float right = 1.0f + tanh_inner; float right = 1.0f + tanh_inner;
float left_derivative = 0.5 * right; float left_derivative = 0.5 * right;
float tanh_derivative = 1 - tanh_inner * tanh_inner; float tanh_derivative = 1 - tanh_inner * tanh_inner;
float inner_derivative = kBeta * (1.0f + 3.0 * kKappa * x_sq); float inner_derivative = kBeta * (1.0f + 3.0 * kKappa * x_sq);
float right_derivative = left * tanh_derivative * inner_derivative; float right_derivative = left * tanh_derivative * inner_derivative;
return dy * (left_derivative + right_derivative); return dy * (left_derivative + right_derivative);
} }
template <typename T, typename Taux> template <typename T, typename Taux>
__global__ void gelu_backward_kernel(const float* dy, T* out, const Taux* __restrict pre_gelu_out, __global__ void gelu_backward_kernel(const float* dy, T* out, const Taux* __restrict pre_gelu_out,
int m, int n) { int m, int n) {
for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) { for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) {
float x = (float)pre_gelu_out[id]; float x = (float)pre_gelu_out[id];
float dx = (float)gelu_backward(x, dy[id]); float dx = (float)gelu_backward(x, dy[id]);
out[id] = (T)(dx); out[id] = (T)(dx);
} }
} }
template <typename T, typename Taux> template <typename T, typename Taux>
void gelu_backward_kernelLauncher(const float* in, T* out, const Taux* pre_gelu_out, int m, int n, void gelu_backward_kernelLauncher(const float* in, T* out, const Taux* pre_gelu_out, int m, int n,
hipStream_t stream) { hipStream_t stream) {
int blocks_per_row = ceil(float(n) / 256); int blocks_per_row = ceil(float(n) / 256);
dim3 grid(min(m * blocks_per_row, 65536)); dim3 grid(min(m * blocks_per_row, 65536));
dim3 block(min(n, 256)); dim3 block(min(n, 256));
hipLaunchKernelGGL((gelu_backward_kernel<T, Taux>), dim3(grid), dim3(block), 0, stream, in, out, hipLaunchKernelGGL((gelu_backward_kernel<T, Taux>), dim3(grid), dim3(block), 0, stream, in, out,
pre_gelu_out, m, n); pre_gelu_out, m, n);
} }
template <typename T, typename Tb, int THREADS_PER_BLOCK> template <typename T, typename Tb, int THREADS_PER_BLOCK>
__global__ void __launch_bounds__(THREADS_PER_BLOCK) __global__ void __launch_bounds__(THREADS_PER_BLOCK)
add_bias_kernel(const float* in, T* out, const Tb* __restrict bias, float* amax, add_bias_kernel(const float* in, T* out, const Tb* __restrict bias, float* amax,
const float* scale, int m, int n) { const float* scale, int m, int n) {
// fp8 output flow // fp8 output flow
if constexpr (std::is_same<T, fp8e4m3>::value || std::is_same<T, fp8e5m2>::value) { if constexpr (std::is_same<T, fp8e4m3>::value || std::is_same<T, fp8e5m2>::value) {
typedef hipcub::BlockReduce<float, THREADS_PER_BLOCK> BlockReduce; typedef hipcub::BlockReduce<float, THREADS_PER_BLOCK> BlockReduce;
__shared__ typename BlockReduce::TempStorage block_temp_storage; __shared__ typename BlockReduce::TempStorage block_temp_storage;
float thread_amax = 0; float thread_amax = 0;
for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) { for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) {
float reg_bias = (float)bias[id % n]; float reg_bias = (float)bias[id % n];
float val = in[id] + reg_bias; float val = in[id] + reg_bias;
out[id] = (T)((*scale) * val); out[id] = (T)((*scale) * val);
// deal with amax of D // deal with amax of D
thread_amax = std::fmax(std::fabs(val), thread_amax); thread_amax = std::fmax(std::fabs(val), thread_amax);
} }
// num_valid can be ignored since each thread amax is set to 0 // num_valid can be ignored since each thread amax is set to 0
float block_amax = BlockReduce(block_temp_storage).Reduce(thread_amax, hipcub::Max()); float block_amax = BlockReduce(block_temp_storage).Reduce(thread_amax, hipcub::Max());
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
atomicMaxFloat(amax, block_amax); atomicMaxFloat(amax, block_amax);
} }
} else { } else {
for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) { for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) {
float reg_bias = (float)bias[id % n]; float reg_bias = (float)bias[id % n];
float val = in[id] + reg_bias; float val = in[id] + reg_bias;
out[id] = (T)(val); out[id] = (T)(val);
} }
} }
} }
template <typename T, typename Tb> template <typename T, typename Tb>
void add_bias_kernelLauncher(const float* in, T* out, const Tb* __restrict bias, float* amax, void add_bias_kernelLauncher(const float* in, T* out, const Tb* __restrict bias, float* amax,
const float* scale, int m, int n, hipStream_t stream) { const float* scale, int m, int n, hipStream_t stream) {
dim3 block, grid; dim3 block, grid;
constexpr int THREADS_PER_BLOCK = 1024; constexpr int THREADS_PER_BLOCK = 1024;
block.x = THREADS_PER_BLOCK; block.x = THREADS_PER_BLOCK;
grid.x = ceil(1.0 * m * n / THREADS_PER_BLOCK); grid.x = ceil(1.0 * m * n / THREADS_PER_BLOCK);
hipLaunchKernelGGL((add_bias_kernel<T, Tb, THREADS_PER_BLOCK>), dim3(grid), dim3(block), 0, hipLaunchKernelGGL((add_bias_kernel<T, Tb, THREADS_PER_BLOCK>), dim3(grid), dim3(block), 0,
stream, in, out, bias, amax, scale, m, n); stream, in, out, bias, amax, scale, m, n);
} }
template <typename T, typename Taux, typename Tb, int THREADS_PER_BLOCK> template <typename T, typename Taux, typename Tb, int THREADS_PER_BLOCK>
__global__ void __launch_bounds__(THREADS_PER_BLOCK) __global__ void __launch_bounds__(THREADS_PER_BLOCK)
add_bias_gelu_kernel(const float* in, T* out, Taux* pre_gelu_out, const Tb* __restrict bias, add_bias_gelu_kernel(const float* in, T* out, Taux* pre_gelu_out, const Tb* __restrict bias,
float* amax, const float* scale, int m, int n) { float* amax, const float* scale, int m, int n) {
// fp8 output flow // fp8 output flow
if constexpr (std::is_same<T, fp8e4m3>::value || std::is_same<T, fp8e5m2>::value) { if constexpr (std::is_same<T, fp8e4m3>::value || std::is_same<T, fp8e5m2>::value) {
// only need to deal with amax and scale of D, no need to deal with amax and scale of pre_gelu_out // only need to deal with amax and scale of D, no need to deal with amax and scale of pre_gelu_out
typedef hipcub::BlockReduce<float, THREADS_PER_BLOCK> BlockReduce; typedef hipcub::BlockReduce<float, THREADS_PER_BLOCK> BlockReduce;
__shared__ typename BlockReduce::TempStorage block_temp_storage; __shared__ typename BlockReduce::TempStorage block_temp_storage;
float thread_amax = 0; float thread_amax = 0;
for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) { for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) {
float reg_bias = (float)bias[id % n]; float reg_bias = (float)bias[id % n];
float val = in[id] + reg_bias; float val = in[id] + reg_bias;
// pre_gelu_out guaranteed not to be fp8 type // pre_gelu_out guaranteed not to be fp8 type
pre_gelu_out[id] = (Taux)(val); pre_gelu_out[id] = (Taux)(val);
val = gelu_forward(val); val = gelu_forward(val);
out[id] = (T)((*scale) * val); out[id] = (T)((*scale) * val);
// deal with amax of D // deal with amax of D
thread_amax = std::fmax(std::fabs(val), thread_amax); thread_amax = std::fmax(std::fabs(val), thread_amax);
} }
// num_valid can be ignored since each thread amax is set to 0 // num_valid can be ignored since each thread amax is set to 0
float block_amax = BlockReduce(block_temp_storage).Reduce(thread_amax, hipcub::Max()); float block_amax = BlockReduce(block_temp_storage).Reduce(thread_amax, hipcub::Max());
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
atomicMaxFloat(amax, block_amax); atomicMaxFloat(amax, block_amax);
} }
} else { } else {
for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) { for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) {
float reg_bias = (float)bias[id % n]; float reg_bias = (float)bias[id % n];
float val = in[id] + reg_bias; float val = in[id] + reg_bias;
pre_gelu_out[id] = (Taux)(val); pre_gelu_out[id] = (Taux)(val);
out[id] = (T)(gelu_forward(val)); out[id] = (T)(gelu_forward(val));
} }
} }
} }
template <typename T, typename Taux, typename Tb> template <typename T, typename Taux, typename Tb>
void add_bias_gelu_kernelLauncher(const float* in, T* out, Taux* pre_gelu_out, void add_bias_gelu_kernelLauncher(const float* in, T* out, Taux* pre_gelu_out,
const Tb* __restrict bias, float* amax, const float* scale, int m, const Tb* __restrict bias, float* amax, const float* scale, int m,
int n, hipStream_t stream) { int n, hipStream_t stream) {
dim3 block, grid; dim3 block, grid;
constexpr int THREADS_PER_BLOCK = 1024; constexpr int THREADS_PER_BLOCK = 1024;
block.x = THREADS_PER_BLOCK; block.x = THREADS_PER_BLOCK;
grid.x = ceil(1.0 * m * n / THREADS_PER_BLOCK); grid.x = ceil(1.0 * m * n / THREADS_PER_BLOCK);
hipLaunchKernelGGL((add_bias_gelu_kernel<T, Taux, Tb, THREADS_PER_BLOCK>), dim3(grid), hipLaunchKernelGGL((add_bias_gelu_kernel<T, Taux, Tb, THREADS_PER_BLOCK>), dim3(grid),
dim3(block), 0, stream, in, out, pre_gelu_out, bias, amax, scale, m, n); dim3(block), 0, stream, in, out, pre_gelu_out, bias, amax, scale, m, n);
} }
template <typename Tin, typename T> template <typename Tin, typename T>
__global__ void identity_kernel(const Tin* in, T* out, int n) { __global__ void identity_kernel(const Tin* in, T* out, int n) {
for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < n; id += blockDim.x * gridDim.x) { for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < n; id += blockDim.x * gridDim.x) {
Tin val = in[id]; Tin val = in[id];
out[id] = (T)(val); out[id] = (T)(val);
} }
} }
template <typename Tin, typename T> template <typename Tin, typename T>
void identity_kernelLauncher(const Tin* in, T* out, int n, hipStream_t stream) { void identity_kernelLauncher(const Tin* in, T* out, int n, hipStream_t stream) {
dim3 block, grid; dim3 block, grid;
block.x = 256; block.x = 256;
grid.x = ceil(n / 256.); grid.x = ceil(n / 256.);
hipLaunchKernelGGL((identity_kernel<Tin, T>), dim3(grid), dim3(block), 0, stream, in, out, n); hipLaunchKernelGGL((identity_kernel<Tin, T>), dim3(grid), dim3(block), 0, stream, in, out, n);
} }
template <typename T, int THREADS_PER_BLOCK> template <typename T, int THREADS_PER_BLOCK>
__global__ void __launch_bounds__(THREADS_PER_BLOCK) __global__ void __launch_bounds__(THREADS_PER_BLOCK)
identity_output_kernel(const float* in, T* out, float* amax, const float* scale, int n) { identity_output_kernel(const float* in, T* out, float* amax, const float* scale, int n) {
if constexpr (std::is_same<T, fp8e4m3>::value || std::is_same<T, fp8e5m2>::value) { if constexpr (std::is_same<T, fp8e4m3>::value || std::is_same<T, fp8e5m2>::value) {
typedef hipcub::BlockReduce<float, THREADS_PER_BLOCK> BlockReduce; typedef hipcub::BlockReduce<float, THREADS_PER_BLOCK> BlockReduce;
__shared__ typename BlockReduce::TempStorage block_temp_storage; __shared__ typename BlockReduce::TempStorage block_temp_storage;
float thread_amax = 0; float thread_amax = 0;
for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < n; id += blockDim.x * gridDim.x) { for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < n; id += blockDim.x * gridDim.x) {
float val = in[id]; float val = in[id];
out[id] = (T)((*scale) * val); out[id] = (T)((*scale) * val);
// deal with amax of D // deal with amax of D
thread_amax = std::fmax(std::fabs(val), thread_amax); thread_amax = std::fmax(std::fabs(val), thread_amax);
} }
// num_valid can be ignored since each thread amax is set to 0 // num_valid can be ignored since each thread amax is set to 0
float block_amax = BlockReduce(block_temp_storage).Reduce(thread_amax, hipcub::Max()); float block_amax = BlockReduce(block_temp_storage).Reduce(thread_amax, hipcub::Max());
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
atomicMaxFloat(amax, block_amax); atomicMaxFloat(amax, block_amax);
} }
} else { } else {
for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < n; id += blockDim.x * gridDim.x) { for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < n; id += blockDim.x * gridDim.x) {
float val = in[id]; float val = in[id];
out[id] = (T)(val); out[id] = (T)(val);
} }
} }
} }
template <typename T> template <typename T>
void identity_output_kernelLauncher(const float* in, T* out, float* amax, const float* scale, int n, void identity_output_kernelLauncher(const float* in, T* out, float* amax, const float* scale, int n,
hipStream_t stream) { hipStream_t stream) {
dim3 block, grid; dim3 block, grid;
constexpr int THREADS_PER_BLOCK = 1024; constexpr int THREADS_PER_BLOCK = 1024;
block.x = THREADS_PER_BLOCK; block.x = THREADS_PER_BLOCK;
grid.x = ceil(1.0 * n / THREADS_PER_BLOCK); grid.x = ceil(1.0 * n / THREADS_PER_BLOCK);
hipLaunchKernelGGL((identity_output_kernel<T, THREADS_PER_BLOCK>), dim3(grid), dim3(block), 0, hipLaunchKernelGGL((identity_output_kernel<T, THREADS_PER_BLOCK>), dim3(grid), dim3(block), 0,
stream, in, out, amax, scale, n); stream, in, out, amax, scale, n);
} }
template <typename Tin, int THREADS_PER_BLOCK> template <typename Tin, int THREADS_PER_BLOCK>
__global__ void __launch_bounds__(THREADS_PER_BLOCK) __global__ void __launch_bounds__(THREADS_PER_BLOCK)
bias_gradient_kernel(const Tin* in, float* out, int m, int n) { bias_gradient_kernel(const Tin* in, float* out, int m, int n) {
typedef hipcub::BlockReduce<float, THREADS_PER_BLOCK> BlockReduce; typedef hipcub::BlockReduce<float, THREADS_PER_BLOCK> BlockReduce;
__shared__ typename BlockReduce::TempStorage block_temp_storage; __shared__ typename BlockReduce::TempStorage block_temp_storage;
int BLOCKS_PER_COL = ceil(float(m) / THREADS_PER_BLOCK); int BLOCKS_PER_COL = ceil(float(m) / THREADS_PER_BLOCK);
int THREADS_PER_COL = BLOCKS_PER_COL * THREADS_PER_BLOCK; int THREADS_PER_COL = BLOCKS_PER_COL * THREADS_PER_BLOCK;
int idx = threadIdx.x + blockIdx.x * blockDim.x; int idx = threadIdx.x + blockIdx.x * blockDim.x;
int col_idx = idx / THREADS_PER_COL; int col_idx = idx / THREADS_PER_COL;
int row_idx = idx % THREADS_PER_COL; int row_idx = idx % THREADS_PER_COL;
float thread_data; float thread_data;
if (row_idx < m) thread_data = (float)in[row_idx * n + col_idx]; if (row_idx < m) thread_data = (float)in[row_idx * n + col_idx];
float local_sum; float local_sum;
if (row_idx < (BLOCKS_PER_COL - 1) * THREADS_PER_BLOCK) { if (row_idx < (BLOCKS_PER_COL - 1) * THREADS_PER_BLOCK) {
local_sum = BlockReduce(block_temp_storage).Sum(thread_data); local_sum = BlockReduce(block_temp_storage).Sum(thread_data);
} else { } else {
local_sum = BlockReduce(block_temp_storage) local_sum = BlockReduce(block_temp_storage)
.Sum(thread_data, m - (BLOCKS_PER_COL - 1) * THREADS_PER_BLOCK); .Sum(thread_data, m - (BLOCKS_PER_COL - 1) * THREADS_PER_BLOCK);
} }
if (threadIdx.x == 0) atomicAdd(&out[col_idx], local_sum); if (threadIdx.x == 0) atomicAdd(&out[col_idx], local_sum);
} }
constexpr int kColwiseReduceTileSize = 32; constexpr int kColwiseReduceTileSize = 32;
template <typename T> template <typename T>
__inline__ __device__ T WarpReduceSum(T val, int max = 32) { __inline__ __device__ T WarpReduceSum(T val, int max = 32) {
for (int offset = max; offset > 0; offset >>= 1) { for (int offset = max; offset > 0; offset >>= 1) {
val += __shfl_down(val, offset); val += __shfl_down(val, offset);
} }
return val; return val;
} }
template <typename InputType> template <typename InputType>
__launch_bounds__(1024) __global__ __launch_bounds__(1024) __global__
void bias_gradient_kernel_v2(float* dst, const InputType* src, int M, int N) { void bias_gradient_kernel_v2(float* dst, const InputType* src, int M, int N) {
__shared__ float g_shared[kColwiseReduceTileSize][kColwiseReduceTileSize]; __shared__ float g_shared[kColwiseReduceTileSize][kColwiseReduceTileSize];
const int j = blockIdx.x * blockDim.x + threadIdx.x; const int j = blockIdx.x * blockDim.x + threadIdx.x;
float grad_sum = 0.f; float grad_sum = 0.f;
if (j < N) { if (j < N) {
for (int i = threadIdx.y; i < M; i += blockDim.y) { for (int i = threadIdx.y; i < M; i += blockDim.y) {
grad_sum += static_cast<float>(src[i * N + j]); grad_sum += static_cast<float>(src[i * N + j]);
} }
} }
g_shared[threadIdx.y][threadIdx.x] = grad_sum; g_shared[threadIdx.y][threadIdx.x] = grad_sum;
__syncthreads(); __syncthreads();
float sum = g_shared[threadIdx.x][threadIdx.y]; float sum = g_shared[threadIdx.x][threadIdx.y];
sum = WarpReduceSum<float>(sum, kColwiseReduceTileSize / 2); sum = WarpReduceSum<float>(sum, kColwiseReduceTileSize / 2);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
const int j = blockIdx.x * blockDim.x + threadIdx.y; const int j = blockIdx.x * blockDim.x + threadIdx.y;
if (j < N) { if (j < N) {
dst[j] = static_cast<float>(sum); dst[j] = static_cast<float>(sum);
} }
} }
} }
template <typename OutputType> template <typename OutputType>
__launch_bounds__(1024) __global__ __launch_bounds__(1024) __global__
void tensorwise_int8_bias_gradient_kernel(OutputType* dst, const int8_t* src, float* scale, int M, int N) { void tensorwise_int8_bias_gradient_kernel(OutputType* dst, const int8_t* src, float* scale, int M, int N) {
__shared__ float g_shared[kColwiseReduceTileSize][kColwiseReduceTileSize]; __shared__ float g_shared[kColwiseReduceTileSize][kColwiseReduceTileSize];
const int j = blockIdx.x * blockDim.x + threadIdx.x; const int j = blockIdx.x * blockDim.x + threadIdx.x;
float grad_sum = 0.f; float grad_sum = 0.f;
float tensorwise_scale = scale[0]; float tensorwise_scale = scale[0];
if (j < N) { if (j < N) {
for (int i = threadIdx.y; i < M; i += blockDim.y) { for (int i = threadIdx.y; i < M; i += blockDim.y) {
grad_sum += static_cast<float>(src[i * N + j]) * tensorwise_scale; grad_sum += static_cast<float>(src[i * N + j]) * tensorwise_scale;
} }
} }
g_shared[threadIdx.y][threadIdx.x] = grad_sum; g_shared[threadIdx.y][threadIdx.x] = grad_sum;
__syncthreads(); __syncthreads();
float sum = g_shared[threadIdx.x][threadIdx.y]; float sum = g_shared[threadIdx.x][threadIdx.y];
sum = WarpReduceSum<float>(sum, kColwiseReduceTileSize / 2); sum = WarpReduceSum<float>(sum, kColwiseReduceTileSize / 2);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
const int j = blockIdx.x * blockDim.x + threadIdx.y; const int j = blockIdx.x * blockDim.x + threadIdx.y;
if (j < N) { if (j < N) {
dst[j] = static_cast<OutputType>(sum); dst[j] = static_cast<OutputType>(sum);
} }
} }
} }
template <typename Tin> template <typename Tin>
void bias_gradient_kernelLauncher(const Tin* in, float* out, int m, int n, bool stream_order_alloc, void bias_gradient_kernelLauncher(const Tin* in, float* out, int m, int n, bool stream_order_alloc,
hipStream_t stream) { hipStream_t stream) {
dim3 block, grid; dim3 block, grid;
constexpr int THREADS_PER_BLOCK = 1024; constexpr int THREADS_PER_BLOCK = 1024;
int BLOCKS_PER_COL = ceil(float(m) / THREADS_PER_BLOCK); int BLOCKS_PER_COL = ceil(float(m) / THREADS_PER_BLOCK);
block.x = THREADS_PER_BLOCK; block.x = THREADS_PER_BLOCK;
grid.x = BLOCKS_PER_COL * n; grid.x = BLOCKS_PER_COL * n;
if (!stream_order_alloc) { if (!stream_order_alloc) {
NVTE_CHECK_CUDA(hipMemset(out, 0, n * sizeof(float))); NVTE_CHECK_CUDA(hipMemset(out, 0, n * sizeof(float)));
} else { } else {
NVTE_CHECK_CUDA(hipMemsetAsync(out, 0, n * sizeof(float), stream)); NVTE_CHECK_CUDA(hipMemsetAsync(out, 0, n * sizeof(float), stream));
} }
// hipLaunchKernelGGL(( bias_gradient_kernel<Tin, THREADS_PER_BLOCK>), dim3(grid), dim3(block), 0, stream, in, out, m, n); // hipLaunchKernelGGL(( bias_gradient_kernel<Tin, THREADS_PER_BLOCK>), dim3(grid), dim3(block), 0, stream, in, out, m, n);
int B = (n - 1) / kColwiseReduceTileSize + 1; int B = (n - 1) / kColwiseReduceTileSize + 1;
bias_gradient_kernel_v2<Tin> bias_gradient_kernel_v2<Tin>
<<<B, dim3(kColwiseReduceTileSize, kColwiseReduceTileSize), 0, stream>>>(out, in, m, n); <<<B, dim3(kColwiseReduceTileSize, kColwiseReduceTileSize), 0, stream>>>(out, in, m, n);
} }
template <typename Tout> template <typename Tout>
void tensorwise_int8_bias_gradient_kernelLauncher(const int8_t* in, Tout* out, float* scale, int m, int n, hipStream_t stream) { void tensorwise_int8_bias_gradient_kernelLauncher(const int8_t* in, Tout* out, float* scale, int m, int n, hipStream_t stream) {
dim3 block, grid; dim3 block, grid;
constexpr int THREADS_PER_BLOCK = 1024; constexpr int THREADS_PER_BLOCK = 1024;
int BLOCKS_PER_COL = ceil(float(m) / THREADS_PER_BLOCK); int BLOCKS_PER_COL = ceil(float(m) / THREADS_PER_BLOCK);
block.x = THREADS_PER_BLOCK; block.x = THREADS_PER_BLOCK;
grid.x = BLOCKS_PER_COL * n; grid.x = BLOCKS_PER_COL * n;
NVTE_CHECK_CUDA(hipMemsetAsync(out, 0, n * sizeof(Tout), stream)); NVTE_CHECK_CUDA(hipMemsetAsync(out, 0, n * sizeof(Tout), stream));
int B = (n - 1) / kColwiseReduceTileSize + 1; int B = (n - 1) / kColwiseReduceTileSize + 1;
tensorwise_int8_bias_gradient_kernel<Tout> tensorwise_int8_bias_gradient_kernel<Tout>
<<<B, dim3(kColwiseReduceTileSize, kColwiseReduceTileSize), 0, stream>>>(out, in, scale, m, n); <<<B, dim3(kColwiseReduceTileSize, kColwiseReduceTileSize), 0, stream>>>(out, in, scale, m, n);
} }
} // namespace detail } // namespace detail
transformer_engine::DType get_transformer_engine_dtype(const rocblas_datatype t) { transformer_engine::DType get_transformer_engine_dtype(const rocblas_datatype t) {
using namespace transformer_engine; using namespace transformer_engine;
switch (t) { switch (t) {
case rocblas_datatype_f16_r: case rocblas_datatype_f16_r:
return DType::kFloat16; return DType::kFloat16;
case rocblas_datatype_f32_r: case rocblas_datatype_f32_r:
return DType::kFloat32; return DType::kFloat32;
case rocblas_datatype_bf16_r: case rocblas_datatype_bf16_r:
return DType::kBFloat16; return DType::kBFloat16;
case rocblas_datatype_f8_r: case rocblas_datatype_f8_r:
return DType::kFloat8E4M3; return DType::kFloat8E4M3;
case rocblas_datatype_bf8_r: case rocblas_datatype_bf8_r:
return DType::kFloat8E5M2; return DType::kFloat8E5M2;
default: default:
NVTE_ERROR("Invalid type"); NVTE_ERROR("Invalid type");
} }
} }
#endif //USE_ROCBLAS #endif //USE_ROCBLAS
#ifdef USE_HIPBLASLT #ifdef USE_HIPBLASLT
namespace { namespace {
static class HandlePool { static class HandlePool {
public: public:
hipblasLtHandle_t get(int device_id) { hipblasLtHandle_t get(int device_id) {
std::lock_guard<std::mutex> lock(mt); std::lock_guard<std::mutex> lock(mt);
if (pool.empty()) { if (pool.empty()) {
int device_count = 0; int device_count = 0;
NVTE_CHECK_CUDA(hipGetDeviceCount(&device_count)); NVTE_CHECK_CUDA(hipGetDeviceCount(&device_count));
pool.resize(device_count); pool.resize(device_count);
return nullptr; return nullptr;
} }
if (!pool[device_id].empty()) { if (!pool[device_id].empty()) {
hipblasLtHandle_t h = pool[device_id].front(); hipblasLtHandle_t h = pool[device_id].front();
pool[device_id].pop_front(); pool[device_id].pop_front();
return h; return h;
} }
return nullptr; return nullptr;
} }
hipblasLtHandle_t obtain(int device_id) { hipblasLtHandle_t obtain(int device_id) {
hipblasLtHandle_t h = get(device_id); hipblasLtHandle_t h = get(device_id);
if (h == nullptr) { if (h == nullptr) {
NVTE_CHECK_HIPBLASLT(hipblasLtCreate(&h)); NVTE_CHECK_HIPBLASLT(hipblasLtCreate(&h));
} }
return h; return h;
} }
void store(const std::vector<hipblasLtHandle_t>& handles) { void store(const std::vector<hipblasLtHandle_t>& handles) {
std::lock_guard<std::mutex> lock(mt); std::lock_guard<std::mutex> lock(mt);
if (pool.empty()) { if (pool.empty()) {
std::cout << "[ERROR] Attempt to store handles to invalid pool" << std::endl; std::cout << "[ERROR] Attempt to store handles to invalid pool" << std::endl;
} }
for (unsigned int i = 0; i < pool.size(); i++) { for (unsigned int i = 0; i < pool.size(); i++) {
if (handles[i] != nullptr) { if (handles[i] != nullptr) {
pool[i].push_front(handles[i]); pool[i].push_front(handles[i]);
} }
} }
} }
~HandlePool() { ~HandlePool() {
#if DESTROY_HIPBLASLT_HANDLES_POOL #if DESTROY_HIPBLASLT_HANDLES_POOL
std::lock_guard<std::mutex> lock(mt); std::lock_guard<std::mutex> lock(mt);
for (auto& hlist : pool) { for (auto& hlist : pool) {
for (auto& h : hlist) { for (auto& h : hlist) {
hipblasLtDestroy(h); hipblasLtDestroy(h);
} }
} }
pool.clear(); pool.clear();
#endif #endif
} }
inline size_t get_size() const { return pool.size(); } inline size_t get_size() const { return pool.size(); }
private: private:
std::mutex mt; std::mutex mt;
using Pool = std::vector<std::forward_list<hipblasLtHandle_t>>; using Pool = std::vector<std::forward_list<hipblasLtHandle_t>>;
// Order of destructors between thread_local and global is not actually guaranteed // Order of destructors between thread_local and global is not actually guaranteed
// As a simple w/a make pool storage "leaky" // As a simple w/a make pool storage "leaky"
// Just do not destruct it and do not destroy hipbladLt handles // Just do not destruct it and do not destroy hipbladLt handles
// Let OS deal with it on application exit // Let OS deal with it on application exit
#if DESTROY_HIPBLASLT_HANDLES_POOL #if DESTROY_HIPBLASLT_HANDLES_POOL
Pool pool; Pool pool;
#else #else
Pool& pool = *new Pool(); Pool& pool = *new Pool();
#endif #endif
} handle_pool; } handle_pool;
thread_local static class HandleCache { thread_local static class HandleCache {
public: public:
hipblasLtHandle_t get(int device_id) const { return d.empty() ? nullptr : d[device_id]; } hipblasLtHandle_t get(int device_id) const { return d.empty() ? nullptr : d[device_id]; }
hipblasLtHandle_t obtain(int device_id) { hipblasLtHandle_t obtain(int device_id) {
hipblasLtHandle_t h = get(device_id); hipblasLtHandle_t h = get(device_id);
if (h) { if (h) {
return h; return h;
} }
h = handle_pool.obtain(device_id); h = handle_pool.obtain(device_id);
set(device_id, h); set(device_id, h);
return h; return h;
} }
void set(int device_id, hipblasLtHandle_t h) { void set(int device_id, hipblasLtHandle_t h) {
if (d.empty()) { if (d.empty()) {
d.resize(handle_pool.get_size()); d.resize(handle_pool.get_size());
} }
d[device_id] = h; d[device_id] = h;
} }
~HandleCache() { ~HandleCache() {
if (!d.empty()) { if (!d.empty()) {
handle_pool.store(d); handle_pool.store(d);
} }
} }
private: private:
std::vector<hipblasLtHandle_t> d; std::vector<hipblasLtHandle_t> d;
} cached_handles; } cached_handles;
class csv_helper { class csv_helper {
public: public:
struct start {}; struct start {};
struct end {}; struct end {};
csv_helper(std::ostream& os, char sep_val) csv_helper(std::ostream& os, char sep_val)
: m_os{os}, m_sep_val(sep_val), m_start(true), m_sep("") {} : m_os{os}, m_sep_val(sep_val), m_start(true), m_sep("") {}
csv_helper& operator<<(const start&) { csv_helper& operator<<(const start&) {
m_start = true; m_start = true;
return *this; return *this;
} }
csv_helper& operator<<(const end&) { csv_helper& operator<<(const end&) {
m_sep = ""; m_sep = "";
m_start = false; m_start = false;
return *this; return *this;
} }
template <typename T> template <typename T>
csv_helper& operator<<(const T& v) { csv_helper& operator<<(const T& v) {
m_os << m_sep << v; m_os << m_sep << v;
if (m_start) { if (m_start) {
m_start = false; m_start = false;
m_sep = m_sep_val; m_sep = m_sep_val;
} }
return *this; return *this;
} }
private: private:
std::ostream& m_os; std::ostream& m_os;
char m_sep_val; char m_sep_val;
bool m_start; bool m_start;
std::string m_sep; std::string m_sep;
}; };
template <typename T> template <typename T>
class NameMapper { class NameMapper {
public: public:
NameMapper(const std::unordered_map<T, std::string_view>& name_map) : map(name_map) {} NameMapper(const std::unordered_map<T, std::string_view>& name_map) : map(name_map) {}
const std::string_view& getName(const T& val) { return map.at(val); } const std::string_view& getName(const T& val) { return map.at(val); }
T getValue(const std::string& name, const char* label = "", T getValue(const std::string& name, const char* label = "",
std::function<bool(const T&)> filter = nullptr) { std::function<bool(const T&)> filter = nullptr) {
for (auto iter = map.begin(); iter != map.end(); ++iter) { for (auto iter = map.begin(); iter != map.end(); ++iter) {
if ((name == iter->second) && (!filter || filter(iter->first))) return iter->first; if ((name == iter->second) && (!filter || filter(iter->first))) return iter->first;
} }
NVTE_ERROR("Invalid ", label, " name: ", name); NVTE_ERROR("Invalid ", label, " name: ", name);
} }
protected: protected:
const std::unordered_map<T, std::string_view>& map; const std::unordered_map<T, std::string_view>& map;
}; };
static std::unordered_map<hipDataType, std::string_view> type_name_map = { static std::unordered_map<hipDataType, std::string_view> type_name_map = {
{HIP_R_32F, "float32"}, {HIP_R_32F, "float32"},
{HIP_R_16F, "float16"}, {HIP_R_16F, "float16"},
{HIP_R_16BF, "bfloat16"}, {HIP_R_16BF, "bfloat16"},
{HIP_R_8F_E4M3_FNUZ, "float8e4m3"}, {HIP_R_8F_E4M3_FNUZ, "float8e4m3"},
{HIP_R_8F_E5M2_FNUZ, "float8e5m2"}, {HIP_R_8F_E5M2_FNUZ, "float8e5m2"},
#if HIP_VERSION >= 60300000 #if HIP_VERSION >= 60300000
{HIP_R_8F_E4M3, "float8e4m3"}, {HIP_R_8F_E4M3, "float8e4m3"},
{HIP_R_8F_E5M2, "float8e5m2"}, {HIP_R_8F_E5M2, "float8e5m2"},
#endif #endif
}; };
static NameMapper<hipDataType> typeNameMapper(type_name_map); static NameMapper<hipDataType> typeNameMapper(type_name_map);
static std::unordered_map<hipblasOperation_t, std::string_view> trans_name_map = { static std::unordered_map<hipblasOperation_t, std::string_view> trans_name_map = {
{HIPBLAS_OP_N, "N"}, {HIPBLAS_OP_T, "T"}}; {HIPBLAS_OP_N, "N"}, {HIPBLAS_OP_T, "T"}};
static NameMapper<hipblasOperation_t> transposeNameMapper(trans_name_map); static NameMapper<hipblasOperation_t> transposeNameMapper(trans_name_map);
static std::unordered_map<hipblasLtEpilogue_t, std::string_view> epi_name_map = { static std::unordered_map<hipblasLtEpilogue_t, std::string_view> epi_name_map = {
{HIPBLASLT_EPILOGUE_DEFAULT, "-"}, {HIPBLASLT_EPILOGUE_BIAS, "bias"}, {HIPBLASLT_EPILOGUE_DEFAULT, "-"}, {HIPBLASLT_EPILOGUE_BIAS, "bias"},
{HIPBLASLT_EPILOGUE_GELU_AUX, "geluaux"}, {HIPBLASLT_EPILOGUE_GELU_AUX_BIAS, "geluauxbias"}, {HIPBLASLT_EPILOGUE_GELU_AUX, "geluaux"}, {HIPBLASLT_EPILOGUE_GELU_AUX_BIAS, "geluauxbias"},
{HIPBLASLT_EPILOGUE_DGELU, "dgelu"}, {HIPBLASLT_EPILOGUE_DGELU_BGRAD, "dgelubgrad"}, {HIPBLASLT_EPILOGUE_DGELU, "dgelu"}, {HIPBLASLT_EPILOGUE_DGELU_BGRAD, "dgelubgrad"},
{HIPBLASLT_EPILOGUE_BGRADB, "bgradb"}}; {HIPBLASLT_EPILOGUE_BGRADB, "bgradb"}};
static NameMapper<hipblasLtEpilogue_t> epilogueNameMapper(epi_name_map); static NameMapper<hipblasLtEpilogue_t> epilogueNameMapper(epi_name_map);
static std::unordered_map<hipblasComputeType_t, std::string_view> comp_name_map = { static std::unordered_map<hipblasComputeType_t, std::string_view> comp_name_map = {
{HIPBLAS_COMPUTE_32F, "f32"}}; {HIPBLAS_COMPUTE_32F, "f32"}};
static NameMapper<hipblasComputeType_t> computeNameMapper(comp_name_map); static NameMapper<hipblasComputeType_t> computeNameMapper(comp_name_map);
static class GemmAlgoCache { static class GemmAlgoCache {
public: public:
struct Key { struct Key {
int deviceCap; int deviceCap;
hipDataType a_type, b_type, d_type, bias_type; hipDataType a_type, b_type, d_type, bias_type;
int m, n, k; int m, n, k;
int lda, ldb, ldd; int lda, ldb, ldd;
hipblasOperation_t transa, transb; hipblasOperation_t transa, transb;
hipblasLtEpilogue_t epilogue; hipblasLtEpilogue_t epilogue;
Key(int deviceCap_, hipDataType a_type_, hipDataType b_type_, hipDataType d_type_, Key(int deviceCap_, hipDataType a_type_, hipDataType b_type_, hipDataType d_type_,
hipDataType bias_type_, int m_, int n_, int k_, int lda_, int ldb_, int ldd_, hipDataType bias_type_, int m_, int n_, int k_, int lda_, int ldb_, int ldd_,
hipblasOperation_t transa_, hipblasOperation_t transb_, hipblasLtEpilogue_t epilogue_) hipblasOperation_t transa_, hipblasOperation_t transb_, hipblasLtEpilogue_t epilogue_)
: deviceCap(deviceCap_), : deviceCap(deviceCap_),
a_type(a_type_), a_type(a_type_),
b_type(b_type_), b_type(b_type_),
d_type(d_type_), d_type(d_type_),
bias_type(bias_type_), bias_type(bias_type_),
m(m_), m(m_),
n(n_), n(n_),
k(k_), k(k_),
lda(lda_), lda(lda_),
ldb(ldb_), ldb(ldb_),
ldd(ldd_), ldd(ldd_),
transa(transa_), transa(transa_),
transb(transb_), transb(transb_),
epilogue(epilogue_) {} epilogue(epilogue_) {}
Key() {} Key() {}
bool operator==(const Key& val) const { bool operator==(const Key& val) const {
return ((deviceCap == val.deviceCap) && (a_type == val.a_type) && (b_type == val.b_type) && return ((deviceCap == val.deviceCap) && (a_type == val.a_type) && (b_type == val.b_type) &&
(d_type == val.d_type) && (bias_type == val.bias_type) && (m == val.m) && (d_type == val.d_type) && (bias_type == val.bias_type) && (m == val.m) &&
(n == val.n) && (k == val.k) && (lda == val.lda) && (ldb == val.ldb) && (n == val.n) && (k == val.k) && (lda == val.lda) && (ldb == val.ldb) &&
(ldd == val.ldd) && (transa == val.transa) && (transb == val.transb) && (ldd == val.ldd) && (transa == val.transa) && (transb == val.transb) &&
(epilogue == val.epilogue)); (epilogue == val.epilogue));
} }
struct Comp { struct Comp {
bool operator()(const Key& lhs, const Key& rhs) const { bool operator()(const Key& lhs, const Key& rhs) const {
return ::std::string_view((const char*)&lhs, sizeof(lhs)) < return ::std::string_view((const char*)&lhs, sizeof(lhs)) <
::std::string_view((const char*)&rhs, sizeof(rhs)); ::std::string_view((const char*)&rhs, sizeof(rhs));
} }
}; };
}; };
void init() { void init() {
std::lock_guard<std::mutex> lock(mt); std::lock_guard<std::mutex> lock(mt);
int device_count = 0; int device_count = 0;
NVTE_CHECK_CUDA(hipGetDeviceCount(&device_count)); NVTE_CHECK_CUDA(hipGetDeviceCount(&device_count));
dev_cap.resize(device_count); dev_cap.resize(device_count);
for (int i = 0; i < device_count; i++) { for (int i = 0; i < device_count; i++) {
hipDeviceProp_t prop; hipDeviceProp_t prop;
NVTE_CHECK_CUDA(hipGetDeviceProperties(&prop, i)); NVTE_CHECK_CUDA(hipGetDeviceProperties(&prop, i));
dev_cap[i] = prop.major * 100 + prop.minor; dev_cap[i] = prop.major * 100 + prop.minor;
} }
load_(); load_();
save_(); save_();
} }
inline int device_cap(int device_id) { inline int device_cap(int device_id) {
if (dev_cap.empty()) init(); if (dev_cap.empty()) init();
return dev_cap[device_id]; return dev_cap[device_id];
} }
struct Algo { struct Algo {
std::optional<hipblasLtMatmulAlgo_t> algo; std::optional<hipblasLtMatmulAlgo_t> algo;
int64_t algoId; int64_t algoId;
int index; int index;
size_t ws_size_min; size_t ws_size_min;
size_t ws_size_max; size_t ws_size_max;
Algo() : algo(), index(-1), algoId(), ws_size_min(0), ws_size_max(0) {} Algo() : algo(), index(-1), algoId(), ws_size_min(0), ws_size_max(0) {}
Algo(int idx, int64_t id, size_t ws_min, size_t ws_max) Algo(int idx, int64_t id, size_t ws_min, size_t ws_max)
: algo(), index(idx), algoId(id), ws_size_min(ws_min), ws_size_max(ws_max) {} : algo(), index(idx), algoId(id), ws_size_min(ws_min), ws_size_max(ws_max) {}
inline bool hasId() { return index >= 0; } inline bool hasId() { return index >= 0; }
const static inline int64_t getAlgoId(const hipblasLtMatmulAlgo_t& algo) { const static inline int64_t getAlgoId(const hipblasLtMatmulAlgo_t& algo) {
return *(const int64_t*)&algo; return *(const int64_t*)&algo;
} }
}; };
bool find(const Key& cfg, size_t ws_size, Algo& algo) { bool find(const Key& cfg, size_t ws_size, Algo& algo) {
std::lock_guard<std::mutex> lock(mt); std::lock_guard<std::mutex> lock(mt);
if (auto* pentry = find_(cfg, ws_size, ws_size); pentry != nullptr) { if (auto* pentry = find_(cfg, ws_size, ws_size); pentry != nullptr) {
algo = *pentry; algo = *pentry;
return true; return true;
} }
return false; return false;
} }
void store(const Key& cfg, const Algo& algo) { void store(const Key& cfg, const Algo& algo) {
size_t ws_size_min = algo.ws_size_min; size_t ws_size_min = algo.ws_size_min;
size_t ws_size_max = algo.ws_size_max; size_t ws_size_max = algo.ws_size_max;
NVTE_CHECK(ws_size_max >= ws_size_min, "Invalid WS size"); NVTE_CHECK(ws_size_max >= ws_size_min, "Invalid WS size");
std::lock_guard<std::mutex> lock(mt); std::lock_guard<std::mutex> lock(mt);
//Remove overlapping with existing entries; //Remove overlapping with existing entries;
while (auto* pentry = find_(cfg, ws_size_min, ws_size_max)) { while (auto* pentry = find_(cfg, ws_size_min, ws_size_max)) {
if (pentry->ws_size_min <= ws_size_min && pentry->ws_size_max >= ws_size_max) { if (pentry->ws_size_min <= ws_size_min && pentry->ws_size_max >= ws_size_max) {
*pentry = algo; *pentry = algo;
save_(); save_();
return; return;
} }
if (ws_size_max > pentry->ws_size_max) { if (ws_size_max > pentry->ws_size_max) {
ws_size_min = pentry->ws_size_max + 1; ws_size_min = pentry->ws_size_max + 1;
} else if (ws_size_min < pentry->ws_size_min) { } else if (ws_size_min < pentry->ws_size_min) {
ws_size_max = pentry->ws_size_min - 1; ws_size_max = pentry->ws_size_min - 1;
} else { } else {
//Should never be here //Should never be here
NVTE_ERROR("Cannot merge WS size range"); NVTE_ERROR("Cannot merge WS size range");
} }
} }
//Merge to adjusted entry if possible //Merge to adjusted entry if possible
auto* pentry = find_(cfg, ws_size_min - 1, ws_size_min); auto* pentry = find_(cfg, ws_size_min - 1, ws_size_min);
if (pentry && pentry->algoId == algo.algoId) { if (pentry && pentry->algoId == algo.algoId) {
pentry->algo = algo.algo; pentry->algo = algo.algo;
pentry->ws_size_max = ws_size_max; pentry->ws_size_max = ws_size_max;
save_(); save_();
} else { } else {
auto it = d.emplace(cfg, algo); auto it = d.emplace(cfg, algo);
it->second.ws_size_min = ws_size_min; it->second.ws_size_min = ws_size_min;
it->second.ws_size_max = ws_size_max; it->second.ws_size_max = ws_size_max;
save_(it->first, it->second); save_(it->first, it->second);
} }
} }
protected: protected:
Algo* find_(const Key& cfg, size_t ws_min, size_t ws_max) { Algo* find_(const Key& cfg, size_t ws_min, size_t ws_max) {
const auto key_range = d.equal_range(cfg); const auto key_range = d.equal_range(cfg);
for (auto i = key_range.first; i != key_range.second; i++) { for (auto i = key_range.first; i != key_range.second; i++) {
if (ws_min <= i->second.ws_size_max && ws_max >= i->second.ws_size_min) { if (ws_min <= i->second.ws_size_max && ws_max >= i->second.ws_size_min) {
return &i->second; return &i->second;
} }
} }
return nullptr; return nullptr;
} }
void header_(std::ostream& ofs) { void header_(std::ostream& ofs) {
csv_helper fs(ofs, csv_sep); csv_helper fs(ofs, csv_sep);
fs << "dev_cap" << "m" << "n" << "k" << "trans_a" << "trans_b" fs << "dev_cap" << "m" << "n" << "k" << "trans_a" << "trans_b"
<< "type_a" << "type_b" << "type_d" << "bias_type" << "type_a" << "type_b" << "type_d" << "bias_type"
<< "lda" << "ldb" << "ldd" << "epi" << "comp" << "scale" << "lda" << "ldb" << "ldd" << "epi" << "comp" << "scale"
<< "ws_min" << "ws_max" << "algo_id" << "aidx"; << "ws_min" << "ws_max" << "algo_id" << "aidx";
} }
void load_() { void load_() {
const char* env = std::getenv("TE_HIPBLASLT_ALGO_LOAD"); const char* env = std::getenv("TE_HIPBLASLT_ALGO_LOAD");
if (env == nullptr || env[0] == '\0') { if (env == nullptr || env[0] == '\0') {
return; return;
} }
std::ifstream ifs{env}; std::ifstream ifs{env};
if (!ifs.is_open()) { if (!ifs.is_open()) {
std::cerr << "Could not load autotune results storage " << env << "\n"; std::cerr << "Could not load autotune results storage " << env << "\n";
return; return;
} }
std::cout << "Loading autotune results from " << env << "\n"; std::cout << "Loading autotune results from " << env << "\n";
Key cfg; Key cfg;
std::string line; std::string line;
std::getline(ifs, line); // the first line with legend std::getline(ifs, line); // the first line with legend
{ {
std::ostringstream hline; std::ostringstream hline;
header_(hline); header_(hline);
if (hline.str() != line) { if (hline.str() != line) {
std::cerr << "Incorrect algo storage legend. Expected " << hline.str() << "\n"; std::cerr << "Incorrect algo storage legend. Expected " << hline.str() << "\n";
return; return;
} }
} }
while (std::getline(ifs, line)) { while (std::getline(ifs, line)) {
line.erase(0, line.find_first_not_of(" \t\n\r\f\v")); line.erase(0, line.find_first_not_of(" \t\n\r\f\v"));
if (auto pos = line.find_last_not_of(" \t\n\r\f\v"); pos != std::string::npos) { if (auto pos = line.find_last_not_of(" \t\n\r\f\v"); pos != std::string::npos) {
line.resize(pos + 1); line.resize(pos + 1);
} }
if (line.empty() || line[0] == '#') continue; if (line.empty() || line[0] == '#') continue;
std::istringstream is(line); std::istringstream is(line);
char c; char c;
std::string type_a, type_b, type_d, bias_type, trans_a, trans_b, epi, comp, scale; std::string type_a, type_b, type_d, bias_type, trans_a, trans_b, epi, comp, scale;
int64_t algo_id; int64_t algo_id;
int algo_idx; int algo_idx;
size_t ws_min, ws_max; size_t ws_min, ws_max;
is >> std::skipws; is >> std::skipws;
is >> cfg.deviceCap >> c >> cfg.m >> c >> cfg.n >> c >> cfg.k >> c; is >> cfg.deviceCap >> c >> cfg.m >> c >> cfg.n >> c >> cfg.k >> c;
//Filter out entries for devices not presented on the curent system //Filter out entries for devices not presented on the curent system
bool b_found = false; bool b_found = false;
for (int i = 0; i < dev_cap.size(); i++) { for (int i = 0; i < dev_cap.size(); i++) {
if (dev_cap[i] == cfg.deviceCap) { if (dev_cap[i] == cfg.deviceCap) {
b_found = true; b_found = true;
break; break;
} }
} }
if (!b_found) continue; if (!b_found) continue;
std::getline(is, trans_a, csv_sep); std::getline(is, trans_a, csv_sep);
std::getline(is, trans_b, csv_sep); std::getline(is, trans_b, csv_sep);
std::getline(is, type_a, csv_sep); std::getline(is, type_a, csv_sep);
std::getline(is, type_b, csv_sep); std::getline(is, type_b, csv_sep);
std::getline(is, type_d, csv_sep); std::getline(is, type_d, csv_sep);
std::getline(is, bias_type, csv_sep); std::getline(is, bias_type, csv_sep);
is >> cfg.lda >> c >> cfg.ldb >> c >> cfg.ldd >> c; is >> cfg.lda >> c >> cfg.ldb >> c >> cfg.ldd >> c;
std::getline(is, epi, csv_sep); std::getline(is, epi, csv_sep);
std::getline(is, comp, csv_sep); std::getline(is, comp, csv_sep);
std::getline(is, scale, csv_sep); std::getline(is, scale, csv_sep);
is >> ws_min >> c >> ws_max >> c >> algo_id >> c >> algo_idx; is >> ws_min >> c >> ws_max >> c >> algo_id >> c >> algo_idx;
if (is.bad()) { if (is.bad()) {
std::cerr << "Parsing CSV line failed: " << line << "\n"; std::cerr << "Parsing CSV line failed: " << line << "\n";
return; return;
} }
if (ws_min > ws_max) { if (ws_min > ws_max) {
std::cout << "[WARNING] Invalid WS size at " << line << "\n"; std::cout << "[WARNING] Invalid WS size at " << line << "\n";
continue; continue;
} }
#if HIP_VERSION >= 60300000 #if HIP_VERSION >= 60300000
auto fp8_filter = [](const hipDataType& val) { auto fp8_filter = [](const hipDataType& val) {
return (val != HIP_R_8F_E4M3_FNUZ && val != HIP_R_8F_E5M2_FNUZ); return (val != HIP_R_8F_E4M3_FNUZ && val != HIP_R_8F_E5M2_FNUZ);
}; };
#else #else
auto fp8_filter = nullptr; auto fp8_filter = nullptr;
#endif #endif
cfg.a_type = typeNameMapper.getValue(type_a, "type_a", fp8_filter); cfg.a_type = typeNameMapper.getValue(type_a, "type_a", fp8_filter);
cfg.b_type = typeNameMapper.getValue(type_b, "type_b", fp8_filter); cfg.b_type = typeNameMapper.getValue(type_b, "type_b", fp8_filter);
cfg.d_type = typeNameMapper.getValue(type_d, "type_d", fp8_filter); cfg.d_type = typeNameMapper.getValue(type_d, "type_d", fp8_filter);
cfg.bias_type = (bias_type == "-") cfg.bias_type = (bias_type == "-")
? (hipDataType)-1 ? (hipDataType)-1
: typeNameMapper.getValue(bias_type, "bias_type", fp8_filter); : typeNameMapper.getValue(bias_type, "bias_type", fp8_filter);
cfg.transa = transposeNameMapper.getValue(trans_a, "trans_a"); cfg.transa = transposeNameMapper.getValue(trans_a, "trans_a");
cfg.transb = transposeNameMapper.getValue(trans_b, "trans_b"); cfg.transb = transposeNameMapper.getValue(trans_b, "trans_b");
cfg.epilogue = epilogueNameMapper.getValue(epi, "epi"); cfg.epilogue = epilogueNameMapper.getValue(epi, "epi");
//Check and filter out compute and scale types //Check and filter out compute and scale types
if (computeNameMapper.getValue(comp, "comp") != HIPBLAS_COMPUTE_32F || if (computeNameMapper.getValue(comp, "comp") != HIPBLAS_COMPUTE_32F ||
typeNameMapper.getValue(scale, "scale") != HIP_R_32F) { typeNameMapper.getValue(scale, "scale") != HIP_R_32F) {
continue; continue;
} }
if (find_(cfg, ws_min, ws_max)) { if (find_(cfg, ws_min, ws_max)) {
std::cout << "[WARNING] Duplicated/overlapped entry in algo cache\n"; std::cout << "[WARNING] Duplicated/overlapped entry in algo cache\n";
continue; continue;
} }
d.emplace(cfg, Algo(algo_idx, algo_id, ws_min, ws_max)); d.emplace(cfg, Algo(algo_idx, algo_id, ws_min, ws_max));
} }
} }
bool can_save_(bool reopen = false) { bool can_save_(bool reopen = false) {
if (!save_fs) { if (!save_fs) {
const char* temp = std::getenv("TE_HIPBLASLT_ALGO_SAVE"); const char* temp = std::getenv("TE_HIPBLASLT_ALGO_SAVE");
if (temp == nullptr || temp[0] == '\0') { if (temp == nullptr || temp[0] == '\0') {
return false; return false;
} }
save_fs_name = temp; save_fs_name = temp;
pid_t pid = getpid(); pid_t pid = getpid();
size_t pos = 0; size_t pos = 0;
while ((pos = save_fs_name.find("%i", pos)) != std::string::npos) { while ((pos = save_fs_name.find("%i", pos)) != std::string::npos) {
save_fs_name.replace(pos, 2, std::to_string(pid)); save_fs_name.replace(pos, 2, std::to_string(pid));
} }
save_fs = std::make_unique<std::ofstream>(); save_fs = std::make_unique<std::ofstream>();
std::cout << "Saving autotune results to " << save_fs_name << "\n"; std::cout << "Saving autotune results to " << save_fs_name << "\n";
} }
if (reopen) { if (reopen) {
if (save_fs->is_open()) { if (save_fs->is_open()) {
save_fs->close(); save_fs->close();
} }
save_fs->open(save_fs_name, std::ios_base::trunc); save_fs->open(save_fs_name, std::ios_base::trunc);
} }
if (save_fs->is_open() && !save_fs->bad()) { if (save_fs->is_open() && !save_fs->bad()) {
return true; return true;
} else { } else {
if (reopen) std::cerr << "Could not open autotune results storage " << save_fs_name << "\n"; if (reopen) std::cerr << "Could not open autotune results storage " << save_fs_name << "\n";
return false; return false;
} }
} }
void save_() { void save_() {
if (!can_save_(true)) { if (!can_save_(true)) {
return; return;
} }
header_(*save_fs); header_(*save_fs);
*save_fs << "\n"; *save_fs << "\n";
for (const auto& elem : d) { for (const auto& elem : d) {
save_(elem.first, elem.second); save_(elem.first, elem.second);
} }
} }
void save_(const Key& cfg, const Algo& algo) { void save_(const Key& cfg, const Algo& algo) {
if (!can_save_()) { if (!can_save_()) {
return; return;
} }
csv_helper csv(*save_fs, csv_sep); csv_helper csv(*save_fs, csv_sep);
csv << cfg.deviceCap << cfg.m << cfg.n << cfg.k << transposeNameMapper.getName(cfg.transa) csv << cfg.deviceCap << cfg.m << cfg.n << cfg.k << transposeNameMapper.getName(cfg.transa)
<< transposeNameMapper.getName(cfg.transb) << typeNameMapper.getName(cfg.a_type) << transposeNameMapper.getName(cfg.transb) << typeNameMapper.getName(cfg.a_type)
<< typeNameMapper.getName(cfg.b_type) << typeNameMapper.getName(cfg.d_type) << typeNameMapper.getName(cfg.b_type) << typeNameMapper.getName(cfg.d_type)
<< ((cfg.bias_type == (hipDataType)-1) ? "-" : typeNameMapper.getName(cfg.bias_type)) << ((cfg.bias_type == (hipDataType)-1) ? "-" : typeNameMapper.getName(cfg.bias_type))
<< cfg.lda << cfg.ldb << cfg.ldd << epilogueNameMapper.getName(cfg.epilogue) << cfg.lda << cfg.ldb << cfg.ldd << epilogueNameMapper.getName(cfg.epilogue)
<< computeNameMapper.getName(HIPBLAS_COMPUTE_32F) << typeNameMapper.getName(HIP_R_32F) << computeNameMapper.getName(HIPBLAS_COMPUTE_32F) << typeNameMapper.getName(HIP_R_32F)
<< algo.ws_size_min << algo.ws_size_max << algo.algoId << algo.index << csv_helper::end() << algo.ws_size_min << algo.ws_size_max << algo.algoId << algo.index << csv_helper::end()
<< "\n"; << "\n";
} }
private: private:
std::vector<int> dev_cap; std::vector<int> dev_cap;
constexpr static char csv_sep = ','; constexpr static char csv_sep = ',';
std::unique_ptr<std::ofstream> save_fs; std::unique_ptr<std::ofstream> save_fs;
std::string save_fs_name; std::string save_fs_name;
std::mutex mt; std::mutex mt;
/* Map of problem config to tuple of ws_size and Algo /* Map of problem config to tuple of ws_size and Algo
* When searching, elements matching Key are filtered * When searching, elements matching Key are filtered
* for requested WS size be between Algo.ws_size and pair.first * for requested WS size be between Algo.ws_size and pair.first
*/ */
std::multimap<Key, Algo, Key::Comp> d; std::multimap<Key, Algo, Key::Comp> d;
} algoCache; } algoCache;
static inline int getIntEnv(const char* name, int defval, int minval) { static inline int getIntEnv(const char* name, int defval, int minval) {
int val = defval; int val = defval;
const char* env = std::getenv(name); const char* env = std::getenv(name);
if (env != nullptr && env[0] != '\0') { if (env != nullptr && env[0] != '\0') {
val = atoi(env); val = atoi(env);
if (val < minval) { if (val < minval) {
val = minval; val = minval;
} }
} }
return val; return val;
} }
} //namespace } //namespace
/* Warning: only call once per device! /* Warning: only call once per device!
* When calling nvte_multi_stream_cublas_gemm with hipblaslt backend * When calling nvte_multi_stream_cublas_gemm with hipblaslt backend
* need to create multiple handles corresponding to compute_streams * need to create multiple handles corresponding to compute_streams
* to avoid a handle be used by multi-streams concurrently. * to avoid a handle be used by multi-streams concurrently.
*/ */
static void init_hipblaslt_handles(hipblasLtHandle_t* hipblaslt_handles) { static void init_hipblaslt_handles(hipblasLtHandle_t* hipblaslt_handles) {
NVTE_CHECK(hipblaslt_handles != nullptr); NVTE_CHECK(hipblaslt_handles != nullptr);
for (int i = 0; i < compute_num_streams; i++) { for (int i = 0; i < compute_num_streams; i++) {
NVTE_CHECK_HIPBLASLT(hipblasLtCreate(&hipblaslt_handles[i])); NVTE_CHECK_HIPBLASLT(hipblasLtCreate(&hipblaslt_handles[i]));
} }
} }
transformer_engine::DType get_transformer_engine_dtype_from_hipblaslt_dtype(const hipDataType t) { transformer_engine::DType get_transformer_engine_dtype_from_hipblaslt_dtype(const hipDataType t) {
using namespace transformer_engine; using namespace transformer_engine;
switch (t) { switch (t) {
case HIP_R_16F: case HIP_R_16F:
return DType::kFloat16; return DType::kFloat16;
case HIP_R_32F: case HIP_R_32F:
return DType::kFloat32; return DType::kFloat32;
case HIP_R_16BF: case HIP_R_16BF:
return DType::kBFloat16; return DType::kBFloat16;
default: default:
NVTE_ERROR("Invalid type"); NVTE_ERROR("Invalid type");
} }
} }
void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD, void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
const Tensor* inputBias, Tensor* outputPreGelu, int m, int n, int k, int lda, const Tensor* inputBias, Tensor* outputPreGelu, int m, int n, int k, int lda,
int ldb, int ldd, hipblasOperation_t transa, hipblasOperation_t transb, int ldb, int ldd, hipblasOperation_t transa, hipblasOperation_t transb,
bool grad, void* workspace, size_t workspaceSize, bool accumulate, bool grad, void* workspace, size_t workspaceSize, bool accumulate,
bool use_split_accumulator, int math_sm_count, int m_split, int n_split, bool use_split_accumulator, int math_sm_count, int m_split, int n_split,
bool gemm_producer, const Tensor* inputCounter, hipStream_t stream, bool gemm_producer, const Tensor* inputCounter, hipStream_t stream,
hipblasLtHandle_t handle) { hipblasLtHandle_t handle) {
void* A = inputA->data.dptr; void* A = inputA->data.dptr;
void* A_scale_inverse = inputA->scale_inv.dptr; void* A_scale_inverse = inputA->scale_inv.dptr;
float* A_scale_inverse_float = (float*)(inputA->scale_inv.dptr); float* A_scale_inverse_float = (float*)(inputA->scale_inv.dptr);
void* B = inputB->data.dptr; void* B = inputB->data.dptr;
void* B_scale_inverse = inputB->scale_inv.dptr; void* B_scale_inverse = inputB->scale_inv.dptr;
float* B_scale_inverse_float = (float*)(inputB->scale_inv.dptr); float* B_scale_inverse_float = (float*)(inputB->scale_inv.dptr);
void* D = outputD->data.dptr; void* D = outputD->data.dptr;
void* bias_ptr = inputBias->data.dptr; void* bias_ptr = inputBias->data.dptr;
const bool bias = bias_ptr != nullptr; const bool bias = bias_ptr != nullptr;
void* pre_gelu_out = outputPreGelu->data.dptr; void* pre_gelu_out = outputPreGelu->data.dptr;
const bool gelu = pre_gelu_out != nullptr; const bool gelu = pre_gelu_out != nullptr;
const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) || is_fp8_dtype(inputB->data.dtype); const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) || is_fp8_dtype(inputB->data.dtype);
const bool use_int8 = is_int8_dtype(inputA->data.dtype) || is_int8_dtype(inputB->data.dtype); const bool use_int8 = is_int8_dtype(inputA->data.dtype) || is_int8_dtype(inputB->data.dtype);
const hipDataType A_type = get_hipblaslt_dtype(inputA->data.dtype); const hipDataType A_type = get_hipblaslt_dtype(inputA->data.dtype);
const hipDataType B_type = get_hipblaslt_dtype(inputB->data.dtype); const hipDataType B_type = get_hipblaslt_dtype(inputB->data.dtype);
const hipDataType D_type = get_hipblaslt_dtype(outputD->data.dtype); const hipDataType D_type = get_hipblaslt_dtype(outputD->data.dtype);
const hipDataType bias_type = get_hipblaslt_dtype(inputBias->data.dtype); const hipDataType bias_type = get_hipblaslt_dtype(inputBias->data.dtype);
NVTE_CHECK(!is_fp8_dtype(inputA->data.dtype) || A_scale_inverse != nullptr, NVTE_CHECK(!is_fp8_dtype(inputA->data.dtype) || A_scale_inverse != nullptr,
"FP8 input to GEMM requires inverse of scale!"); "FP8 input to GEMM requires inverse of scale!");
NVTE_CHECK(!is_fp8_dtype(inputB->data.dtype) || B_scale_inverse != nullptr, NVTE_CHECK(!is_fp8_dtype(inputB->data.dtype) || B_scale_inverse != nullptr,
"FP8 input to GEMM requires inverse of scale!"); "FP8 input to GEMM requires inverse of scale!");
NVTE_CHECK(!is_int8_dtype(inputA->data.dtype) || A_scale_inverse != nullptr, NVTE_CHECK(!is_int8_dtype(inputA->data.dtype) || A_scale_inverse != nullptr,
"INT8 input to GEMM requires inverse of scale!"); "INT8 input to GEMM requires inverse of scale!");
NVTE_CHECK(!is_int8_dtype(inputB->data.dtype) || B_scale_inverse != nullptr, NVTE_CHECK(!is_int8_dtype(inputB->data.dtype) || B_scale_inverse != nullptr,
"INT8 input to GEMM requires inverse of scale!"); "INT8 input to GEMM requires inverse of scale!");
bool tensorwise_int8 = 0;; bool tensorwise_int8 = 0;;
const char* NVTE_INT8_SIM_FP8_TENSORWISE = std::getenv("NVTE_INT8_SIM_FP8_TENSORWISE"); const char* NVTE_INT8_SIM_FP8_TENSORWISE = std::getenv("NVTE_INT8_SIM_FP8_TENSORWISE");
if (NVTE_INT8_SIM_FP8_TENSORWISE != nullptr && NVTE_INT8_SIM_FP8_TENSORWISE[0] == '1' && use_int8) tensorwise_int8 = 1; if (NVTE_INT8_SIM_FP8_TENSORWISE != nullptr && NVTE_INT8_SIM_FP8_TENSORWISE[0] == '1' && use_int8) tensorwise_int8 = 1;
// check consistency of arguments: // check consistency of arguments:
// if fp8 is desired, context cannot be null // if fp8 is desired, context cannot be null
// fp8 + gelu fusion + fp8 aux is unavailable right now. // fp8 + gelu fusion + fp8 aux is unavailable right now.
if (use_fp8 || use_int8) { if (use_fp8 || use_int8) {
NVTE_CHECK(!gelu, "fp8 gemm + gelu fusion is unavailable right now!"); NVTE_CHECK(!gelu, "fp8 gemm + gelu fusion is unavailable right now!");
} }
float one = 1.0; float one = 1.0;
float zero = 0.0; float zero = 0.0;
float beta = (accumulate) ? one : zero; float beta = (accumulate) ? one : zero;
int device_id; int device_id;
NVTE_CHECK_CUDA(hipGetDevice(&device_id)); NVTE_CHECK_CUDA(hipGetDevice(&device_id));
if (handle == nullptr) { if (handle == nullptr) {
handle = cached_handles.get(device_id); handle = cached_handles.get(device_id);
if (handle == nullptr) { if (handle == nullptr) {
handle = cached_handles.obtain(device_id); handle = cached_handles.obtain(device_id);
} }
} }
hipblasLtMatmulDesc_t operationDesc = nullptr; hipblasLtMatmulDesc_t operationDesc = nullptr;
hipblasLtMatrixLayout_t Adesc = nullptr, Bdesc = nullptr, Cdesc = nullptr, Ddesc = nullptr; hipblasLtMatrixLayout_t Adesc = nullptr, Bdesc = nullptr, Cdesc = nullptr, Ddesc = nullptr;
hipblasLtMatmulPreference_t preference = nullptr; hipblasLtMatmulPreference_t preference = nullptr;
hipblasLtEpilogue_t epilogue = HIPBLASLT_EPILOGUE_DEFAULT; hipblasLtEpilogue_t epilogue = HIPBLASLT_EPILOGUE_DEFAULT;
int64_t ld_gelumat = (int64_t)ldd; int64_t ld_gelumat = (int64_t)ldd;
// default to tf32 except for e5m2 inputs where the config is not supported // default to tf32 except for e5m2 inputs where the config is not supported
hipblasComputeType_t gemm_compute_type = HIPBLAS_COMPUTE_32F; hipblasComputeType_t gemm_compute_type = HIPBLAS_COMPUTE_32F;
// Create matrix descriptors. Not setting any extra attributes. // Create matrix descriptors. Not setting any extra attributes.
NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutCreate(&Adesc, A_type, transa == HIPBLAS_OP_N ? m : k, NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutCreate(&Adesc, A_type, transa == HIPBLAS_OP_N ? m : k,
transa == HIPBLAS_OP_N ? k : m, lda)); transa == HIPBLAS_OP_N ? k : m, lda));
NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutCreate(&Bdesc, B_type, transb == HIPBLAS_OP_N ? k : n, NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutCreate(&Bdesc, B_type, transb == HIPBLAS_OP_N ? k : n,
transb == HIPBLAS_OP_N ? n : k, ldb)); transb == HIPBLAS_OP_N ? n : k, ldb));
NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutCreate(&Ddesc, D_type, m, n, ldd)); NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutCreate(&Ddesc, D_type, m, n, ldd));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescCreate(&operationDesc, gemm_compute_type, HIP_R_32F)); NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescCreate(&operationDesc, gemm_compute_type, HIP_R_32F));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc, HIPBLASLT_MATMUL_DESC_TRANSA, NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc, HIPBLASLT_MATMUL_DESC_TRANSA,
&transa, sizeof(transa))); &transa, sizeof(transa)));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc, HIPBLASLT_MATMUL_DESC_TRANSB, NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc, HIPBLASLT_MATMUL_DESC_TRANSB,
&transb, sizeof(transb))); &transb, sizeof(transb)));
// set fp8 attributes -- input and output types should already be set to fp8 as appropriate // set fp8 attributes -- input and output types should already be set to fp8 as appropriate
// Note: gelu fusion isn't available right now, and we don't need // Note: gelu fusion isn't available right now, and we don't need
// amax(D) either (next op is high precision). // amax(D) either (next op is high precision).
if (use_fp8) { if (use_fp8) {
// Split accumulator. // Split accumulator.
const int8_t fastAccuMode = (use_split_accumulator) ? 0 : 1; const int8_t fastAccuMode = (use_split_accumulator) ? 0 : 1;
/* /*
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc, NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
HIPBLASLT_MATMUL_DESC_FAST_ACCUM, //TODO: We don't have fast accum mode yet HIPBLASLT_MATMUL_DESC_FAST_ACCUM, //TODO: We don't have fast accum mode yet
&fastAccuMode, &fastAccuMode,
sizeof(fastAccuMode))); sizeof(fastAccuMode)));
*/ */
NVTE_CHECK_HIPBLASLT( NVTE_CHECK_HIPBLASLT(
hipblasLtMatmulDescSetAttribute(operationDesc, HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER, hipblasLtMatmulDescSetAttribute(operationDesc, HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER,
&A_scale_inverse, sizeof(A_scale_inverse))); &A_scale_inverse, sizeof(A_scale_inverse)));
NVTE_CHECK_HIPBLASLT( NVTE_CHECK_HIPBLASLT(
hipblasLtMatmulDescSetAttribute(operationDesc, HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER, hipblasLtMatmulDescSetAttribute(operationDesc, HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER,
&B_scale_inverse, sizeof(B_scale_inverse))); &B_scale_inverse, sizeof(B_scale_inverse)));
if (bias) { if (bias) {
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute( NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(
operationDesc, HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, &bias_type, sizeof(bias_type))); operationDesc, HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, &bias_type, sizeof(bias_type)));
} }
} }
if (tensorwise_int8) { if (tensorwise_int8) {
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc, NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER, HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER,
(void*)&A_scale_inverse_float, (void*)&A_scale_inverse_float,
sizeof(void*))); sizeof(void*)));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc, NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER, HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER,
(void*)&B_scale_inverse_float, (void*)&B_scale_inverse_float,
sizeof(void*))); sizeof(void*)));
} }
if (bias && gelu) { if (bias && gelu) {
if (grad) { if (grad) {
epilogue = HIPBLASLT_EPILOGUE_DGELU_BGRAD; epilogue = HIPBLASLT_EPILOGUE_DGELU_BGRAD;
} else { } else {
epilogue = HIPBLASLT_EPILOGUE_GELU_AUX_BIAS; epilogue = HIPBLASLT_EPILOGUE_GELU_AUX_BIAS;
} }
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute( NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(
operationDesc, HIPBLASLT_MATMUL_DESC_BIAS_POINTER, &bias_ptr, sizeof(bias_ptr))); operationDesc, HIPBLASLT_MATMUL_DESC_BIAS_POINTER, &bias_ptr, sizeof(bias_ptr)));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc, NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
&pre_gelu_out, sizeof(pre_gelu_out))); &pre_gelu_out, sizeof(pre_gelu_out)));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute( NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(
operationDesc, HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ld_gelumat, sizeof(ld_gelumat))); operationDesc, HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ld_gelumat, sizeof(ld_gelumat)));
} else if (bias) { } else if (bias) {
if (tensorwise_int8) { if (tensorwise_int8) {
if (grad) { if (grad) {
int batch_size = k; int batch_size = k;
int output_dim = n; int output_dim = n;
DType te_bias_dtype = get_transformer_engine_dtype_from_hipblaslt_dtype(bias_type); DType te_bias_dtype = get_transformer_engine_dtype_from_hipblaslt_dtype(bias_type);
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
te_bias_dtype, BType,· te_bias_dtype, BType,·
detail::tensorwise_int8_bias_gradient_kernelLauncher<BType>( detail::tensorwise_int8_bias_gradient_kernelLauncher<BType>(
reinterpret_cast<const int8_t*>(B), reinterpret_cast<BType*>(bias_ptr), B_scale_inverse_float, batch_size, reinterpret_cast<const int8_t*>(B), reinterpret_cast<BType*>(bias_ptr), B_scale_inverse_float, batch_size,
output_dim, stream);); output_dim, stream););
} else { } else {
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute( NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(
operationDesc, HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, &bias_type, sizeof(bias_type))); operationDesc, HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, &bias_type, sizeof(bias_type)));
epilogue = HIPBLASLT_EPILOGUE_BIAS; epilogue = HIPBLASLT_EPILOGUE_BIAS;
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute( NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(
operationDesc, HIPBLASLT_MATMUL_DESC_BIAS_POINTER, &bias_ptr, sizeof(bias_ptr))); operationDesc, HIPBLASLT_MATMUL_DESC_BIAS_POINTER, &bias_ptr, sizeof(bias_ptr)));
} }
} else { } else {
if (grad) { if (grad) {
// grad output is always input B // grad output is always input B
epilogue = HIPBLASLT_EPILOGUE_BGRADB; epilogue = HIPBLASLT_EPILOGUE_BGRADB;
} else { } else {
epilogue = HIPBLASLT_EPILOGUE_BIAS; epilogue = HIPBLASLT_EPILOGUE_BIAS;
} }
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute( NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(
operationDesc, HIPBLASLT_MATMUL_DESC_BIAS_POINTER, &bias_ptr, sizeof(bias_ptr))); operationDesc, HIPBLASLT_MATMUL_DESC_BIAS_POINTER, &bias_ptr, sizeof(bias_ptr)));
} }
} else if (gelu) { } else if (gelu) {
if (grad) { if (grad) {
epilogue = HIPBLASLT_EPILOGUE_DGELU; epilogue = HIPBLASLT_EPILOGUE_DGELU;
} else { } else {
epilogue = HIPBLASLT_EPILOGUE_GELU_AUX; epilogue = HIPBLASLT_EPILOGUE_GELU_AUX;
} }
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc, NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
&pre_gelu_out, sizeof(pre_gelu_out))); &pre_gelu_out, sizeof(pre_gelu_out)));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute( NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(
operationDesc, HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ld_gelumat, sizeof(ld_gelumat))); operationDesc, HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ld_gelumat, sizeof(ld_gelumat)));
} }
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute( NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(
operationDesc, HIPBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue))); operationDesc, HIPBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue)));
GemmAlgoCache::Key gemm_cfg(algoCache.device_cap(device_id), A_type, B_type, D_type, GemmAlgoCache::Key gemm_cfg(algoCache.device_cap(device_id), A_type, B_type, D_type,
use_fp8 ? bias_type : (hipDataType)-1, m, n, k, lda, ldb, ldd, transa, use_fp8 ? bias_type : (hipDataType)-1, m, n, k, lda, ldb, ldd, transa,
transb, epilogue); transb, epilogue);
GemmAlgoCache::Algo cached_algo; GemmAlgoCache::Algo cached_algo;
if (algoCache.find(gemm_cfg, workspaceSize, cached_algo) == 0 || !cached_algo.algo.has_value()) { if (algoCache.find(gemm_cfg, workspaceSize, cached_algo) == 0 || !cached_algo.algo.has_value()) {
int firstAlgo = getIntEnv("TE_HIPBLASLT_ALGO_SELECTION", 0, 0); int firstAlgo = getIntEnv("TE_HIPBLASLT_ALGO_SELECTION", 0, 0);
int tuneLoopCount = getIntEnv("TE_HIPBLASLT_TUNING_RUN_COUNT", 0, 0); int tuneLoopCount = getIntEnv("TE_HIPBLASLT_TUNING_RUN_COUNT", 0, 0);
int algoTuneCount = 1; int algoTuneCount = 1;
std::vector<hipblasLtMatmulHeuristicResult_t> algoArr; std::vector<hipblasLtMatmulHeuristicResult_t> algoArr;
bool logTuning = getIntEnv("TE_HIPBLASLT_LOG_TUNING", 0, 0) != 0; bool logTuning = getIntEnv("TE_HIPBLASLT_LOG_TUNING", 0, 0) != 0;
if (tuneLoopCount) { if (tuneLoopCount) {
/* HIPBLASLT may return hundreds of algos for some configs /* HIPBLASLT may return hundreds of algos for some configs
* Limit amount by default. User may override with env * Limit amount by default. User may override with env
*/ */
static const int defaultAlgoCount = 16; static const int defaultAlgoCount = 16;
algoTuneCount = getIntEnv("TE_HIPBLASLT_TUNING_ALGO_COUNT", defaultAlgoCount, 1); algoTuneCount = getIntEnv("TE_HIPBLASLT_TUNING_ALGO_COUNT", defaultAlgoCount, 1);
} }
algoTuneCount += firstAlgo; algoTuneCount += firstAlgo;
int algoTotalCount = int algoTotalCount =
cached_algo.hasId() ? std::max(algoTuneCount, (cached_algo.index + 1)) : algoTuneCount; cached_algo.hasId() ? std::max(algoTuneCount, (cached_algo.index + 1)) : algoTuneCount;
algoArr.resize(algoTotalCount); algoArr.resize(algoTotalCount);
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulPreferenceCreate(&preference)); NVTE_CHECK_HIPBLASLT(hipblasLtMatmulPreferenceCreate(&preference));
NVTE_CHECK_HIPBLASLT( NVTE_CHECK_HIPBLASLT(
hipblasLtMatmulPreferenceSetAttribute(preference, HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, hipblasLtMatmulPreferenceSetAttribute(preference, HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&workspaceSize, sizeof(workspaceSize))); &workspaceSize, sizeof(workspaceSize)));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Ddesc, NVTE_CHECK_HIPBLASLT(hipblasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Ddesc,
Ddesc, preference, algoTotalCount, Ddesc, preference, algoTotalCount,
algoArr.data(), &algoTotalCount)); algoArr.data(), &algoTotalCount));
algoArr.resize(algoTotalCount); algoArr.resize(algoTotalCount);
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulPreferenceDestroy(preference)); NVTE_CHECK_HIPBLASLT(hipblasLtMatmulPreferenceDestroy(preference));
//If cached algo exists in persistent storage we just need to find matching hipblasLtMatmulAlgo_t //If cached algo exists in persistent storage we just need to find matching hipblasLtMatmulAlgo_t
if (cached_algo.hasId()) { if (cached_algo.hasId()) {
int idx = (cached_algo.index < algoTotalCount) ? cached_algo.index : 0; int idx = (cached_algo.index < algoTotalCount) ? cached_algo.index : 0;
for (int i = 0; i < algoTotalCount; i++) { for (int i = 0; i < algoTotalCount; i++) {
const auto& algo = algoArr[idx]; const auto& algo = algoArr[idx];
if (algo.state == HIPBLAS_STATUS_SUCCESS) { if (algo.state == HIPBLAS_STATUS_SUCCESS) {
if (cached_algo.algoId == cached_algo.getAlgoId(algo.algo)) { if (cached_algo.algoId == cached_algo.getAlgoId(algo.algo)) {
cached_algo.algo = algo.algo; cached_algo.algo = algo.algo;
if (algo.workspaceSize != cached_algo.ws_size_min || idx != cached_algo.index) { if (algo.workspaceSize != cached_algo.ws_size_min || idx != cached_algo.index) {
cached_algo.ws_size_min = algo.workspaceSize; cached_algo.ws_size_min = algo.workspaceSize;
cached_algo.index = idx; cached_algo.index = idx;
algoCache.store(gemm_cfg, cached_algo); algoCache.store(gemm_cfg, cached_algo);
} }
break; break;
} }
} }
idx = (idx + 1) % algoTotalCount; idx = (idx + 1) % algoTotalCount;
} }
if (logTuning && !cached_algo.algo.has_value()) { if (logTuning && !cached_algo.algo.has_value()) {
std::cout << "[WARNING] Cannot find cached algoId " << cached_algo.algoId std::cout << "[WARNING] Cannot find cached algoId " << cached_algo.algoId
<< " in hipBLASLt results" << std::endl; << " in hipBLASLt results" << std::endl;
} }
} }
//No suitable entry in autotune cache or could not find matched algo in hipBLASLt results //No suitable entry in autotune cache or could not find matched algo in hipBLASLt results
if (!cached_algo.algo.has_value()) { if (!cached_algo.algo.has_value()) {
int bestAlgo = -1; int bestAlgo = -1;
algoTuneCount = std::min(algoTuneCount, algoTotalCount); algoTuneCount = std::min(algoTuneCount, algoTotalCount);
if (tuneLoopCount > 0) { if (tuneLoopCount > 0) {
if (logTuning) if (logTuning)
std::cout << "[INFO] Perform hipBLASLt algo selection on GPU" << device_id std::cout << "[INFO] Perform hipBLASLt algo selection on GPU" << device_id
<< " in range [" << firstAlgo << "-" << (algoTuneCount - 1) << "] with " << " in range [" << firstAlgo << "-" << (algoTuneCount - 1) << "] with "
<< tuneLoopCount << " loops " << std::endl; << tuneLoopCount << " loops " << std::endl;
NVTE_CHECK_CUDA(hipStreamSynchronize(stream)); NVTE_CHECK_CUDA(hipStreamSynchronize(stream));
hipStream_t profilingStream; hipStream_t profilingStream;
NVTE_CHECK_CUDA(hipStreamCreateWithFlags(&profilingStream, hipStreamNonBlocking)); NVTE_CHECK_CUDA(hipStreamCreateWithFlags(&profilingStream, hipStreamNonBlocking));
using tuning_clock = std::chrono::steady_clock; using tuning_clock = std::chrono::steady_clock;
tuning_clock::now(); //the first call takes little longer so do it outside the loop tuning_clock::now(); //the first call takes little longer so do it outside the loop
tuning_clock::duration bestTime = tuning_clock::duration::max(); tuning_clock::duration bestTime = tuning_clock::duration::max();
for (int algo = firstAlgo; algo < algoTuneCount; algo++) { for (int algo = firstAlgo; algo < algoTuneCount; algo++) {
if (algoArr[algo].state != HIPBLAS_STATUS_SUCCESS) { if (algoArr[algo].state != HIPBLAS_STATUS_SUCCESS) {
continue; continue;
} }
// Warm-up call // Warm-up call
NVTE_CHECK_HIPBLASLT(hipblasLtMatmul(handle, operationDesc, NVTE_CHECK_HIPBLASLT(hipblasLtMatmul(handle, operationDesc,
static_cast<const void*>(&one), /* alpha */ static_cast<const void*>(&one), /* alpha */
A, /* A */ A, /* A */
Adesc, B, /* B */ Adesc, B, /* B */
Bdesc, static_cast<const void*>(&beta), /* beta */ Bdesc, static_cast<const void*>(&beta), /* beta */
D, /* C */ D, /* C */
Ddesc, D, /* D */ Ddesc, D, /* D */
Ddesc, &algoArr[algo].algo, /* algo */ Ddesc, &algoArr[algo].algo, /* algo */
workspace, /* workspace */ workspace, /* workspace */
workspaceSize, profilingStream)); /* stream */ workspaceSize, profilingStream)); /* stream */
NVTE_CHECK_CUDA(hipStreamSynchronize(profilingStream)); NVTE_CHECK_CUDA(hipStreamSynchronize(profilingStream));
//Profiling loop //Profiling loop
tuning_clock::time_point startTime = tuning_clock::now(); tuning_clock::time_point startTime = tuning_clock::now();
for (int loop = 0; loop < tuneLoopCount; loop++) { for (int loop = 0; loop < tuneLoopCount; loop++) {
NVTE_CHECK_HIPBLASLT(hipblasLtMatmul(handle, operationDesc, NVTE_CHECK_HIPBLASLT(hipblasLtMatmul(handle, operationDesc,
static_cast<const void*>(&one), /* alpha */ static_cast<const void*>(&one), /* alpha */
A, /* A */ A, /* A */
Adesc, B, /* B */ Adesc, B, /* B */
Bdesc, static_cast<const void*>(&beta), /* beta */ Bdesc, static_cast<const void*>(&beta), /* beta */
D, /* C */ D, /* C */
Ddesc, D, /* D */ Ddesc, D, /* D */
Ddesc, &algoArr[algo].algo, /* algo */ Ddesc, &algoArr[algo].algo, /* algo */
workspace, /* workspace */ workspace, /* workspace */
workspaceSize, profilingStream)); /* stream */ workspaceSize, profilingStream)); /* stream */
} }
NVTE_CHECK_CUDA(hipStreamSynchronize(profilingStream)); NVTE_CHECK_CUDA(hipStreamSynchronize(profilingStream));
tuning_clock::duration algoTime = tuning_clock::now() - startTime; tuning_clock::duration algoTime = tuning_clock::now() - startTime;
if (algoTime < bestTime) { if (algoTime < bestTime) {
bestAlgo = algo; bestAlgo = algo;
bestTime = algoTime; bestTime = algoTime;
} }
} }
NVTE_CHECK_CUDA(hipStreamDestroy(profilingStream)); NVTE_CHECK_CUDA(hipStreamDestroy(profilingStream));
if (bestAlgo >= 0) { if (bestAlgo >= 0) {
if (logTuning) if (logTuning)
std::cout << "[INFO] Select hipBLASLt algo " << bestAlgo << " with time " std::cout << "[INFO] Select hipBLASLt algo " << bestAlgo << " with time "
<< std::chrono::duration_cast<std::chrono::nanoseconds>(bestTime).count() / << std::chrono::duration_cast<std::chrono::nanoseconds>(bestTime).count() /
tuneLoopCount tuneLoopCount
<< " ns" << std::endl; << " ns" << std::endl;
} }
} else if (firstAlgo < algoTuneCount) { } else if (firstAlgo < algoTuneCount) {
bestAlgo = firstAlgo; bestAlgo = firstAlgo;
} }
if (bestAlgo < 0) { if (bestAlgo < 0) {
NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Ddesc)); NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Ddesc));
NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Bdesc)); NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Bdesc));
NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Adesc)); NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Adesc));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescDestroy(operationDesc)); NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescDestroy(operationDesc));
throw std::runtime_error("Unable to find any suitable algorithms"); throw std::runtime_error("Unable to find any suitable algorithms");
} }
cached_algo.algo = algoArr[bestAlgo].algo; cached_algo.algo = algoArr[bestAlgo].algo;
cached_algo.index = bestAlgo; cached_algo.index = bestAlgo;
cached_algo.algoId = cached_algo.getAlgoId(algoArr[bestAlgo].algo); cached_algo.algoId = cached_algo.getAlgoId(algoArr[bestAlgo].algo);
cached_algo.ws_size_min = algoArr[bestAlgo].workspaceSize; cached_algo.ws_size_min = algoArr[bestAlgo].workspaceSize;
cached_algo.ws_size_max = workspaceSize; cached_algo.ws_size_max = workspaceSize;
if (logTuning) if (logTuning)
std::cout << "[INFO] Use hipBLASLt algo [" << bestAlgo << "] " << cached_algo.algoId std::cout << "[INFO] Use hipBLASLt algo [" << bestAlgo << "] " << cached_algo.algoId
<< std::endl; << std::endl;
algoCache.store(gemm_cfg, cached_algo); algoCache.store(gemm_cfg, cached_algo);
} }
} }
// D = alpha * (A * B) + beta * C // D = alpha * (A * B) + beta * C
NVTE_CHECK_HIPBLASLT(hipblasLtMatmul(handle, operationDesc, NVTE_CHECK_HIPBLASLT(hipblasLtMatmul(handle, operationDesc,
static_cast<const void*>(&one), /* alpha */ static_cast<const void*>(&one), /* alpha */
A, /* A */ A, /* A */
Adesc, B, /* B */ Adesc, B, /* B */
Bdesc, static_cast<const void*>(&beta), /* beta */ Bdesc, static_cast<const void*>(&beta), /* beta */
D, /* C */ D, /* C */
Ddesc, D, /* D */ Ddesc, D, /* D */
Ddesc, &cached_algo.algo.value(), /* algo */ Ddesc, &cached_algo.algo.value(), /* algo */
workspace, /* workspace */ workspace, /* workspace */
workspaceSize, stream)); /* stream */ workspaceSize, stream)); /* stream */
NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Ddesc)); NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Ddesc));
NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Bdesc)); NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Bdesc));
NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Adesc)); NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Adesc));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescDestroy(operationDesc)); NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescDestroy(operationDesc));
} }
struct HipBlasLtUserArgsDeleter {
class userArgsManager { void operator()(hipblaslt_ext::UserArguments* ptr) const noexcept {
public: hipFree(ptr);
userArgsManager() {} }
};
~userArgsManager() {
// Release all userArgs when the manager is destroyed using HipBlasLtUserArgsPtr = std::unique_ptr<hipblaslt_ext::UserArguments, HipBlasLtUserArgsDeleter>;
for (auto& device_pair : userArgs_map_) {
hipFree(device_pair.second); // Only one userArgs per device inline HipBlasLtUserArgsPtr make_hipblaslt_user_args_ptr(size_t size, bool host) {
} hipblaslt_ext::UserArguments* raw_ptr = nullptr;
} if (host) {
NVTE_CHECK_CUDA(hipHostMalloc(&raw_ptr, size * sizeof(hipblaslt_ext::UserArguments)));
// Get a userArgs for the given device (creates if necessary) } else {
hipblaslt_ext::UserArguments* get(int device_id, size_t size) { NVTE_CHECK_CUDA(hipMalloc(&raw_ptr, size * sizeof(hipblaslt_ext::UserArguments)));
std::lock_guard<std::mutex> lock(mutex_); }
return HipBlasLtUserArgsPtr(raw_ptr);
// Check if the userArgs for this device exists }
auto device_it = userArgs_map_.find(device_id);
if (device_it != userArgs_map_.end()) { inline hipblaslt_ext::UserArguments* get_hipblaslt_user_args(size_t size, bool host) {
return device_it->second; thread_local static std::unordered_map<size_t, HipBlasLtUserArgsPtr> host_userargs_cache;
} thread_local static std::unordered_map<size_t, HipBlasLtUserArgsPtr> device_userargs_cache;
std::unordered_map<size_t, HipBlasLtUserArgsPtr>& user_args_cache = host ? host_userargs_cache : device_userargs_cache;
// Create a new userArgs for this device if it doesn't exist auto size_it = user_args_cache.find(size);
hipblaslt_ext::UserArguments* userArgs; if (size_it != user_args_cache.end()) {
NVTE_CHECK_CUDA(hipHostMalloc(&userArgs, size * sizeof(hipblaslt_ext::UserArguments))); return size_it->second.get();
}
// Store the userArgs in the map for this device else
userArgs_map_[device_id] = userArgs; {
return userArgs; HipBlasLtUserArgsPtr user_args = make_hipblaslt_user_args_ptr(size, host);
} hipblaslt_ext::UserArguments* raw_ptr = user_args.get();
user_args_cache[size] = std::move(user_args);
private: return raw_ptr;
std::unordered_map<int, hipblaslt_ext::UserArguments*> }
userArgs_map_; // Map from device_id to hipblasHandle }
std::mutex mutex_;
};
void hipblaslt_groupedgemm(std::vector<const Tensor*>& inputA, std::vector<const Tensor*>& inputB,
class d_userArgsManager { std::vector<Tensor*>& outputD, std::vector<int64_t>& m,
public: std::vector<int64_t>& n, std::vector<int64_t>& k, std::vector<int64_t>& b,
d_userArgsManager() {} hipblasOperation_t transa, hipblasOperation_t transb, void* workspace,
size_t workspaceSize, bool accumulate, bool use_split_accumulator,
~d_userArgsManager() { int math_sm_count, hipStream_t stream, int compute_stream_offset = 0) {
// Release all userArgs when the manager is destroyed // Check compute_stream_offset valid.
for (auto& device_pair : d_userArgs_map_) { NVTE_CHECK(compute_stream_offset >= -1 && compute_stream_offset < compute_num_streams);
hipFree(device_pair.second); // Only one userArgs per device
} hipblaslt_ext::UserArguments* userArgs = get_hipblaslt_user_args(m.size(), true);
} hipblaslt_ext::UserArguments* d_userArgs = get_hipblaslt_user_args(m.size(), false);
// Get a userArgs for the given device (creates if necessary) // hipblaslt_ext::UserArguments* userArgs;
hipblaslt_ext::UserArguments* get(int device_id, size_t size) { // NVTE_CHECK_CUDA(hipHostMalloc(&userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments)));
std::lock_guard<std::mutex> lock(mutex_);
hipblasLtHandle_t handle = nullptr;
// Check if the userArgs for this device exists if (compute_stream_offset != -1) {
auto device_it = d_userArgs_map_.find(device_id); // Init hipblaslt handles (once, globally)
if (device_it != d_userArgs_map_.end()) { static std::once_flag init_flag;
return device_it->second; static hipblasLtHandle_t hipblaslt_handles[compute_num_streams];
} std::call_once(init_flag, init_hipblaslt_handles, hipblaslt_handles);
// Create a new userArgs for this device if it doesn't exist handle = hipblaslt_handles[compute_stream_offset];
hipblaslt_ext::UserArguments* d_userArgs; }
NVTE_CHECK_CUDA(hipMalloc(&d_userArgs, size * sizeof(hipblaslt_ext::UserArguments)));
const hipDataType A_type = get_hipblaslt_dtype(inputA[0]->data.dtype);
// Store the userArgs in the map for this device const hipDataType B_type = get_hipblaslt_dtype(inputB[0]->data.dtype);
d_userArgs_map_[device_id] = d_userArgs; const hipDataType D_type = get_hipblaslt_dtype(outputD[0]->data.dtype);
return d_userArgs;
} hipblasComputeType_t computeType = HIPBLAS_COMPUTE_32F;
private: float one = 1.0;
std::unordered_map<int, hipblaslt_ext::UserArguments*> float zero = 0.0;
d_userArgs_map_; // Map from device_id to hipblasHandle float beta = (accumulate) ? one : zero;
std::mutex mutex_; int int_one = 1;
}; int int_zero = 0;
int int_beta = int_zero;
// Define a static userArgs manager bool use_int8 = false;
static userArgsManager UAManager;
static d_userArgsManager d_UAManager; if ((A_type == HIP_R_8I) && (B_type == HIP_R_8I) && (D_type == HIP_R_32I)) {
NVTE_CHECK(!accumulate, "Int8 gemm not support accumulate.");
void hipblaslt_groupedgemm(std::vector<const Tensor*>& inputA, std::vector<const Tensor*>& inputB, use_int8 = true;
std::vector<Tensor*>& outputD, std::vector<int64_t>& m, computeType = HIPBLAS_COMPUTE_32I;
std::vector<int64_t>& n, std::vector<int64_t>& k, std::vector<int64_t>& b, }
hipblasOperation_t transa, hipblasOperation_t transb, void* workspace,
size_t workspaceSize, bool accumulate, bool use_split_accumulator, hipblaslt_ext::GemmPreference gemmPref;
int math_sm_count, hipStream_t stream, int compute_stream_offset = 0) { gemmPref.setMaxWorkspaceBytes(workspaceSize);
// Check compute_stream_offset valid. hipblaslt_ext::GroupedGemm groupedgemm(handle, transa, transb, A_type, B_type, D_type, D_type,
NVTE_CHECK(compute_stream_offset >= -1 && compute_stream_offset < compute_num_streams); computeType);
int device_id; std::vector<hipblaslt_ext::GemmEpilogue> epilogue{
hipGetDevice(&device_id); hipblaslt_ext::
hipblaslt_ext::UserArguments* userArgs = UAManager.get(device_id, m.size()); GemmEpilogue()}; // No action needed, default is HIPBLASLT_EPILOGUE_DEFAULT. (Gemm only)
hipblaslt_ext::UserArguments* d_userArgs = d_UAManager.get(device_id, m.size()); std::vector<hipblaslt_ext::GemmInputs> inputs(m.size());
for (int i = 0; i < m.size(); i++) {
// hipblaslt_ext::UserArguments* userArgs; inputs[i].a = inputA[i]->data.dptr;
// NVTE_CHECK_CUDA(hipHostMalloc(&userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments))); inputs[i].b = inputB[i]->data.dptr;
inputs[i].c = outputD[i]->data.dptr;
hipblasLtHandle_t handle = nullptr; inputs[i].d = outputD[i]->data.dptr;
if (compute_stream_offset != -1) { inputs[i].alpha = use_int8 ? static_cast<void*>(&int_one) : static_cast<void*>(&one);
// Init hipblaslt handles (once, globally) inputs[i].beta = use_int8 ? static_cast<void*>(&int_beta) : static_cast<void*>(&beta);
static std::once_flag init_flag; }
static hipblasLtHandle_t hipblaslt_handles[compute_num_streams]; // hipblaslt_ext::GemmEpilogue supports broadcasting
std::call_once(init_flag, init_hipblaslt_handles, hipblaslt_handles); groupedgemm.setProblem(m, n, k, b, epilogue, inputs);
handle = hipblaslt_handles[compute_stream_offset]; const int request_solutions = 1;
} std::vector<hipblasLtMatmulHeuristicResult_t> heuristicResult;
NVTE_CHECK_HIPBLASLT(groupedgemm.algoGetHeuristic(request_solutions, gemmPref, heuristicResult));
const hipDataType A_type = get_hipblaslt_dtype(inputA[0]->data.dtype);
const hipDataType B_type = get_hipblaslt_dtype(inputB[0]->data.dtype); if (heuristicResult.empty()) {
const hipDataType D_type = get_hipblaslt_dtype(outputD[0]->data.dtype); std::cerr << "No valid solution found!" << std::endl;
return;
hipblasComputeType_t computeType = HIPBLAS_COMPUTE_32F; }
float one = 1.0; // Make sure to initialize everytime the algo changes
float zero = 0.0; NVTE_CHECK_HIPBLASLT(groupedgemm.initialize(heuristicResult[0].algo, workspace));
float beta = (accumulate) ? one : zero;
int int_one = 1; // Get the default values from the grouepdgemm object
int int_zero = 0; groupedgemm.getDefaultValueForDeviceUserArguments(userArgs);
int int_beta = int_zero; // Copy them to device memory
bool use_int8 = false; // hipblaslt_ext::UserArguments* d_userArgs;
// NVTE_CHECK_CUDA(hipMallocAsync(&d_userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments), stream));
if ((A_type == HIP_R_8I) && (B_type == HIP_R_8I) && (D_type == HIP_R_32I)) { NVTE_CHECK_CUDA(hipMemcpy(d_userArgs, userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments),
NVTE_CHECK(!accumulate, "Int8 gemm not support accumulate."); hipMemcpyHostToDevice));
use_int8 = true;
computeType = HIPBLAS_COMPUTE_32I; NVTE_CHECK_HIPBLASLT(groupedgemm.run(d_userArgs, stream));
} // NVTE_CHECK_HIPBLASLT(groupedgemm.initialize(heuristicResult[0].algo, workspace, false, stream));
// NVTE_CHECK_HIPBLASLT(groupedgemm.run(stream));
hipblaslt_ext::GemmPreference gemmPref;
gemmPref.setMaxWorkspaceBytes(workspaceSize); // NVTE_CHECK_CUDA(hipFreeAsync(d_userArgs, stream));
hipblaslt_ext::GroupedGemm groupedgemm(handle, transa, transb, A_type, B_type, D_type, D_type, // NVTE_CHECK_CUDA(hipFree(userArgs));
computeType); }
std::vector<hipblaslt_ext::GemmEpilogue> epilogue{ #endif //USE_HIPBLASLT
hipblaslt_ext::
GemmEpilogue()}; // No action needed, default is HIPBLASLT_EPILOGUE_DEFAULT. (Gemm only) #ifdef USE_ROCBLAS // Use rocblas + kernel, no fusion
std::vector<hipblaslt_ext::GemmInputs> inputs(m.size());
for (int i = 0; i < m.size(); i++) { inline void CreateRocblasHandle(rocblas_handle* handle) {
inputs[i].a = inputA[i]->data.dptr; NVTE_CHECK_ROCBLAS(rocblas_create_handle(handle));
inputs[i].b = inputB[i]->data.dptr; }
inputs[i].c = outputD[i]->data.dptr;
inputs[i].d = outputD[i]->data.dptr; using rocblasHandleManager = detail::HandleManager<rocblas_handle, CreateRocblasHandle>;
inputs[i].alpha = use_int8 ? static_cast<void*>(&int_one) : static_cast<void*>(&one); void rocblas_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
inputs[i].beta = use_int8 ? static_cast<void*>(&int_beta) : static_cast<void*>(&beta); const Tensor* inputBias, Tensor* outputPreGelu, int m, int n, int k, int lda,
} int ldb, int ldd, rocblas_operation transa, rocblas_operation transb, bool grad,
// hipblaslt_ext::GemmEpilogue supports broadcasting void* workspace, size_t workspaceSize, bool accumulate,
groupedgemm.setProblem(m, n, k, b, epilogue, inputs); bool use_split_accumulator, int math_sm_count, int m_split, int n_split,
bool gemm_producer, const Tensor* inputCounter, hipStream_t stream) {
const int request_solutions = 1; void* A = inputA->data.dptr;
std::vector<hipblasLtMatmulHeuristicResult_t> heuristicResult; void* A_scale_inverse = inputA->scale_inv.dptr;
NVTE_CHECK_HIPBLASLT(groupedgemm.algoGetHeuristic(request_solutions, gemmPref, heuristicResult)); void* B = inputB->data.dptr;
void* B_scale_inverse = inputB->scale_inv.dptr;
if (heuristicResult.empty()) { void* C = outputD->data.dptr;
std::cerr << "No valid solution found!" << std::endl; void* D = outputD->data.dptr;
return; void* D_scale = outputD->scale.dptr;
} void* D_amax = outputD->amax.dptr;
void* bias_ptr = inputBias->data.dptr;
// Make sure to initialize everytime the algo changes const bool bias = bias_ptr != nullptr;
NVTE_CHECK_HIPBLASLT(groupedgemm.initialize(heuristicResult[0].algo, workspace)); void* pre_gelu_out = outputPreGelu->data.dptr;
const bool gelu = pre_gelu_out != nullptr;
// Get the default values from the grouepdgemm object const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) || is_fp8_dtype(inputB->data.dtype);
groupedgemm.getDefaultValueForDeviceUserArguments(userArgs); const rocblas_datatype A_type = get_rocblas_dtype(inputA->data.dtype);
// Copy them to device memory const rocblas_datatype B_type = get_rocblas_dtype(inputB->data.dtype);
// hipblaslt_ext::UserArguments* d_userArgs; const rocblas_datatype D_type = get_rocblas_dtype(outputD->data.dtype);
// NVTE_CHECK_CUDA(hipMallocAsync(&d_userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments), stream)); const rocblas_datatype bias_type = get_rocblas_dtype(inputBias->data.dtype);
NVTE_CHECK_CUDA(hipMemcpy(d_userArgs, userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments), const rocblas_datatype gelu_type = get_rocblas_dtype(outputPreGelu->data.dtype);
hipMemcpyHostToDevice));
// check consistency of arguments:
NVTE_CHECK_HIPBLASLT(groupedgemm.run(d_userArgs, stream)); // if fp8 is desired, context cannot be null
// NVTE_CHECK_HIPBLASLT(groupedgemm.initialize(heuristicResult[0].algo, workspace, false, stream)); // fp8 + gelu fusion + fp8 aux is unavailable right now.
// NVTE_CHECK_HIPBLASLT(groupedgemm.run(stream)); if (use_fp8 && gelu) {
NVTE_CHECK(!is_fp8_dtype(outputPreGelu->data.dtype),
// NVTE_CHECK_CUDA(hipFreeAsync(d_userArgs, stream)); "fp8 Aux output for gemm + gelu fusion not supported!");
// NVTE_CHECK_CUDA(hipFree(userArgs)); }
} if (is_fp8_dtype(outputD->data.dtype)) {
NVTE_CHECK(!accumulate, "Accumulation mode not supported with FP8 GEMM output!");
#endif //USE_HIPBLASLT }
// fp8 + grad unavailable in upstream
#ifdef USE_ROCBLAS // Use rocblas + kernel, no fusion NVTE_CHECK(!(use_fp8 && grad), "fp8 + grad not supported!");
inline void CreateRocblasHandle(rocblas_handle* handle) { float one = 1.0;
NVTE_CHECK_ROCBLAS(rocblas_create_handle(handle)); float zero = 0.0;
} float beta = (accumulate) ? one : zero;
using rocblasHandleManager = detail::HandleManager<rocblas_handle, CreateRocblasHandle>; float alpha = 1.0;
void rocblas_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD, if (use_fp8) {
const Tensor* inputBias, Tensor* outputPreGelu, int m, int n, int k, int lda, float A_scale_inv, B_scale_inv;
int ldb, int ldd, rocblas_operation transa, rocblas_operation transb, bool grad, (void)hipMemcpy(&A_scale_inv, A_scale_inverse, sizeof(float), hipMemcpyDeviceToHost);
void* workspace, size_t workspaceSize, bool accumulate, (void)hipMemcpy(&B_scale_inv, B_scale_inverse, sizeof(float), hipMemcpyDeviceToHost);
bool use_split_accumulator, int math_sm_count, int m_split, int n_split, alpha = A_scale_inv * B_scale_inv;
bool gemm_producer, const Tensor* inputCounter, hipStream_t stream) { }
void* A = inputA->data.dptr;
void* A_scale_inverse = inputA->scale_inv.dptr; rocblas_handle handle = rocblasHandleManager::Instance().GetHandle();
void* B = inputB->data.dptr; NVTE_CHECK_ROCBLAS(rocblas_set_stream(handle, stream));
void* B_scale_inverse = inputB->scale_inv.dptr;
void* C = outputD->data.dptr; // extract the stream order alloc env
void* D = outputD->data.dptr; bool stream_order_alloc = false;
void* D_scale = outputD->scale.dptr; if (const char* env_p = std::getenv("ROCBLAS_STREAM_ORDER_ALLOC")) {
void* D_amax = outputD->amax.dptr; if (env_p == nullptr || std::string(env_p) == "1") stream_order_alloc = true;
void* bias_ptr = inputBias->data.dptr; }
const bool bias = bias_ptr != nullptr;
void* pre_gelu_out = outputPreGelu->data.dptr; int64_t ld_gelumat = (int64_t)ldd;
const bool gelu = pre_gelu_out != nullptr;
const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) || is_fp8_dtype(inputB->data.dtype); NVTE_CHECK((A_type == rocblas_datatype_f16_r && B_type == rocblas_datatype_f16_r &&
const rocblas_datatype A_type = get_rocblas_dtype(inputA->data.dtype); D_type == rocblas_datatype_f16_r) ||
const rocblas_datatype B_type = get_rocblas_dtype(inputB->data.dtype); (A_type == rocblas_datatype_f16_r && B_type == rocblas_datatype_f16_r &&
const rocblas_datatype D_type = get_rocblas_dtype(outputD->data.dtype); D_type == rocblas_datatype_f32_r) ||
const rocblas_datatype bias_type = get_rocblas_dtype(inputBias->data.dtype); (A_type == rocblas_datatype_bf16_r && B_type == rocblas_datatype_bf16_r &&
const rocblas_datatype gelu_type = get_rocblas_dtype(outputPreGelu->data.dtype); D_type == rocblas_datatype_bf16_r) ||
(A_type == rocblas_datatype_bf16_r && B_type == rocblas_datatype_bf16_r &&
// check consistency of arguments: D_type == rocblas_datatype_f32_r) ||
// if fp8 is desired, context cannot be null (A_type == rocblas_datatype_f32_r && B_type == rocblas_datatype_f32_r &&
// fp8 + gelu fusion + fp8 aux is unavailable right now. D_type == rocblas_datatype_f32_r) ||
if (use_fp8 && gelu) { (A_type == rocblas_datatype_f8_r && B_type == rocblas_datatype_f8_r &&
NVTE_CHECK(!is_fp8_dtype(outputPreGelu->data.dtype), D_type == rocblas_datatype_f32_r) ||
"fp8 Aux output for gemm + gelu fusion not supported!"); (A_type == rocblas_datatype_f8_r && B_type == rocblas_datatype_f8_r &&
} D_type == rocblas_datatype_f16_r) ||
if (is_fp8_dtype(outputD->data.dtype)) { (A_type == rocblas_datatype_f8_r && B_type == rocblas_datatype_f8_r &&
NVTE_CHECK(!accumulate, "Accumulation mode not supported with FP8 GEMM output!"); D_type == rocblas_datatype_bf16_r) ||
} (A_type == rocblas_datatype_f8_r && B_type == rocblas_datatype_f8_r &&
// fp8 + grad unavailable in upstream D_type == rocblas_datatype_f8_r) ||
NVTE_CHECK(!(use_fp8 && grad), "fp8 + grad not supported!"); (A_type == rocblas_datatype_f8_r && B_type == rocblas_datatype_f8_r &&
D_type == rocblas_datatype_bf8_r) ||
float one = 1.0; (A_type == rocblas_datatype_f8_r && B_type == rocblas_datatype_bf8_r &&
float zero = 0.0; D_type == rocblas_datatype_f32_r) ||
float beta = (accumulate) ? one : zero; (A_type == rocblas_datatype_f8_r && B_type == rocblas_datatype_bf8_r &&
D_type == rocblas_datatype_f16_r) ||
float alpha = 1.0; (A_type == rocblas_datatype_f8_r && B_type == rocblas_datatype_bf8_r &&
if (use_fp8) { D_type == rocblas_datatype_bf16_r) ||
float A_scale_inv, B_scale_inv; (A_type == rocblas_datatype_f8_r && B_type == rocblas_datatype_bf8_r &&
(void)hipMemcpy(&A_scale_inv, A_scale_inverse, sizeof(float), hipMemcpyDeviceToHost); D_type == rocblas_datatype_f8_r) ||
(void)hipMemcpy(&B_scale_inv, B_scale_inverse, sizeof(float), hipMemcpyDeviceToHost); (A_type == rocblas_datatype_f8_r && B_type == rocblas_datatype_bf8_r &&
alpha = A_scale_inv * B_scale_inv; D_type == rocblas_datatype_bf8_r) ||
} (A_type == rocblas_datatype_bf8_r && B_type == rocblas_datatype_f8_r &&
D_type == rocblas_datatype_f32_r) ||
rocblas_handle handle = rocblasHandleManager::Instance().GetHandle(); (A_type == rocblas_datatype_bf8_r && B_type == rocblas_datatype_f8_r &&
NVTE_CHECK_ROCBLAS(rocblas_set_stream(handle, stream)); D_type == rocblas_datatype_f16_r) ||
(A_type == rocblas_datatype_bf8_r && B_type == rocblas_datatype_f8_r &&
// extract the stream order alloc env D_type == rocblas_datatype_bf16_r) ||
bool stream_order_alloc = false; (A_type == rocblas_datatype_bf8_r && B_type == rocblas_datatype_f8_r &&
if (const char* env_p = std::getenv("ROCBLAS_STREAM_ORDER_ALLOC")) { D_type == rocblas_datatype_f8_r) ||
if (env_p == nullptr || std::string(env_p) == "1") stream_order_alloc = true; (A_type == rocblas_datatype_bf8_r && B_type == rocblas_datatype_f8_r &&
} D_type == rocblas_datatype_bf8_r),
"Only the following combinations of data types are enabled now!\n\
int64_t ld_gelumat = (int64_t)ldd; 1. input: fp32, output: fp32.\n\
2. input: fp16, output: fp16.\n\
NVTE_CHECK((A_type == rocblas_datatype_f16_r && B_type == rocblas_datatype_f16_r && 3. input: bf16, output: bf16.\n\
D_type == rocblas_datatype_f16_r) || 4. input: fp8/bf8, output: fp8/bf8, fp16/bf16, fp32");
(A_type == rocblas_datatype_f16_r && B_type == rocblas_datatype_f16_r &&
D_type == rocblas_datatype_f32_r) || //If D is not fp32, then we need a temp buffer for GEMM result before applying epilogues. Otherwise, we can apply epilogues in-place.
(A_type == rocblas_datatype_bf16_r && B_type == rocblas_datatype_bf16_r && // with bias or gelu, allocate fp32 D_temp if the output is not fp32
D_type == rocblas_datatype_bf16_r) || // with input fp8/bf8 (use_fp8) and bf16 output, need a fp32 D_temp, as rocblas does not support this case (fp8/bf8 input fp16/fp32 output is supported)
(A_type == rocblas_datatype_bf16_r && B_type == rocblas_datatype_bf16_r && // with use_fp8 true and fp8/bf8 output, need fp32 D_temp to support amax and scale operation
D_type == rocblas_datatype_f32_r) || void* D_temp;
(A_type == rocblas_datatype_f32_r && B_type == rocblas_datatype_f32_r && if (((bias || gelu) && (D_type == rocblas_datatype_f16_r || D_type == rocblas_datatype_bf16_r)) ||
D_type == rocblas_datatype_f32_r) || (use_fp8 && (D_type == rocblas_datatype_bf16_r || D_type == rocblas_datatype_f8_r ||
(A_type == rocblas_datatype_f8_r && B_type == rocblas_datatype_f8_r && D_type == rocblas_datatype_bf8_r))) {
D_type == rocblas_datatype_f32_r) || if (!stream_order_alloc) {
(A_type == rocblas_datatype_f8_r && B_type == rocblas_datatype_f8_r && NVTE_CHECK_CUDA(hipMalloc(&D_temp, sizeof(float) * m * n));
D_type == rocblas_datatype_f16_r) || } else {
(A_type == rocblas_datatype_f8_r && B_type == rocblas_datatype_f8_r && NVTE_CHECK_CUDA(hipMallocAsync(&D_temp, sizeof(float) * m * n, stream));
D_type == rocblas_datatype_bf16_r) || }
(A_type == rocblas_datatype_f8_r && B_type == rocblas_datatype_f8_r && } else {
D_type == rocblas_datatype_f8_r) || D_temp = D;
(A_type == rocblas_datatype_f8_r && B_type == rocblas_datatype_f8_r && }
D_type == rocblas_datatype_bf8_r) ||
(A_type == rocblas_datatype_f8_r && B_type == rocblas_datatype_bf8_r && // When Ti=To=fp16 and there is no bias or gelu, D_temp points to D and we would like it to be fp16
D_type == rocblas_datatype_f32_r) || rocblas_datatype D_temp_type = rocblas_datatype_f32_r;
(A_type == rocblas_datatype_f8_r && B_type == rocblas_datatype_bf8_r && if (!(bias || gelu) && (A_type == rocblas_datatype_f16_r && B_type == rocblas_datatype_f16_r &&
D_type == rocblas_datatype_f16_r) || D_type == rocblas_datatype_f16_r)) {
(A_type == rocblas_datatype_f8_r && B_type == rocblas_datatype_bf8_r && D_temp_type = rocblas_datatype_f16_r;
D_type == rocblas_datatype_bf16_r) || }
(A_type == rocblas_datatype_f8_r && B_type == rocblas_datatype_bf8_r && // When Ti=To=bf16 and there is no bias or gelu, D_temp points to D and we would like it to be bf16
D_type == rocblas_datatype_f8_r) || if (!(bias || gelu) && (A_type == rocblas_datatype_bf16_r && B_type == rocblas_datatype_bf16_r &&
(A_type == rocblas_datatype_f8_r && B_type == rocblas_datatype_bf8_r && D_type == rocblas_datatype_bf16_r)) {
D_type == rocblas_datatype_bf8_r) || D_temp_type = rocblas_datatype_bf16_r;
(A_type == rocblas_datatype_bf8_r && B_type == rocblas_datatype_f8_r && }
D_type == rocblas_datatype_f32_r) || // When Ti in fp8 or bf8, To=fp16, there is no bias or gelu, D_temp points to D and we would like it to be fp16, as rocblas support this case.
(A_type == rocblas_datatype_bf8_r && B_type == rocblas_datatype_f8_r && if ((!(bias || gelu)) && (use_fp8 && D_type == rocblas_datatype_f16_r)) {
D_type == rocblas_datatype_f16_r) || D_temp_type = rocblas_datatype_f16_r;
(A_type == rocblas_datatype_bf8_r && B_type == rocblas_datatype_f8_r && }
D_type == rocblas_datatype_bf16_r) ||
(A_type == rocblas_datatype_bf8_r && B_type == rocblas_datatype_f8_r && if (accumulate && (D_temp != D || D_temp_type != D_type)) {
D_type == rocblas_datatype_f8_r) || DType output_dtype = get_transformer_engine_dtype(D_type);
(A_type == rocblas_datatype_bf8_r && B_type == rocblas_datatype_f8_r && TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
D_type == rocblas_datatype_bf8_r), output_dtype, OType,
"Only the following combinations of data types are enabled now!\n\ //D_temp allocated only with fp32
1. input: fp32, output: fp32.\n\ detail::identity_kernelLauncher<OType, float>(
2. input: fp16, output: fp16.\n\ reinterpret_cast<const OType*>(D), reinterpret_cast<float*>(D_temp), m * n, stream););
3. input: bf16, output: bf16.\n\ }
4. input: fp8/bf8, output: fp8/bf8, fp16/bf16, fp32");
// D = alpha * (A * B) + beta * C
//If D is not fp32, then we need a temp buffer for GEMM result before applying epilogues. Otherwise, we can apply epilogues in-place. if (use_fp8) {
// with bias or gelu, allocate fp32 D_temp if the output is not fp32 rocblas_computetype computeType = rocblas_compute_type_f32;
// with input fp8/bf8 (use_fp8) and bf16 output, need a fp32 D_temp, as rocblas does not support this case (fp8/bf8 input fp16/fp32 output is supported) NVTE_CHECK_ROCBLAS(rocblas_gemm_ex3(handle, transa, transb, m, n, k, &alpha, A, A_type, lda, B,
// with use_fp8 true and fp8/bf8 output, need fp32 D_temp to support amax and scale operation B_type, ldb, &beta, D_temp, D_temp_type, ldd, D_temp,
void* D_temp; D_temp_type, ldd, computeType,
if (((bias || gelu) && (D_type == rocblas_datatype_f16_r || D_type == rocblas_datatype_bf16_r)) || rocblas_gemm_algo::rocblas_gemm_algo_standard, 0, 0));
(use_fp8 && (D_type == rocblas_datatype_bf16_r || D_type == rocblas_datatype_f8_r || } else {
D_type == rocblas_datatype_bf8_r))) { rocblas_datatype computeType = rocblas_datatype_f32_r;
if (!stream_order_alloc) { uint32_t flags = rocblas_gemm_flags_none;
NVTE_CHECK_CUDA(hipMalloc(&D_temp, sizeof(float) * m * n)); if ((A_type == rocblas_datatype_f16_r && B_type == rocblas_datatype_f16_r) && grad) {
} else { flags = rocblas_gemm_flags_fp16_alt_impl;
NVTE_CHECK_CUDA(hipMallocAsync(&D_temp, sizeof(float) * m * n, stream)); }
} NVTE_CHECK_ROCBLAS(rocblas_gemm_ex(handle, transa, transb, m, n, k, &alpha, A, A_type, lda, B,
} else { B_type, ldb, &beta, D_temp, D_temp_type, ldd, D_temp,
D_temp = D; D_temp_type, ldd, computeType,
} rocblas_gemm_algo::rocblas_gemm_algo_standard, 0, flags));
}
// When Ti=To=fp16 and there is no bias or gelu, D_temp points to D and we would like it to be fp16
rocblas_datatype D_temp_type = rocblas_datatype_f32_r; int batch_size, input_dim, output_dim;
if (!(bias || gelu) && (A_type == rocblas_datatype_f16_r && B_type == rocblas_datatype_f16_r && if (bias && gelu) {
D_type == rocblas_datatype_f16_r)) { if (grad) {
D_temp_type = rocblas_datatype_f16_r; // epilogue = CUBLASLT_EPILOGUE_DGELU_BGRAD;
} // Apply GELU gradient to D_temp and store in D
// When Ti=To=bf16 and there is no bias or gelu, D_temp points to D and we would like it to be bf16 // Apply bias gradient to D (D is already the result of GELU gradient) and store in bias_ptr;
if (!(bias || gelu) && (A_type == rocblas_datatype_bf16_r && B_type == rocblas_datatype_bf16_r && // This case is NN
D_type == rocblas_datatype_bf16_r)) { // D_temp is of shape is (m, n) in column major and thus is of shape (n, m) in row major
D_temp_type = rocblas_datatype_bf16_r; // The bias vector length is m. So it will be reduced along axis 0 in row major
} // (TODO): The cublasLt doc is not very clear wrt the bias gradient here.
// When Ti in fp8 or bf8, To=fp16, there is no bias or gelu, D_temp points to D and we would like it to be fp16, as rocblas support this case. // It does not explicitly say that it goes through GELU gradient first. We will need to
if ((!(bias || gelu)) && (use_fp8 && D_type == rocblas_datatype_f16_r)) { // confirm in the future. As of now, my implementation for the bias gradient takes
D_temp_type = rocblas_datatype_f16_r; // the GELU gradient result in lower precision (D). It might be better to take the GELU
} // gradient result in fp32 but as it requires some kernel changes I would only do that
// once we confirm that this is the right form of the epilogue.
if (accumulate && (D_temp != D || D_temp_type != D_type)) { // This is for linear1 -> gelu -> linear2
DType output_dtype = get_transformer_engine_dtype(D_type); // compute dX = dY * W for linear2
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( // gemm_ex(A=W, B=dY)
output_dtype, OType, batch_size = n;
//D_temp allocated only with fp32 input_dim =
detail::identity_kernelLauncher<OType, float>( m; // input dimension of the second linear layer is the output dimension of the first linear layer
reinterpret_cast<const OType*>(D), reinterpret_cast<float*>(D_temp), m * n, stream);); output_dim = k;
} DType output_dtype = get_transformer_engine_dtype(D_type);
DType gelu_dtype = get_transformer_engine_dtype(gelu_type);
// D = alpha * (A * B) + beta * C TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
if (use_fp8) { output_dtype, OType,
rocblas_computetype computeType = rocblas_compute_type_f32; TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
NVTE_CHECK_ROCBLAS(rocblas_gemm_ex3(handle, transa, transb, m, n, k, &alpha, A, A_type, lda, B, gelu_dtype, GType,
B_type, ldb, &beta, D_temp, D_temp_type, ldd, D_temp, detail::gelu_backward_kernelLauncher<OType, GType>(
D_temp_type, ldd, computeType, reinterpret_cast<const float*>(D_temp), reinterpret_cast<OType*>(D),
rocblas_gemm_algo::rocblas_gemm_algo_standard, 0, 0)); reinterpret_cast<const GType*>(pre_gelu_out), batch_size, input_dim, stream);););
} else {
rocblas_datatype computeType = rocblas_datatype_f32_r; void* bias_tmp;
uint32_t flags = rocblas_gemm_flags_none; if (bias_type != rocblas_datatype_f32_r) {
if ((A_type == rocblas_datatype_f16_r && B_type == rocblas_datatype_f16_r) && grad) { if (!stream_order_alloc) {
flags = rocblas_gemm_flags_fp16_alt_impl; NVTE_CHECK_CUDA(hipMalloc(
} &bias_tmp,
NVTE_CHECK_ROCBLAS(rocblas_gemm_ex(handle, transa, transb, m, n, k, &alpha, A, A_type, lda, B, sizeof(float) * input_dim)); // The bias gradient is for the first linear layer
B_type, ldb, &beta, D_temp, D_temp_type, ldd, D_temp, } else {
D_temp_type, ldd, computeType, NVTE_CHECK_CUDA(hipMallocAsync(&bias_tmp, sizeof(float) * input_dim, stream));
rocblas_gemm_algo::rocblas_gemm_algo_standard, 0, flags)); }
} } else {
bias_tmp = bias_ptr;
int batch_size, input_dim, output_dim; }
if (bias && gelu) {
if (grad) { TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
// epilogue = CUBLASLT_EPILOGUE_DGELU_BGRAD; output_dtype, OType,
// Apply GELU gradient to D_temp and store in D detail::bias_gradient_kernelLauncher<OType>(
// Apply bias gradient to D (D is already the result of GELU gradient) and store in bias_ptr; reinterpret_cast<const OType*>(D), reinterpret_cast<float*>(bias_tmp), batch_size,
// This case is NN input_dim, stream_order_alloc, stream););
// D_temp is of shape is (m, n) in column major and thus is of shape (n, m) in row major
// The bias vector length is m. So it will be reduced along axis 0 in row major if (bias_type != rocblas_datatype_f32_r) {
// (TODO): The cublasLt doc is not very clear wrt the bias gradient here. DType bias_dtype = get_transformer_engine_dtype(bias_type);
// It does not explicitly say that it goes through GELU gradient first. We will need to TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
// confirm in the future. As of now, my implementation for the bias gradient takes bias_dtype, BType,
// the GELU gradient result in lower precision (D). It might be better to take the GELU detail::identity_kernelLauncher<float, BType>(reinterpret_cast<const float*>(bias_tmp),
// gradient result in fp32 but as it requires some kernel changes I would only do that reinterpret_cast<BType*>(bias_ptr),
// once we confirm that this is the right form of the epilogue. input_dim, stream););
// This is for linear1 -> gelu -> linear2 if (!stream_order_alloc) {
// compute dX = dY * W for linear2 NVTE_CHECK_CUDA(hipFree(bias_tmp));
// gemm_ex(A=W, B=dY) } else {
batch_size = n; NVTE_CHECK_CUDA(hipFreeAsync(bias_tmp, stream));
input_dim = }
m; // input dimension of the second linear layer is the output dimension of the first linear layer }
output_dim = k;
DType output_dtype = get_transformer_engine_dtype(D_type); } else {
DType gelu_dtype = get_transformer_engine_dtype(gelu_type); // epilogue = CUBLASLT_EPILOGUE_GELU_AUX_BIAS;
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( // Add bias_ptr to D_temp and store in pre_gelu_out, and apply GELU to the pre_gelu_output and then store in D
output_dtype, OType, // D_temp is of shape is (m, n) in column major and thus is of shape (n, m) in row major
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( // gemm_ex(A=W, B=X, transA=T)
gelu_dtype, GType, batch_size = n;
detail::gelu_backward_kernelLauncher<OType, GType>( input_dim = k;
reinterpret_cast<const float*>(D_temp), reinterpret_cast<OType*>(D), output_dim = m;
reinterpret_cast<const GType*>(pre_gelu_out), batch_size, input_dim, stream););); DType output_dtype = get_transformer_engine_dtype(D_type);
DType bias_dtype = get_transformer_engine_dtype(bias_type);
void* bias_tmp; DType gelu_dtype = get_transformer_engine_dtype(gelu_type);
if (bias_type != rocblas_datatype_f32_r) { TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
if (!stream_order_alloc) { output_dtype, OType,
NVTE_CHECK_CUDA(hipMalloc( TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
&bias_tmp, gelu_dtype, GType,
sizeof(float) * input_dim)); // The bias gradient is for the first linear layer TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
} else { bias_dtype, BType,
NVTE_CHECK_CUDA(hipMallocAsync(&bias_tmp, sizeof(float) * input_dim, stream)); detail::add_bias_gelu_kernelLauncher<OType, GType, BType>(
} reinterpret_cast<const float*>(D_temp), reinterpret_cast<OType*>(D),
} else { reinterpret_cast<GType*>(pre_gelu_out),
bias_tmp = bias_ptr; reinterpret_cast<const BType*>(bias_ptr), reinterpret_cast<float*>(D_amax),
} reinterpret_cast<const float*>(D_scale), batch_size, output_dim,
stream););););
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( }
output_dtype, OType, } else if (bias) {
detail::bias_gradient_kernelLauncher<OType>( if (grad) {
reinterpret_cast<const OType*>(D), reinterpret_cast<float*>(bias_tmp), batch_size, // grad output is always input B
input_dim, stream_order_alloc, stream);); // epilogue = CUBLASLT_EPILOGUE_BGRADB;
// Apply bias gradient to matrix B and store in bias_ptr, reduce along the k dimension, output bias length is n
if (bias_type != rocblas_datatype_f32_r) { // As B is transposed, is of shape (n, k) in column major, and is of shape (k, n) in row major.
DType bias_dtype = get_transformer_engine_dtype(bias_type); // bias gradient vector length is n. So it will be reduced along axis 0 in row major.
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( // The backward pass calculate the bias gradient along with dW = dY^T * X
bias_dtype, BType, // gemm_ex(A=X, B = dY, transB=T)
detail::identity_kernelLauncher<float, BType>(reinterpret_cast<const float*>(bias_tmp), batch_size = k;
reinterpret_cast<BType*>(bias_ptr), input_dim = m;
input_dim, stream);); output_dim = n;
if (!stream_order_alloc) { void* bias_tmp;
NVTE_CHECK_CUDA(hipFree(bias_tmp)); if (bias_type != rocblas_datatype_f32_r) {
} else { if (!stream_order_alloc) {
NVTE_CHECK_CUDA(hipFreeAsync(bias_tmp, stream)); NVTE_CHECK_CUDA(hipMalloc(&bias_tmp, sizeof(float) * output_dim));
} } else {
} NVTE_CHECK_CUDA(hipMallocAsync(&bias_tmp, sizeof(float) * output_dim, stream));
}
} else { } else {
// epilogue = CUBLASLT_EPILOGUE_GELU_AUX_BIAS; bias_tmp = bias_ptr;
// Add bias_ptr to D_temp and store in pre_gelu_out, and apply GELU to the pre_gelu_output and then store in D }
// D_temp is of shape is (m, n) in column major and thus is of shape (n, m) in row major
// gemm_ex(A=W, B=X, transA=T) DType input_dtype = get_transformer_engine_dtype(B_type);
batch_size = n; DType output_dtype = get_transformer_engine_dtype(D_type);
input_dim = k; DType bias_dtype = get_transformer_engine_dtype(bias_type);
output_dim = m; TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
DType output_dtype = get_transformer_engine_dtype(D_type); input_dtype, IType,
DType bias_dtype = get_transformer_engine_dtype(bias_type); detail::bias_gradient_kernelLauncher<IType>(
DType gelu_dtype = get_transformer_engine_dtype(gelu_type); reinterpret_cast<const IType*>(B), reinterpret_cast<float*>(bias_tmp), batch_size,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( output_dim, stream_order_alloc, stream););
output_dtype, OType, if (bias_type != rocblas_datatype_f32_r) {
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
gelu_dtype, GType, bias_dtype, BType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( detail::identity_kernelLauncher<float, BType>(reinterpret_cast<const float*>(bias_tmp),
bias_dtype, BType, reinterpret_cast<BType*>(bias_ptr),
detail::add_bias_gelu_kernelLauncher<OType, GType, BType>( output_dim, stream););
reinterpret_cast<const float*>(D_temp), reinterpret_cast<OType*>(D), if (!stream_order_alloc) {
reinterpret_cast<GType*>(pre_gelu_out), NVTE_CHECK_CUDA(hipFree(bias_tmp));
reinterpret_cast<const BType*>(bias_ptr), reinterpret_cast<float*>(D_amax), } else {
reinterpret_cast<const float*>(D_scale), batch_size, output_dim, NVTE_CHECK_CUDA(hipFreeAsync(bias_tmp, stream));
stream);););); }
} }
} else if (bias) { if (D_type == rocblas_datatype_f16_r || D_type == rocblas_datatype_bf16_r) {
if (grad) { TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
// grad output is always input B output_dtype, OType,
// epilogue = CUBLASLT_EPILOGUE_BGRADB; detail::identity_kernelLauncher<float, OType>(reinterpret_cast<const float*>(D_temp),
// Apply bias gradient to matrix B and store in bias_ptr, reduce along the k dimension, output bias length is n reinterpret_cast<OType*>(D),
// As B is transposed, is of shape (n, k) in column major, and is of shape (k, n) in row major. input_dim * output_dim, stream););
// bias gradient vector length is n. So it will be reduced along axis 0 in row major. }
// The backward pass calculate the bias gradient along with dW = dY^T * X } else {
// gemm_ex(A=X, B = dY, transB=T) // epilogue = CUBLASLT_EPILOGUE_BIAS;
batch_size = k; // Broadcast bias and add it to D_temp and store in D. The bias vector length is m
input_dim = m; // D_temp is of shape is (m, n) in column major and thus is of shape (n, m) in row major
output_dim = n; // gemm_ex(A=W, B=X, transA=T)
void* bias_tmp; batch_size = n;
if (bias_type != rocblas_datatype_f32_r) { input_dim = k;
if (!stream_order_alloc) { output_dim = m;
NVTE_CHECK_CUDA(hipMalloc(&bias_tmp, sizeof(float) * output_dim)); DType output_dtype = get_transformer_engine_dtype(D_type);
} else { DType bias_dtype = get_transformer_engine_dtype(bias_type);
NVTE_CHECK_CUDA(hipMallocAsync(&bias_tmp, sizeof(float) * output_dim, stream)); TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
} output_dtype, OType,
} else { TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
bias_tmp = bias_ptr; bias_dtype, BType,
} detail::add_bias_kernelLauncher<OType, BType>(
reinterpret_cast<const float*>(D_temp), reinterpret_cast<OType*>(D),
DType input_dtype = get_transformer_engine_dtype(B_type); reinterpret_cast<const BType*>(bias_ptr), reinterpret_cast<float*>(D_amax),
DType output_dtype = get_transformer_engine_dtype(D_type); reinterpret_cast<const float*>(D_scale), batch_size, output_dim, stream);););
DType bias_dtype = get_transformer_engine_dtype(bias_type); }
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( } else if (gelu) {
input_dtype, IType, if (grad) {
detail::bias_gradient_kernelLauncher<IType>( // epilogue = CUBLASLT_EPILOGUE_DGELU;
reinterpret_cast<const IType*>(B), reinterpret_cast<float*>(bias_tmp), batch_size, // Take input from pre_gelu_out and apply GELU gradients to D_temp and store result in D
output_dim, stream_order_alloc, stream);); // D_temp is of shape is (m, n) in column major and thus is of shape (n, m) in row major
if (bias_type != rocblas_datatype_f32_r) { // gemm_ex(A=W, B=dY)
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( batch_size = n;
bias_dtype, BType, input_dim = m;
detail::identity_kernelLauncher<float, BType>(reinterpret_cast<const float*>(bias_tmp), output_dim = k;
reinterpret_cast<BType*>(bias_ptr), DType output_dtype = get_transformer_engine_dtype(D_type);
output_dim, stream);); DType gelu_dtype = get_transformer_engine_dtype(gelu_type);
if (!stream_order_alloc) { TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
NVTE_CHECK_CUDA(hipFree(bias_tmp)); output_dtype, OType,
} else { TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
NVTE_CHECK_CUDA(hipFreeAsync(bias_tmp, stream)); gelu_dtype, GType,
} detail::gelu_backward_kernelLauncher<OType, GType>(
} reinterpret_cast<const float*>(D_temp), reinterpret_cast<OType*>(D),
if (D_type == rocblas_datatype_f16_r || D_type == rocblas_datatype_bf16_r) { reinterpret_cast<const GType*>(pre_gelu_out), batch_size, input_dim, stream);););
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( } else {
output_dtype, OType, // epilogue = CUBLASLT_EPILOGUE_GELU_AUX;
detail::identity_kernelLauncher<float, OType>(reinterpret_cast<const float*>(D_temp), // Store (quantized) D_temp in pre_gelu_out, and apply GELU to D_temp then store in D
reinterpret_cast<OType*>(D), // D_temp is of shape is (m, n) in column major and thus is of shape (n, m) in row major
input_dim * output_dim, stream);); // gemm_ex(A=W, B=X, transA=T)
} batch_size = n;
} else { input_dim = k;
// epilogue = CUBLASLT_EPILOGUE_BIAS; output_dim = m;
// Broadcast bias and add it to D_temp and store in D. The bias vector length is m
// D_temp is of shape is (m, n) in column major and thus is of shape (n, m) in row major DType gelu_dtype = get_transformer_engine_dtype(gelu_type);
// gemm_ex(A=W, B=X, transA=T) TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
batch_size = n; gelu_dtype, GType,
input_dim = k; detail::identity_kernelLauncher<float, GType>(reinterpret_cast<const float*>(D_temp),
output_dim = m; reinterpret_cast<GType*>(pre_gelu_out),
DType output_dtype = get_transformer_engine_dtype(D_type); batch_size * output_dim, stream););
DType bias_dtype = get_transformer_engine_dtype(bias_type); DType output_dtype = get_transformer_engine_dtype(D_type);
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
output_dtype, OType, output_dtype, OType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( detail::gelu_forward_kernelLauncher<OType>(
bias_dtype, BType, reinterpret_cast<const float*>(D_temp), reinterpret_cast<OType*>(D),
detail::add_bias_kernelLauncher<OType, BType>( reinterpret_cast<float*>(D_amax), reinterpret_cast<const float*>(D_scale), batch_size,
reinterpret_cast<const float*>(D_temp), reinterpret_cast<OType*>(D), output_dim, stream););
reinterpret_cast<const BType*>(bias_ptr), reinterpret_cast<float*>(D_amax), }
reinterpret_cast<const float*>(D_scale), batch_size, output_dim, stream););); } else { // No epilogue - !(bias || gelu)
} if (use_fp8 && (D_type == rocblas_datatype_bf16_r || D_type == rocblas_datatype_f8_r ||
} else if (gelu) { D_type == rocblas_datatype_bf8_r)) {
if (grad) { DType output_dtype = get_transformer_engine_dtype(D_type);
// epilogue = CUBLASLT_EPILOGUE_DGELU; TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
// Take input from pre_gelu_out and apply GELU gradients to D_temp and store result in D output_dtype, OType,
// D_temp is of shape is (m, n) in column major and thus is of shape (n, m) in row major detail::identity_output_kernelLauncher<OType>(
// gemm_ex(A=W, B=dY) reinterpret_cast<const float*>(D_temp), reinterpret_cast<OType*>(D),
batch_size = n; reinterpret_cast<float*>(D_amax), reinterpret_cast<const float*>(D_scale), m * n,
input_dim = m; stream););
output_dim = k; }
DType output_dtype = get_transformer_engine_dtype(D_type); }
DType gelu_dtype = get_transformer_engine_dtype(gelu_type);
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( if (((bias || gelu) && (D_type == rocblas_datatype_f16_r || D_type == rocblas_datatype_bf16_r)) ||
output_dtype, OType, (use_fp8 && (D_type == rocblas_datatype_bf16_r || D_type == rocblas_datatype_f8_r ||
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( D_type == rocblas_datatype_bf8_r))) {
gelu_dtype, GType, if (!stream_order_alloc) {
detail::gelu_backward_kernelLauncher<OType, GType>( NVTE_CHECK_CUDA(hipFree(D_temp));
reinterpret_cast<const float*>(D_temp), reinterpret_cast<OType*>(D), } else {
reinterpret_cast<const GType*>(pre_gelu_out), batch_size, input_dim, stream););); NVTE_CHECK_CUDA(hipFreeAsync(D_temp, stream));
} else { }
// epilogue = CUBLASLT_EPILOGUE_GELU_AUX; }
// Store (quantized) D_temp in pre_gelu_out, and apply GELU to D_temp then store in D }
// D_temp is of shape is (m, n) in column major and thus is of shape (n, m) in row major
// gemm_ex(A=W, B=X, transA=T) #endif //USE_ROCBLAS
batch_size = n;
input_dim = k; void cublas_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
output_dim = m; const Tensor* inputBias, Tensor* outputPreGelu, int m, int n, int k, int lda,
int ldb, int ldd, bool transa, bool transb, bool grad, void* workspace,
DType gelu_dtype = get_transformer_engine_dtype(gelu_type); size_t workspaceSize, bool accumulate, bool use_split_accumulator,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( int math_sm_count, int m_split, int n_split, bool gemm_producer,
gelu_dtype, GType, const Tensor* inputCounter, hipStream_t stream, bool nvte_use_hipblaslt = 0,
detail::identity_kernelLauncher<float, GType>(reinterpret_cast<const float*>(D_temp), bool nvte_use_rocblas = 0, int compute_stream_offset = -1) {
reinterpret_cast<GType*>(pre_gelu_out), /*If no backend is specified with env variable use HIPBLASLT unless it is disabled
batch_size * output_dim, stream);); If HIPBLASLT backend is enabled and requested, use it despite ROCBLAS status
DType output_dtype = get_transformer_engine_dtype(D_type); Otherwise use ROCBLAS
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( */
output_dtype, OType,
detail::gelu_forward_kernelLauncher<OType>( bool use_hipblaslt = (std::getenv("NVTE_USE_HIPBLASLT") != nullptr) || nvte_use_hipblaslt;
reinterpret_cast<const float*>(D_temp), reinterpret_cast<OType*>(D), bool use_rocblas = (std::getenv("NVTE_USE_ROCBLAS") != nullptr) || nvte_use_rocblas;
reinterpret_cast<float*>(D_amax), reinterpret_cast<const float*>(D_scale), batch_size,
output_dim, stream);); #if !defined(USE_HIPBLASLT) && !defined(USE_ROCBLAS)
} #error GEMM backend is not specified
} else { // No epilogue - !(bias || gelu) #elif !defined(USE_HIPBLASLT)
if (use_fp8 && (D_type == rocblas_datatype_bf16_r || D_type == rocblas_datatype_f8_r || if (use_hipblaslt) {
D_type == rocblas_datatype_bf8_r)) { use_hipblaslt = false;
DType output_dtype = get_transformer_engine_dtype(D_type); use_rocblas = true;
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( std::cout << "[NOTICE] hipBLASLt is not enabled, NVTE_USE_HIPBLASLT env is ignored\n";
output_dtype, OType, }
detail::identity_output_kernelLauncher<OType>( #elif !defined(USE_ROCBLAS)
reinterpret_cast<const float*>(D_temp), reinterpret_cast<OType*>(D), if (use_rocblas) {
reinterpret_cast<float*>(D_amax), reinterpret_cast<const float*>(D_scale), m * n, use_rocblas = false;
stream);); use_hipblaslt = true;
} std::cout << "[NOTICE] rocBLAS is not enabled, NVTE_USE_ROCBLAS env is ignored\n";
} }
#else
if (((bias || gelu) && (D_type == rocblas_datatype_f16_r || D_type == rocblas_datatype_bf16_r)) || if (use_hipblaslt && use_rocblas) {
(use_fp8 && (D_type == rocblas_datatype_bf16_r || D_type == rocblas_datatype_f8_r || use_rocblas = false;
D_type == rocblas_datatype_bf8_r))) { use_hipblaslt = true;
if (!stream_order_alloc) { // std::cout << "[NOTICE] Two GEMM backend are enabled, hipBLASLt will be used\n";
NVTE_CHECK_CUDA(hipFree(D_temp)); } else if (!use_hipblaslt && !use_rocblas) {
} else { use_rocblas = false;
NVTE_CHECK_CUDA(hipFreeAsync(D_temp, stream)); use_hipblaslt = true;
} // std::cout << "[NOTICE] Two GEMM backend are disabled, hipBLASLt will be used\n";
} }
} #endif
#endif //USE_ROCBLAS #ifdef USE_HIPBLASLT
if (use_hipblaslt || !use_rocblas) {
void cublas_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD, // Check compute_stream_offset valid.
const Tensor* inputBias, Tensor* outputPreGelu, int m, int n, int k, int lda, NVTE_CHECK(compute_stream_offset >= -1 && compute_stream_offset < compute_num_streams);
int ldb, int ldd, bool transa, bool transb, bool grad, void* workspace,
size_t workspaceSize, bool accumulate, bool use_split_accumulator, hipblasLtHandle_t handle = nullptr;
int math_sm_count, int m_split, int n_split, bool gemm_producer, if (compute_stream_offset != -1) {
const Tensor* inputCounter, hipStream_t stream, bool nvte_use_hipblaslt = 0, // Init hipblaslt handles (once, globally)
bool nvte_use_rocblas = 0, int compute_stream_offset = -1) { static std::once_flag init_flag;
/*If no backend is specified with env variable use HIPBLASLT unless it is disabled static hipblasLtHandle_t hipblaslt_handles[compute_num_streams];
If HIPBLASLT backend is enabled and requested, use it despite ROCBLAS status std::call_once(init_flag, init_hipblaslt_handles, hipblaslt_handles);
Otherwise use ROCBLAS
*/ handle = hipblaslt_handles[compute_stream_offset];
}
bool use_hipblaslt = (std::getenv("NVTE_USE_HIPBLASLT") != nullptr) || nvte_use_hipblaslt;
bool use_rocblas = (std::getenv("NVTE_USE_ROCBLAS") != nullptr) || nvte_use_rocblas; hipblaslt_gemm(inputA, inputB, outputD, inputBias, outputPreGelu, m, n, k, lda, ldb, ldd,
(transa) ? HIPBLAS_OP_T : HIPBLAS_OP_N, (transb) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
#if !defined(USE_HIPBLASLT) && !defined(USE_ROCBLAS) grad, workspace, workspaceSize, accumulate, use_split_accumulator, math_sm_count,
#error GEMM backend is not specified m_split, n_split, gemm_producer, inputCounter, stream, handle);
#elif !defined(USE_HIPBLASLT)
if (use_hipblaslt) { return;
use_hipblaslt = false; }
use_rocblas = true; #endif
std::cout << "[NOTICE] hipBLASLt is not enabled, NVTE_USE_HIPBLASLT env is ignored\n";
} #ifdef USE_ROCBLAS
#elif !defined(USE_ROCBLAS) if (use_rocblas) {
if (use_rocblas) { rocblas_gemm(inputA, inputB, outputD, inputBias, outputPreGelu, m, n, k, lda, ldb, ldd,
use_rocblas = false; (transa) ? rocblas_operation_transpose : rocblas_operation_none,
use_hipblaslt = true; (transb) ? rocblas_operation_transpose : rocblas_operation_none, grad, workspace,
std::cout << "[NOTICE] rocBLAS is not enabled, NVTE_USE_ROCBLAS env is ignored\n"; workspaceSize, accumulate, use_split_accumulator, math_sm_count, m_split, n_split,
} gemm_producer, inputCounter, stream);
#else }
if (use_hipblaslt && use_rocblas) { #endif
use_rocblas = false; }
use_hipblaslt = true;
// std::cout << "[NOTICE] Two GEMM backend are enabled, hipBLASLt will be used\n"; } //namespace transformer_engine
} else if (!use_hipblaslt && !use_rocblas) {
use_rocblas = false;
use_hipblaslt = true;
// std::cout << "[NOTICE] Two GEMM backend are disabled, hipBLASLt will be used\n";
}
#endif
#ifdef USE_HIPBLASLT
if (use_hipblaslt || !use_rocblas) {
// Check compute_stream_offset valid.
NVTE_CHECK(compute_stream_offset >= -1 && compute_stream_offset < compute_num_streams);
hipblasLtHandle_t handle = nullptr;
if (compute_stream_offset != -1) {
// Init hipblaslt handles (once, globally)
static std::once_flag init_flag;
static hipblasLtHandle_t hipblaslt_handles[compute_num_streams];
std::call_once(init_flag, init_hipblaslt_handles, hipblaslt_handles);
handle = hipblaslt_handles[compute_stream_offset];
}
hipblaslt_gemm(inputA, inputB, outputD, inputBias, outputPreGelu, m, n, k, lda, ldb, ldd,
(transa) ? HIPBLAS_OP_T : HIPBLAS_OP_N, (transb) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
grad, workspace, workspaceSize, accumulate, use_split_accumulator, math_sm_count,
m_split, n_split, gemm_producer, inputCounter, stream, handle);
return;
}
#endif
#ifdef USE_ROCBLAS
if (use_rocblas) {
rocblas_gemm(inputA, inputB, outputD, inputBias, outputPreGelu, m, n, k, lda, ldb, ldd,
(transa) ? rocblas_operation_transpose : rocblas_operation_none,
(transb) ? rocblas_operation_transpose : rocblas_operation_none, grad, workspace,
workspaceSize, accumulate, use_split_accumulator, math_sm_count, m_split, n_split,
gemm_producer, inputCounter, stream);
}
#endif
}
} //namespace transformer_engine
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment