Commit 254fd12c authored by wenjh's avatar wenjh
Browse files

Merge branch 'develop_v2.5'

parents 554296b4 c6dae0e5
...@@ -3,36 +3,43 @@ ...@@ -3,36 +3,43 @@
* *
* License for AMD contributions = MIT. See LICENSE for more information * License for AMD contributions = MIT. See LICENSE for more information
************************************************************************/ ************************************************************************/
#include <type_traits>
#include <transformer_engine/gemm.h> #include <transformer_engine/gemm.h>
#include <transformer_engine/transformer_engine.h> #include <transformer_engine/transformer_engine.h>
#include <type_traits>
#ifdef USE_HIPBLASLT #ifdef USE_HIPBLASLT
#include <hipblaslt/hipblaslt.h>
#include <unistd.h> #include <unistd.h>
#include <vector>
#include <chrono>
#include <forward_list> #include <forward_list>
#include <mutex>
#include <unordered_map>
#include <sstream>
#include <fstream> #include <fstream>
#include <chrono> #include <mutex>
#include <optional> #include <optional>
#include <hipblaslt/hipblaslt.h> #include <sstream>
#include <unordered_map>
#include <vector>
#endif #endif
#ifdef USE_ROCBLAS #ifdef USE_ROCBLAS
#define ROCBLAS_BETA_FEATURES_API #define ROCBLAS_BETA_FEATURES_API
#include <rocblas/rocblas.h> #include <rocblas/rocblas.h>
#include <hipcub/hipcub.hpp>
#include <hipblaslt/hipblaslt-ext.hpp> #include <hipblaslt/hipblaslt-ext.hpp>
#include <hipcub/hipcub.hpp>
#endif #endif
#include <iostream> #include <cstdint>
#include <cstdlib> #include <cstdlib>
#include <iostream>
#include <string> #include <string>
#include <cstdint>
#include "../common.h" #include "../common.h"
#include "../util/handle_manager.h" #include "../util/handle_manager.h"
#include "../util/vectorized_pointwise.h"
#include "../util/logging.h" #include "../util/logging.h"
#include "../util/vectorized_pointwise.h"
namespace { namespace {
...@@ -54,7 +61,7 @@ static hipDataType get_hipblaslt_dtype(const transformer_engine::DType t) { ...@@ -54,7 +61,7 @@ static hipDataType get_hipblaslt_dtype(const transformer_engine::DType t) {
case DType::kInt8: case DType::kInt8:
return HIP_R_8I; return HIP_R_8I;
case DType::kInt32: case DType::kInt32:
return HIP_R_32I; return HIP_R_32I;
default: default:
NVTE_ERROR("Invalid type"); NVTE_ERROR("Invalid type");
} }
...@@ -81,7 +88,7 @@ rocblas_datatype get_rocblas_dtype(const transformer_engine::DType t) { ...@@ -81,7 +88,7 @@ rocblas_datatype get_rocblas_dtype(const transformer_engine::DType t) {
} }
#endif #endif
} //namespace } //namespace
namespace transformer_engine { namespace transformer_engine {
...@@ -91,67 +98,58 @@ namespace detail { ...@@ -91,67 +98,58 @@ namespace detail {
struct Empty {}; struct Empty {};
__device__ inline fp32 identity(fp32 value, const Empty&) { __device__ inline fp32 identity(fp32 value, const Empty&) { return value; }
return value;
}
__inline__ __device__ __inline__ __device__ float gelu(float x, const Empty&) {
float gelu(float x, const Empty&)
{
float cdf = 0.5f * (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x)))); float cdf = 0.5f * (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x))));
return x * cdf; return x * cdf;
} }
__inline__ __device__ float gelu_forward(float x) {
__inline__ __device__
float gelu_forward(float x)
{
float cdf = 0.5f * (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x)))); float cdf = 0.5f * (1.0f + tanhf((0.7978845608028654f * (x + 0.044715f * x * x * x))));
return x * cdf; return x * cdf;
} }
template <typename T, int THREADS_PER_BLOCK> template <typename T, int THREADS_PER_BLOCK>
__global__ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
void __launch_bounds__(THREADS_PER_BLOCK) gelu_forward_kernel(const float* in, T* out, float* amax, const float* scale, int m, int n) { gelu_forward_kernel(const float* in, T* out, float* amax, const float* scale, int m, int n) {
// fp8 output flow // fp8 output flow
if constexpr(std::is_same<T, fp8e4m3>::value ||std::is_same<T, fp8e5m2>::value){ if constexpr (std::is_same<T, fp8e4m3>::value || std::is_same<T, fp8e5m2>::value) {
typedef hipcub::BlockReduce<float, THREADS_PER_BLOCK> BlockReduce; typedef hipcub::BlockReduce<float, THREADS_PER_BLOCK> BlockReduce;
__shared__ typename BlockReduce::TempStorage block_temp_storage; __shared__ typename BlockReduce::TempStorage block_temp_storage;
float thread_amax = 0; float thread_amax = 0;
for(int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x){ for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) {
float x = in[id]; float x = in[id];
float y = gelu_forward(x); float y = gelu_forward(x);
out[id] = (T)((*scale)*y); out[id] = (T)((*scale) * y);
thread_amax=std::fmax(std::fabs(y), thread_amax); thread_amax = std::fmax(std::fabs(y), thread_amax);
} }
float block_amax = BlockReduce(block_temp_storage).Reduce(thread_amax, hipcub::Max()); float block_amax = BlockReduce(block_temp_storage).Reduce(thread_amax, hipcub::Max());
if(threadIdx.x==0){ if (threadIdx.x == 0) {
atomicMaxFloat(amax, block_amax); atomicMaxFloat(amax, block_amax);
} }
}else{ } else {
for(int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x){ for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) {
float x = in[id]; float x = in[id];
float y = gelu_forward(x); float y = gelu_forward(x);
out[id] = (T)(y); out[id] = (T)(y);
} }
} }
} }
template <typename T> template <typename T>
void gelu_forward_kernelLauncher(const float* in, T* out, float* amax, const float* scale, int m, int n, hipStream_t stream) { void gelu_forward_kernelLauncher(const float* in, T* out, float* amax, const float* scale, int m,
int n, hipStream_t stream) {
dim3 block, grid; dim3 block, grid;
constexpr int THREADS_PER_BLOCK = 1024; constexpr int THREADS_PER_BLOCK = 1024;
block.x = THREADS_PER_BLOCK; block.x = THREADS_PER_BLOCK;
grid.x = ceil(1.0*m * n / THREADS_PER_BLOCK); grid.x = ceil(1.0 * m * n / THREADS_PER_BLOCK);
hipLaunchKernelGGL(( gelu_forward_kernel<T, THREADS_PER_BLOCK>), dim3(grid), dim3(block), 0, stream, in, out, amax, scale, m, n); hipLaunchKernelGGL((gelu_forward_kernel<T, THREADS_PER_BLOCK>), dim3(grid), dim3(block), 0,
stream, in, out, amax, scale, m, n);
} }
__inline__ __device__ float gelu_backward(float x, float dy) {
__inline__ __device__ constexpr float kBeta = 0.7978845608028654f;
float gelu_backward(float x, float dy){
constexpr float kBeta = 0.7978845608028654f;
constexpr float kKappa = 0.044715f; constexpr float kKappa = 0.044715f;
float x_sq = x * x; float x_sq = x * x;
float x_cube = x_sq * x; float x_cube = x_sq * x;
...@@ -170,46 +168,48 @@ float gelu_backward(float x, float dy){ ...@@ -170,46 +168,48 @@ float gelu_backward(float x, float dy){
} }
template <typename T, typename Taux> template <typename T, typename Taux>
__global__ __global__ void gelu_backward_kernel(const float* dy, T* out, const Taux* __restrict pre_gelu_out,
void gelu_backward_kernel(const float* dy, T* out, const Taux* __restrict pre_gelu_out, int m, int n) { int m, int n) {
for(int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) {
{
float x = (float)pre_gelu_out[id]; float x = (float)pre_gelu_out[id];
float dx = (float)gelu_backward(x, dy[id]); float dx = (float)gelu_backward(x, dy[id]);
out[id] = (T)(dx); out[id] = (T)(dx);
} }
} }
template <typename T, typename Taux> template <typename T, typename Taux>
void gelu_backward_kernelLauncher(const float* in, T* out, const Taux* pre_gelu_out, int m, int n, hipStream_t stream) { void gelu_backward_kernelLauncher(const float* in, T* out, const Taux* pre_gelu_out, int m, int n,
int blocks_per_row = ceil(float(n)/256); hipStream_t stream) {
int blocks_per_row = ceil(float(n) / 256);
dim3 grid(min(m * blocks_per_row, 65536)); dim3 grid(min(m * blocks_per_row, 65536));
dim3 block(min(n, 256)); dim3 block(min(n, 256));
hipLaunchKernelGGL(( gelu_backward_kernel<T, Taux>), dim3(grid), dim3(block), 0, stream, in, out, pre_gelu_out, m, n); hipLaunchKernelGGL((gelu_backward_kernel<T, Taux>), dim3(grid), dim3(block), 0, stream, in, out,
pre_gelu_out, m, n);
} }
template <typename T, typename Tb, int THREADS_PER_BLOCK> template <typename T, typename Tb, int THREADS_PER_BLOCK>
__global__ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
void __launch_bounds__(THREADS_PER_BLOCK) add_bias_kernel(const float* in, T* out, const Tb* __restrict bias, float* amax, const float* scale, int m, int n){ add_bias_kernel(const float* in, T* out, const Tb* __restrict bias, float* amax,
const float* scale, int m, int n) {
// fp8 output flow // fp8 output flow
if constexpr(std::is_same<T, fp8e4m3>::value ||std::is_same<T, fp8e5m2>::value){ if constexpr (std::is_same<T, fp8e4m3>::value || std::is_same<T, fp8e5m2>::value) {
typedef hipcub::BlockReduce<float, THREADS_PER_BLOCK> BlockReduce; typedef hipcub::BlockReduce<float, THREADS_PER_BLOCK> BlockReduce;
__shared__ typename BlockReduce::TempStorage block_temp_storage; __shared__ typename BlockReduce::TempStorage block_temp_storage;
float thread_amax = 0; float thread_amax = 0;
for(int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x){ for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) {
float reg_bias = (float)bias[id % n]; float reg_bias = (float)bias[id % n];
float val = in[id] + reg_bias; float val = in[id] + reg_bias;
out[id] = (T)((*scale)*val); out[id] = (T)((*scale) * val);
// deal with amax of D // deal with amax of D
thread_amax=std::fmax(std::fabs(val), thread_amax); thread_amax = std::fmax(std::fabs(val), thread_amax);
} }
// num_valid can be ignored since each thread amax is set to 0 // num_valid can be ignored since each thread amax is set to 0
float block_amax = BlockReduce(block_temp_storage).Reduce(thread_amax, hipcub::Max()); float block_amax = BlockReduce(block_temp_storage).Reduce(thread_amax, hipcub::Max());
if(threadIdx.x==0){ if (threadIdx.x == 0) {
atomicMaxFloat(amax, block_amax); atomicMaxFloat(amax, block_amax);
} }
}else{ } else {
for(int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x){ for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) {
float reg_bias = (float)bias[id % n]; float reg_bias = (float)bias[id % n];
float val = in[id] + reg_bias; float val = in[id] + reg_bias;
out[id] = (T)(val); out[id] = (T)(val);
...@@ -217,43 +217,44 @@ void __launch_bounds__(THREADS_PER_BLOCK) add_bias_kernel(const float* in, T* ou ...@@ -217,43 +217,44 @@ void __launch_bounds__(THREADS_PER_BLOCK) add_bias_kernel(const float* in, T* ou
} }
} }
template <typename T, typename Tb> template <typename T, typename Tb>
void add_bias_kernelLauncher(const float* in, T* out, const Tb* __restrict bias, float* amax, const float* scale, int m, int n, hipStream_t stream) { void add_bias_kernelLauncher(const float* in, T* out, const Tb* __restrict bias, float* amax,
const float* scale, int m, int n, hipStream_t stream) {
dim3 block, grid; dim3 block, grid;
constexpr int THREADS_PER_BLOCK = 1024; constexpr int THREADS_PER_BLOCK = 1024;
block.x = THREADS_PER_BLOCK; block.x = THREADS_PER_BLOCK;
grid.x = ceil(1.0*m * n / THREADS_PER_BLOCK); grid.x = ceil(1.0 * m * n / THREADS_PER_BLOCK);
hipLaunchKernelGGL(( add_bias_kernel<T, Tb, THREADS_PER_BLOCK>), dim3(grid), dim3(block), 0, stream, in, out, bias, amax, scale, m, n); hipLaunchKernelGGL((add_bias_kernel<T, Tb, THREADS_PER_BLOCK>), dim3(grid), dim3(block), 0,
stream, in, out, bias, amax, scale, m, n);
} }
template <typename T, typename Taux, typename Tb, int THREADS_PER_BLOCK> template <typename T, typename Taux, typename Tb, int THREADS_PER_BLOCK>
__global__ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
void __launch_bounds__(THREADS_PER_BLOCK) add_bias_gelu_kernel(const float* in, T* out, Taux* pre_gelu_out, const Tb* __restrict bias, float* amax, const float* scale, int m, int n){ add_bias_gelu_kernel(const float* in, T* out, Taux* pre_gelu_out, const Tb* __restrict bias,
float* amax, const float* scale, int m, int n) {
// fp8 output flow // fp8 output flow
if constexpr(std::is_same<T, fp8e4m3>::value ||std::is_same<T, fp8e5m2>::value){ if constexpr (std::is_same<T, fp8e4m3>::value || std::is_same<T, fp8e5m2>::value) {
// only need to deal with amax and scale of D, no need to deal with amax and scale of pre_gelu_out // only need to deal with amax and scale of D, no need to deal with amax and scale of pre_gelu_out
typedef hipcub::BlockReduce<float, THREADS_PER_BLOCK> BlockReduce; typedef hipcub::BlockReduce<float, THREADS_PER_BLOCK> BlockReduce;
__shared__ typename BlockReduce::TempStorage block_temp_storage; __shared__ typename BlockReduce::TempStorage block_temp_storage;
float thread_amax = 0; float thread_amax = 0;
for(int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x){ for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) {
float reg_bias = (float)bias[id % n]; float reg_bias = (float)bias[id % n];
float val = in[id] + reg_bias; float val = in[id] + reg_bias;
// pre_gelu_out guaranteed not to be fp8 type // pre_gelu_out guaranteed not to be fp8 type
pre_gelu_out[id] = (Taux)(val); pre_gelu_out[id] = (Taux)(val);
val = gelu_forward(val); val = gelu_forward(val);
out[id] = (T)((*scale)*val); out[id] = (T)((*scale) * val);
// deal with amax of D // deal with amax of D
thread_amax=std::fmax(std::fabs(val), thread_amax); thread_amax = std::fmax(std::fabs(val), thread_amax);
} }
// num_valid can be ignored since each thread amax is set to 0 // num_valid can be ignored since each thread amax is set to 0
float block_amax = BlockReduce(block_temp_storage).Reduce(thread_amax, hipcub::Max()); float block_amax = BlockReduce(block_temp_storage).Reduce(thread_amax, hipcub::Max());
if(threadIdx.x==0){ if (threadIdx.x == 0) {
atomicMaxFloat(amax, block_amax); atomicMaxFloat(amax, block_amax);
} }
}else{ } else {
for(int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x){ for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) {
float reg_bias = (float)bias[id % n]; float reg_bias = (float)bias[id % n];
float val = in[id] + reg_bias; float val = in[id] + reg_bias;
pre_gelu_out[id] = (Taux)(val); pre_gelu_out[id] = (Taux)(val);
...@@ -263,93 +264,91 @@ void __launch_bounds__(THREADS_PER_BLOCK) add_bias_gelu_kernel(const float* in, ...@@ -263,93 +264,91 @@ void __launch_bounds__(THREADS_PER_BLOCK) add_bias_gelu_kernel(const float* in,
} }
template <typename T, typename Taux, typename Tb> template <typename T, typename Taux, typename Tb>
void add_bias_gelu_kernelLauncher(const float* in, T* out, Taux* pre_gelu_out, const Tb* __restrict bias, float* amax, const float* scale, int m, int n, hipStream_t stream) { void add_bias_gelu_kernelLauncher(const float* in, T* out, Taux* pre_gelu_out,
const Tb* __restrict bias, float* amax, const float* scale, int m,
int n, hipStream_t stream) {
dim3 block, grid; dim3 block, grid;
constexpr int THREADS_PER_BLOCK = 1024; constexpr int THREADS_PER_BLOCK = 1024;
block.x = THREADS_PER_BLOCK; block.x = THREADS_PER_BLOCK;
grid.x = ceil(1.0*m * n / THREADS_PER_BLOCK); grid.x = ceil(1.0 * m * n / THREADS_PER_BLOCK);
hipLaunchKernelGGL(( add_bias_gelu_kernel<T, Taux, Tb, THREADS_PER_BLOCK>), dim3(grid), dim3(block), 0, stream, in, out, pre_gelu_out, bias, amax, scale, m, n ); hipLaunchKernelGGL((add_bias_gelu_kernel<T, Taux, Tb, THREADS_PER_BLOCK>), dim3(grid),
dim3(block), 0, stream, in, out, pre_gelu_out, bias, amax, scale, m, n);
} }
template <typename Tin, typename T> template <typename Tin, typename T>
__global__ __global__ void identity_kernel(const Tin* in, T* out, int n) {
void identity_kernel(const Tin* in, T* out, int n) { for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < n; id += blockDim.x * gridDim.x) {
for(int id = blockIdx.x * blockDim.x + threadIdx.x; id < n; id += blockDim.x * gridDim.x)
{
Tin val = in[id]; Tin val = in[id];
out[id] = (T)(val); out[id] = (T)(val);
} }
} }
template <typename Tin, typename T> template <typename Tin, typename T>
void identity_kernelLauncher(const Tin* in, T* out, int n, hipStream_t stream) { void identity_kernelLauncher(const Tin* in, T* out, int n, hipStream_t stream) {
dim3 block, grid; dim3 block, grid;
block.x = 256; block.x = 256;
grid.x = ceil( n / 256.); grid.x = ceil(n / 256.);
hipLaunchKernelGGL(( identity_kernel<Tin, T>), dim3(grid), dim3(block), 0, stream, in, out, n ); hipLaunchKernelGGL((identity_kernel<Tin, T>), dim3(grid), dim3(block), 0, stream, in, out, n);
} }
template <typename T, int THREADS_PER_BLOCK> template <typename T, int THREADS_PER_BLOCK>
__global__ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
void __launch_bounds__(THREADS_PER_BLOCK) identity_output_kernel(const float* in, T* out, float* amax, const float* scale, int n) { identity_output_kernel(const float* in, T* out, float* amax, const float* scale, int n) {
if constexpr(std::is_same<T, fp8e4m3>::value ||std::is_same<T, fp8e5m2>::value){ if constexpr (std::is_same<T, fp8e4m3>::value || std::is_same<T, fp8e5m2>::value) {
typedef hipcub::BlockReduce<float, THREADS_PER_BLOCK> BlockReduce; typedef hipcub::BlockReduce<float, THREADS_PER_BLOCK> BlockReduce;
__shared__ typename BlockReduce::TempStorage block_temp_storage; __shared__ typename BlockReduce::TempStorage block_temp_storage;
float thread_amax = 0; float thread_amax = 0;
for(int id = blockIdx.x * blockDim.x + threadIdx.x; id < n; id += blockDim.x * gridDim.x){ for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < n; id += blockDim.x * gridDim.x) {
float val = in[id]; float val = in[id];
out[id] = (T)((*scale)*val); out[id] = (T)((*scale) * val);
// deal with amax of D // deal with amax of D
thread_amax=std::fmax(std::fabs(val), thread_amax); thread_amax = std::fmax(std::fabs(val), thread_amax);
} }
// num_valid can be ignored since each thread amax is set to 0 // num_valid can be ignored since each thread amax is set to 0
float block_amax = BlockReduce(block_temp_storage).Reduce(thread_amax, hipcub::Max()); float block_amax = BlockReduce(block_temp_storage).Reduce(thread_amax, hipcub::Max());
if(threadIdx.x==0){ if (threadIdx.x == 0) {
atomicMaxFloat(amax, block_amax); atomicMaxFloat(amax, block_amax);
} }
}else{ } else {
for(int id = blockIdx.x * blockDim.x + threadIdx.x; id < n; id += blockDim.x * gridDim.x){ for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < n; id += blockDim.x * gridDim.x) {
float val = in[id]; float val = in[id];
out[id] = (T)(val); out[id] = (T)(val);
} }
} }
} }
template <typename T> template <typename T>
void identity_output_kernelLauncher(const float* in, T* out, float* amax, const float* scale, int n, hipStream_t stream) { void identity_output_kernelLauncher(const float* in, T* out, float* amax, const float* scale, int n,
hipStream_t stream) {
dim3 block, grid; dim3 block, grid;
constexpr int THREADS_PER_BLOCK = 1024; constexpr int THREADS_PER_BLOCK = 1024;
block.x = THREADS_PER_BLOCK; block.x = THREADS_PER_BLOCK;
grid.x = ceil( 1.0*n / THREADS_PER_BLOCK); grid.x = ceil(1.0 * n / THREADS_PER_BLOCK);
hipLaunchKernelGGL(( identity_output_kernel<T, THREADS_PER_BLOCK>), dim3(grid), dim3(block), 0, stream, in, out, amax, scale, n ); hipLaunchKernelGGL((identity_output_kernel<T, THREADS_PER_BLOCK>), dim3(grid), dim3(block), 0,
stream, in, out, amax, scale, n);
} }
template <typename Tin, int THREADS_PER_BLOCK> template <typename Tin, int THREADS_PER_BLOCK>
__global__ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
void __launch_bounds__(THREADS_PER_BLOCK) bias_gradient_kernel(const Tin* in, float* out, int m, int n) { bias_gradient_kernel(const Tin* in, float* out, int m, int n) {
typedef hipcub::BlockReduce<float, THREADS_PER_BLOCK> BlockReduce; typedef hipcub::BlockReduce<float, THREADS_PER_BLOCK> BlockReduce;
__shared__ typename BlockReduce::TempStorage block_temp_storage; __shared__ typename BlockReduce::TempStorage block_temp_storage;
int BLOCKS_PER_COL = ceil(float(m)/THREADS_PER_BLOCK); int BLOCKS_PER_COL = ceil(float(m) / THREADS_PER_BLOCK);
int THREADS_PER_COL = BLOCKS_PER_COL * THREADS_PER_BLOCK; int THREADS_PER_COL = BLOCKS_PER_COL * THREADS_PER_BLOCK;
int idx = threadIdx.x + blockIdx.x * blockDim.x; int idx = threadIdx.x + blockIdx.x * blockDim.x;
int col_idx = idx / THREADS_PER_COL; int col_idx = idx / THREADS_PER_COL;
int row_idx = idx % THREADS_PER_COL; int row_idx = idx % THREADS_PER_COL;
float thread_data; float thread_data;
if (row_idx < m) if (row_idx < m) thread_data = (float)in[row_idx * n + col_idx];
thread_data = (float)in[row_idx * n + col_idx];
float local_sum; float local_sum;
if (row_idx < (BLOCKS_PER_COL-1) * THREADS_PER_BLOCK) { if (row_idx < (BLOCKS_PER_COL - 1) * THREADS_PER_BLOCK) {
local_sum = BlockReduce(block_temp_storage).Sum(thread_data); local_sum = BlockReduce(block_temp_storage).Sum(thread_data);
} else {
local_sum = BlockReduce(block_temp_storage)
.Sum(thread_data, m - (BLOCKS_PER_COL - 1) * THREADS_PER_BLOCK);
} }
else { if (threadIdx.x == 0) atomicAdd(&out[col_idx], local_sum);
local_sum = BlockReduce(block_temp_storage).Sum(thread_data, m-(BLOCKS_PER_COL-1)*THREADS_PER_BLOCK);
}
if (threadIdx.x == 0)
atomicAdd(&out[col_idx], local_sum);
} }
constexpr int kColwiseReduceTileSize = 32; constexpr int kColwiseReduceTileSize = 32;
...@@ -364,45 +363,47 @@ __inline__ __device__ T WarpReduceSum(T val, int max = 32) { ...@@ -364,45 +363,47 @@ __inline__ __device__ T WarpReduceSum(T val, int max = 32) {
template <typename InputType> template <typename InputType>
__launch_bounds__(1024) __global__ __launch_bounds__(1024) __global__
void bias_gradient_kernel_v2(float *dst, const InputType *src, int M, int N) { void bias_gradient_kernel_v2(float* dst, const InputType* src, int M, int N) {
__shared__ float g_shared[kColwiseReduceTileSize][kColwiseReduceTileSize]; __shared__ float g_shared[kColwiseReduceTileSize][kColwiseReduceTileSize];
const int j = blockIdx.x * blockDim.x + threadIdx.x; const int j = blockIdx.x * blockDim.x + threadIdx.x;
float grad_sum = 0.f; float grad_sum = 0.f;
if (j < N) {
for (int i = threadIdx.y; i < M; i += blockDim.y) {
grad_sum += static_cast<float>(src[i * N + j]);
}
}
g_shared[threadIdx.y][threadIdx.x] = grad_sum;
__syncthreads();
float sum = g_shared[threadIdx.x][threadIdx.y];
sum = WarpReduceSum<float>(sum, kColwiseReduceTileSize / 2);
if (threadIdx.x == 0) {
const int j = blockIdx.x * blockDim.x + threadIdx.y;
if (j < N) { if (j < N) {
for (int i = threadIdx.y; i < M; i += blockDim.y) { dst[j] = static_cast<float>(sum);
grad_sum += static_cast<float>(src[i * N + j]);
}
} }
g_shared[threadIdx.y][threadIdx.x] = grad_sum; }
__syncthreads();
float sum = g_shared[threadIdx.x][threadIdx.y];
sum = WarpReduceSum<float>(sum, kColwiseReduceTileSize / 2);
if (threadIdx.x == 0) {
const int j = blockIdx.x * blockDim.x + threadIdx.y;
if (j < N) {
dst[j] = static_cast<float>(sum);
}
}
} }
template <typename Tin> template <typename Tin>
void bias_gradient_kernelLauncher(const Tin* in, float* out, int m, int n, bool stream_order_alloc, hipStream_t stream) { void bias_gradient_kernelLauncher(const Tin* in, float* out, int m, int n, bool stream_order_alloc,
hipStream_t stream) {
dim3 block, grid; dim3 block, grid;
constexpr int THREADS_PER_BLOCK = 1024; constexpr int THREADS_PER_BLOCK = 1024;
int BLOCKS_PER_COL = ceil(float(m)/THREADS_PER_BLOCK); int BLOCKS_PER_COL = ceil(float(m) / THREADS_PER_BLOCK);
block.x = THREADS_PER_BLOCK; block.x = THREADS_PER_BLOCK;
grid.x = BLOCKS_PER_COL*n; grid.x = BLOCKS_PER_COL * n;
if(! stream_order_alloc){ if (!stream_order_alloc) {
NVTE_CHECK_CUDA( hipMemset(out, 0, n*sizeof(float)) ); NVTE_CHECK_CUDA(hipMemset(out, 0, n * sizeof(float)));
}else{ } else {
NVTE_CHECK_CUDA( hipMemsetAsync(out, 0, n*sizeof(float), stream) ); NVTE_CHECK_CUDA(hipMemsetAsync(out, 0, n * sizeof(float), stream));
} }
// hipLaunchKernelGGL(( bias_gradient_kernel<Tin, THREADS_PER_BLOCK>), dim3(grid), dim3(block), 0, stream, in, out, m, n); // hipLaunchKernelGGL(( bias_gradient_kernel<Tin, THREADS_PER_BLOCK>), dim3(grid), dim3(block), 0, stream, in, out, m, n);
int B =(n - 1) / kColwiseReduceTileSize + 1; int B = (n - 1) / kColwiseReduceTileSize + 1;
bias_gradient_kernel_v2<Tin><<<B, dim3(kColwiseReduceTileSize, kColwiseReduceTileSize), 0, stream>>>(out, in, m, n); bias_gradient_kernel_v2<Tin>
<<<B, dim3(kColwiseReduceTileSize, kColwiseReduceTileSize), 0, stream>>>(out, in, m, n);
} }
} // namespace detail } // namespace detail
transformer_engine::DType get_transformer_engine_dtype(const rocblas_datatype t) { transformer_engine::DType get_transformer_engine_dtype(const rocblas_datatype t) {
using namespace transformer_engine; using namespace transformer_engine;
...@@ -421,28 +422,25 @@ transformer_engine::DType get_transformer_engine_dtype(const rocblas_datatype t) ...@@ -421,28 +422,25 @@ transformer_engine::DType get_transformer_engine_dtype(const rocblas_datatype t)
NVTE_ERROR("Invalid type"); NVTE_ERROR("Invalid type");
} }
} }
#endif //USE_ROCBLAS #endif //USE_ROCBLAS
#ifdef USE_HIPBLASLT #ifdef USE_HIPBLASLT
namespace { namespace {
static class HandlePool { static class HandlePool {
public: public:
hipblasLtHandle_t get(int device_id) hipblasLtHandle_t get(int device_id) {
{
std::lock_guard<std::mutex> lock(mt); std::lock_guard<std::mutex> lock(mt);
if (pool.empty()) if (pool.empty()) {
{ int device_count = 0;
int device_count = 0;
NVTE_CHECK_CUDA(hipGetDeviceCount(&device_count)); NVTE_CHECK_CUDA(hipGetDeviceCount(&device_count));
pool.resize(device_count); pool.resize(device_count);
return nullptr; return nullptr;
} }
if (!pool[device_id].empty()) if (!pool[device_id].empty()) {
{
hipblasLtHandle_t h = pool[device_id].front(); hipblasLtHandle_t h = pool[device_id].front();
pool[device_id].pop_front(); pool[device_id].pop_front();
return h; return h;
...@@ -451,27 +449,21 @@ public: ...@@ -451,27 +449,21 @@ public:
return nullptr; return nullptr;
} }
hipblasLtHandle_t obtain(int device_id) hipblasLtHandle_t obtain(int device_id) {
{
hipblasLtHandle_t h = get(device_id); hipblasLtHandle_t h = get(device_id);
if (h == nullptr) if (h == nullptr) {
{
NVTE_CHECK_HIPBLASLT(hipblasLtCreate(&h)); NVTE_CHECK_HIPBLASLT(hipblasLtCreate(&h));
} }
return h; return h;
} }
void store(const std::vector<hipblasLtHandle_t>& handles) void store(const std::vector<hipblasLtHandle_t>& handles) {
{
std::lock_guard<std::mutex> lock(mt); std::lock_guard<std::mutex> lock(mt);
if (pool.empty()) if (pool.empty()) {
{
std::cout << "[ERROR] Attempt to store handles to invalid pool" << std::endl; std::cout << "[ERROR] Attempt to store handles to invalid pool" << std::endl;
} }
for (unsigned int i=0; i<pool.size(); i++) for (unsigned int i = 0; i < pool.size(); i++) {
{ if (handles[i] != nullptr) {
if (handles[i] != nullptr)
{
pool[i].push_front(handles[i]); pool[i].push_front(handles[i]);
} }
} }
...@@ -480,10 +472,8 @@ public: ...@@ -480,10 +472,8 @@ public:
~HandlePool() { ~HandlePool() {
#if DESTROY_HIPBLASLT_HANDLES_POOL #if DESTROY_HIPBLASLT_HANDLES_POOL
std::lock_guard<std::mutex> lock(mt); std::lock_guard<std::mutex> lock(mt);
for (auto & hlist : pool) for (auto& hlist : pool) {
{ for (auto& h : hlist) {
for (auto & h : hlist)
{
hipblasLtDestroy(h); hipblasLtDestroy(h);
} }
} }
...@@ -491,12 +481,9 @@ public: ...@@ -491,12 +481,9 @@ public:
#endif #endif
} }
inline size_t get_size() const inline size_t get_size() const { return pool.size(); }
{
return pool.size();
}
private: private:
std::mutex mt; std::mutex mt;
using Pool = std::vector<std::forward_list<hipblasLtHandle_t>>; using Pool = std::vector<std::forward_list<hipblasLtHandle_t>>;
// Order of destructors between thread_local and global is not actually guaranteed // Order of destructors between thread_local and global is not actually guaranteed
...@@ -506,23 +493,17 @@ private: ...@@ -506,23 +493,17 @@ private:
#if DESTROY_HIPBLASLT_HANDLES_POOL #if DESTROY_HIPBLASLT_HANDLES_POOL
Pool pool; Pool pool;
#else #else
Pool &pool = *new Pool(); Pool& pool = *new Pool();
#endif #endif
} handle_pool; } handle_pool;
thread_local static class HandleCache { thread_local static class HandleCache {
public: public:
hipblasLtHandle_t get(int device_id) const hipblasLtHandle_t get(int device_id) const { return d.empty() ? nullptr : d[device_id]; }
{
return d.empty() ? nullptr : d[device_id];
}
hipblasLtHandle_t obtain(int device_id) hipblasLtHandle_t obtain(int device_id) {
{
hipblasLtHandle_t h = get(device_id); hipblasLtHandle_t h = get(device_id);
if (h) if (h) {
{
return h; return h;
} }
h = handle_pool.obtain(device_id); h = handle_pool.obtain(device_id);
...@@ -530,126 +511,106 @@ public: ...@@ -530,126 +511,106 @@ public:
return h; return h;
} }
void set(int device_id, hipblasLtHandle_t h) void set(int device_id, hipblasLtHandle_t h) {
{ if (d.empty()) {
if (d.empty())
{
d.resize(handle_pool.get_size()); d.resize(handle_pool.get_size());
} }
d[device_id] = h; d[device_id] = h;
} }
~HandleCache() ~HandleCache() {
{ if (!d.empty()) {
if (!d.empty())
{
handle_pool.store(d); handle_pool.store(d);
} }
} }
private: private:
std::vector<hipblasLtHandle_t> d; std::vector<hipblasLtHandle_t> d;
} cached_handles; } cached_handles;
class csv_helper {
class csv_helper public:
{
public:
struct start {}; struct start {};
struct end {}; struct end {};
csv_helper(std::ostream& os, char sep_val) : m_os{ os }, m_sep_val(sep_val), m_start(true), m_sep("") {} csv_helper(std::ostream& os, char sep_val)
: m_os{os}, m_sep_val(sep_val), m_start(true), m_sep("") {}
csv_helper& operator << (const start&) csv_helper& operator<<(const start&) {
{
m_start = true; m_start = true;
return *this; return *this;
} }
csv_helper& operator << (const end&) csv_helper& operator<<(const end&) {
{ m_sep = "";
m_sep="";
m_start = false; m_start = false;
return *this; return *this;
} }
template< typename T> template <typename T>
csv_helper& operator<<(const T& v) csv_helper& operator<<(const T& v) {
{
m_os << m_sep << v; m_os << m_sep << v;
if (m_start) if (m_start) {
{
m_start = false; m_start = false;
m_sep = m_sep_val; m_sep = m_sep_val;
} }
return *this; return *this;
} }
private: private:
std::ostream& m_os; std::ostream& m_os;
char m_sep_val; char m_sep_val;
bool m_start; bool m_start;
std::string m_sep; std::string m_sep;
}; };
template <typename T>
template<typename T> class NameMapper {
class NameMapper public:
{ NameMapper(const std::unordered_map<T, std::string_view>& name_map) : map(name_map) {}
public: const std::string_view& getName(const T& val) { return map.at(val); }
NameMapper(const std::unordered_map<T, std::string_view>& name_map): map(name_map) {} T getValue(const std::string& name, const char* label = "",
const std::string_view &getName(const T &val) { std::function<bool(const T&)> filter = nullptr) {
return map.at(val); for (auto iter = map.begin(); iter != map.end(); ++iter) {
}
T getValue(const std::string& name, const char *label="", std::function<bool(const T&)> filter = nullptr)
{
for (auto iter = map.begin(); iter != map.end(); ++iter)
{
if ((name == iter->second) && (!filter || filter(iter->first))) return iter->first; if ((name == iter->second) && (!filter || filter(iter->first))) return iter->first;
} }
NVTE_ERROR("Invalid ", label, " name: ", name); NVTE_ERROR("Invalid ", label, " name: ", name);
} }
protected:
const std::unordered_map<T, std::string_view> &map; protected:
const std::unordered_map<T, std::string_view>& map;
}; };
static std::unordered_map<hipDataType, std::string_view> type_name_map = { static std::unordered_map<hipDataType, std::string_view> type_name_map = {
{HIP_R_32F, "float32"}, {HIP_R_32F, "float32"},
{HIP_R_16F, "float16"}, {HIP_R_16F, "float16"},
{HIP_R_16BF, "bfloat16"}, {HIP_R_16BF, "bfloat16"},
{HIP_R_8F_E4M3_FNUZ, "float8e4m3"}, {HIP_R_8F_E4M3_FNUZ, "float8e4m3"},
{HIP_R_8F_E5M2_FNUZ, "float8e5m2"}, {HIP_R_8F_E5M2_FNUZ, "float8e5m2"},
#if HIP_VERSION >= 60300000 #if HIP_VERSION >= 60300000
{HIP_R_8F_E4M3, "float8e4m3"}, {HIP_R_8F_E4M3, "float8e4m3"},
{HIP_R_8F_E5M2, "float8e5m2"}, {HIP_R_8F_E5M2, "float8e5m2"},
#endif #endif
}; };
static NameMapper<hipDataType> typeNameMapper(type_name_map); static NameMapper<hipDataType> typeNameMapper(type_name_map);
static std::unordered_map<hipblasOperation_t, std::string_view> trans_name_map = { static std::unordered_map<hipblasOperation_t, std::string_view> trans_name_map = {
{HIPBLAS_OP_N, "N"}, {HIPBLAS_OP_N, "N"}, {HIPBLAS_OP_T, "T"}};
{HIPBLAS_OP_T, "T"}
};
static NameMapper<hipblasOperation_t> transposeNameMapper(trans_name_map); static NameMapper<hipblasOperation_t> transposeNameMapper(trans_name_map);
static std::unordered_map<hipblasLtEpilogue_t, std::string_view> epi_name_map = { static std::unordered_map<hipblasLtEpilogue_t, std::string_view> epi_name_map = {
{HIPBLASLT_EPILOGUE_DEFAULT, "-"}, {HIPBLASLT_EPILOGUE_DEFAULT, "-"}, {HIPBLASLT_EPILOGUE_BIAS, "bias"},
{HIPBLASLT_EPILOGUE_BIAS, "bias"}, {HIPBLASLT_EPILOGUE_GELU_AUX, "geluaux"}, {HIPBLASLT_EPILOGUE_GELU_AUX_BIAS, "geluauxbias"},
{HIPBLASLT_EPILOGUE_GELU_AUX, "geluaux"}, {HIPBLASLT_EPILOGUE_DGELU, "dgelu"}, {HIPBLASLT_EPILOGUE_DGELU_BGRAD, "dgelubgrad"},
{HIPBLASLT_EPILOGUE_GELU_AUX_BIAS, "geluauxbias"}, {HIPBLASLT_EPILOGUE_BGRADB, "bgradb"}};
{HIPBLASLT_EPILOGUE_DGELU, "dgelu"},
{HIPBLASLT_EPILOGUE_DGELU_BGRAD, "dgelubgrad"},
{HIPBLASLT_EPILOGUE_BGRADB, "bgradb"}
};
static NameMapper<hipblasLtEpilogue_t> epilogueNameMapper(epi_name_map); static NameMapper<hipblasLtEpilogue_t> epilogueNameMapper(epi_name_map);
static std::unordered_map<hipblasComputeType_t, std::string_view> comp_name_map = { static std::unordered_map<hipblasComputeType_t, std::string_view> comp_name_map = {
{HIPBLAS_COMPUTE_32F, "f32"} {HIPBLAS_COMPUTE_32F, "f32"}};
};
static NameMapper<hipblasComputeType_t> computeNameMapper(comp_name_map); static NameMapper<hipblasComputeType_t> computeNameMapper(comp_name_map);
static class GemmAlgoCache { static class GemmAlgoCache {
public: public:
struct Key { struct Key {
int deviceCap; int deviceCap;
hipDataType a_type, b_type, d_type, bias_type; hipDataType a_type, b_type, d_type, bias_type;
...@@ -658,61 +619,58 @@ public: ...@@ -658,61 +619,58 @@ public:
hipblasOperation_t transa, transb; hipblasOperation_t transa, transb;
hipblasLtEpilogue_t epilogue; hipblasLtEpilogue_t epilogue;
Key(int deviceCap_, Key(int deviceCap_, hipDataType a_type_, hipDataType b_type_, hipDataType d_type_,
hipDataType a_type_, hipDataType b_type_, hipDataType bias_type_, int m_, int n_, int k_, int lda_, int ldb_, int ldd_,
hipDataType d_type_, hipDataType bias_type_, hipblasOperation_t transa_, hipblasOperation_t transb_, hipblasLtEpilogue_t epilogue_)
int m_, int n_, int k_, int lda_, int ldb_, int ldd_, : deviceCap(deviceCap_),
hipblasOperation_t transa_, hipblasOperation_t transb_, a_type(a_type_),
hipblasLtEpilogue_t epilogue_): b_type(b_type_),
deviceCap(deviceCap_), d_type(d_type_),
a_type(a_type_), b_type(b_type_), bias_type(bias_type_),
d_type(d_type_), bias_type(bias_type_), m(m_),
m(m_), n(n_), k(k_), lda(lda_), ldb(ldb_), ldd(ldd_), n(n_),
transa(transa_), transb(transb_), k(k_),
epilogue(epilogue_) {} lda(lda_),
ldb(ldb_),
ldd(ldd_),
transa(transa_),
transb(transb_),
epilogue(epilogue_) {}
Key() {} Key() {}
bool operator==(const Key &val) const bool operator==(const Key& val) const {
{ return ((deviceCap == val.deviceCap) && (a_type == val.a_type) && (b_type == val.b_type) &&
return ((deviceCap == val.deviceCap) (d_type == val.d_type) && (bias_type == val.bias_type) && (m == val.m) &&
&& (a_type == val.a_type) && (b_type == val.b_type) (n == val.n) && (k == val.k) && (lda == val.lda) && (ldb == val.ldb) &&
&& (d_type == val.d_type) && (bias_type == val.bias_type) (ldd == val.ldd) && (transa == val.transa) && (transb == val.transb) &&
&& (m == val.m) && (n == val.n) && (k == val.k) (epilogue == val.epilogue));
&& (lda == val.lda) && (ldb == val.ldb) && (ldd == val.ldd)
&& (transa == val.transa) && (transb == val.transb)
&& (epilogue == val.epilogue) );
} }
struct Comp struct Comp {
{ bool operator()(const Key& lhs, const Key& rhs) const {
bool operator()(const Key& lhs, const Key& rhs) const return ::std::string_view((const char*)&lhs, sizeof(lhs)) <
{ ::std::string_view((const char*)&rhs, sizeof(rhs));
return ::std::string_view((const char*)&lhs, sizeof(lhs)) < ::std::string_view((const char*)&rhs, sizeof(rhs));
} }
}; };
}; };
void init() void init() {
{
std::lock_guard<std::mutex> lock(mt); std::lock_guard<std::mutex> lock(mt);
int device_count = 0; int device_count = 0;
NVTE_CHECK_CUDA(hipGetDeviceCount(&device_count)); NVTE_CHECK_CUDA(hipGetDeviceCount(&device_count));
dev_cap.resize(device_count); dev_cap.resize(device_count);
for (int i=0; i<device_count; i++) for (int i = 0; i < device_count; i++) {
{
hipDeviceProp_t prop; hipDeviceProp_t prop;
NVTE_CHECK_CUDA(hipGetDeviceProperties(&prop, i)); NVTE_CHECK_CUDA(hipGetDeviceProperties(&prop, i));
dev_cap[i] = prop.major*100 + prop.minor; dev_cap[i] = prop.major * 100 + prop.minor;
} }
load_(); load_();
save_(); save_();
} }
inline int device_cap(int device_id) inline int device_cap(int device_id) {
{ if (dev_cap.empty()) init();
if (dev_cap.empty())
init();
return dev_cap[device_id]; return dev_cap[device_id];
} }
...@@ -722,28 +680,25 @@ public: ...@@ -722,28 +680,25 @@ public:
int index; int index;
size_t ws_size_min; size_t ws_size_min;
size_t ws_size_max; size_t ws_size_max;
Algo(): algo(), index(-1), algoId(), ws_size_min(0), ws_size_max(0) {} Algo() : algo(), index(-1), algoId(), ws_size_min(0), ws_size_max(0) {}
Algo(int idx, int64_t id, size_t ws_min, size_t ws_max): algo(), index(idx), algoId(id), ws_size_min(ws_min), ws_size_max(ws_max) {} Algo(int idx, int64_t id, size_t ws_min, size_t ws_max)
inline bool hasId() { return index>=0; } const : algo(), index(idx), algoId(id), ws_size_min(ws_min), ws_size_max(ws_max) {}
static inline int64_t getAlgoId(const hipblasLtMatmulAlgo_t &algo) inline bool hasId() { return index >= 0; }
{ const static inline int64_t getAlgoId(const hipblasLtMatmulAlgo_t& algo) {
return *(const int64_t*)&algo; return *(const int64_t*)&algo;
} }
}; };
bool find(const Key &cfg, size_t ws_size, Algo &algo) bool find(const Key& cfg, size_t ws_size, Algo& algo) {
{
std::lock_guard<std::mutex> lock(mt); std::lock_guard<std::mutex> lock(mt);
if (auto *pentry = find_(cfg, ws_size, ws_size); pentry != nullptr) if (auto* pentry = find_(cfg, ws_size, ws_size); pentry != nullptr) {
{
algo = *pentry; algo = *pentry;
return true; return true;
} }
return false; return false;
} }
void store(const Key &cfg, const Algo &algo) void store(const Key& cfg, const Algo& algo) {
{
size_t ws_size_min = algo.ws_size_min; size_t ws_size_min = algo.ws_size_min;
size_t ws_size_max = algo.ws_size_max; size_t ws_size_max = algo.ws_size_max;
NVTE_CHECK(ws_size_max >= ws_size_min, "Invalid WS size"); NVTE_CHECK(ws_size_max >= ws_size_min, "Invalid WS size");
...@@ -751,23 +706,17 @@ public: ...@@ -751,23 +706,17 @@ public:
//Remove overlapping with existing entries; //Remove overlapping with existing entries;
while (auto* pentry = find_(cfg, ws_size_min, ws_size_max)) { while (auto* pentry = find_(cfg, ws_size_min, ws_size_max)) {
if (pentry->ws_size_min <= ws_size_min && pentry->ws_size_max >= ws_size_max) if (pentry->ws_size_min <= ws_size_min && pentry->ws_size_max >= ws_size_max) {
{
*pentry = algo; *pentry = algo;
save_(); save_();
return; return;
} }
if (ws_size_max > pentry->ws_size_max) if (ws_size_max > pentry->ws_size_max) {
{
ws_size_min = pentry->ws_size_max + 1; ws_size_min = pentry->ws_size_max + 1;
} } else if (ws_size_min < pentry->ws_size_min) {
else if (ws_size_min < pentry->ws_size_min)
{
ws_size_max = pentry->ws_size_min - 1; ws_size_max = pentry->ws_size_min - 1;
} } else {
else
{
//Should never be here //Should never be here
NVTE_ERROR("Cannot merge WS size range"); NVTE_ERROR("Cannot merge WS size range");
} }
...@@ -775,14 +724,11 @@ public: ...@@ -775,14 +724,11 @@ public:
//Merge to adjusted entry if possible //Merge to adjusted entry if possible
auto* pentry = find_(cfg, ws_size_min - 1, ws_size_min); auto* pentry = find_(cfg, ws_size_min - 1, ws_size_min);
if (pentry && pentry->algoId == algo.algoId) if (pentry && pentry->algoId == algo.algoId) {
{
pentry->algo = algo.algo; pentry->algo = algo.algo;
pentry->ws_size_max = ws_size_max; pentry->ws_size_max = ws_size_max;
save_(); save_();
} } else {
else
{
auto it = d.emplace(cfg, algo); auto it = d.emplace(cfg, algo);
it->second.ws_size_min = ws_size_min; it->second.ws_size_min = ws_size_min;
it->second.ws_size_max = ws_size_max; it->second.ws_size_max = ws_size_max;
...@@ -790,40 +736,32 @@ public: ...@@ -790,40 +736,32 @@ public:
} }
} }
protected: protected:
Algo* find_(const Key& cfg, size_t ws_min, size_t ws_max) {
Algo* find_(const Key &cfg, size_t ws_min, size_t ws_max)
{
const auto key_range = d.equal_range(cfg); const auto key_range = d.equal_range(cfg);
for (auto i = key_range.first; i != key_range.second; i++) for (auto i = key_range.first; i != key_range.second; i++) {
{ if (ws_min <= i->second.ws_size_max && ws_max >= i->second.ws_size_min) {
if (ws_min <= i->second.ws_size_max && ws_max >= i->second.ws_size_min)
{
return &i->second; return &i->second;
} }
} }
return nullptr; return nullptr;
} }
void header_(std::ostream& ofs) void header_(std::ostream& ofs) {
{
csv_helper fs(ofs, csv_sep); csv_helper fs(ofs, csv_sep);
fs << "dev_cap" << "m" << "n" << "k" << "trans_a" << "trans_b" fs << "dev_cap" << "m" << "n" << "k" << "trans_a" << "trans_b"
<< "type_a" << "type_b" << "type_d" << "bias_type" << "type_a" << "type_b" << "type_d" << "bias_type"
<< "lda" << "ldb" << "ldd" << "epi" << "comp" << "scale" << "lda" << "ldb" << "ldd" << "epi" << "comp" << "scale"
<< "ws_min" << "ws_max" << "algo_id" << "aidx"; << "ws_min" << "ws_max" << "algo_id" << "aidx";
} }
void load_() void load_() {
{
const char* env = std::getenv("TE_HIPBLASLT_ALGO_LOAD"); const char* env = std::getenv("TE_HIPBLASLT_ALGO_LOAD");
if (env == nullptr || env[0] == '\0') if (env == nullptr || env[0] == '\0') {
{
return; return;
} }
std::ifstream ifs{env}; std::ifstream ifs{env};
if (!ifs.is_open()) if (!ifs.is_open()) {
{
std::cerr << "Could not load autotune results storage " << env << "\n"; std::cerr << "Could not load autotune results storage " << env << "\n";
return; return;
} }
...@@ -831,7 +769,7 @@ protected: ...@@ -831,7 +769,7 @@ protected:
Key cfg; Key cfg;
std::string line; std::string line;
std::getline(ifs, line); // the first line with legend std::getline(ifs, line); // the first line with legend
{ {
std::ostringstream hline; std::ostringstream hline;
header_(hline); header_(hline);
...@@ -841,12 +779,10 @@ protected: ...@@ -841,12 +779,10 @@ protected:
} }
} }
while(std::getline(ifs, line)) while (std::getline(ifs, line)) {
{
line.erase(0, line.find_first_not_of(" \t\n\r\f\v")); line.erase(0, line.find_first_not_of(" \t\n\r\f\v"));
if (auto pos = line.find_last_not_of(" \t\n\r\f\v"); pos != std::string::npos) if (auto pos = line.find_last_not_of(" \t\n\r\f\v"); pos != std::string::npos) {
{ line.resize(pos + 1);
line.resize(pos+1);
} }
if (line.empty() || line[0] == '#') continue; if (line.empty() || line[0] == '#') continue;
std::istringstream is(line); std::istringstream is(line);
...@@ -861,10 +797,8 @@ protected: ...@@ -861,10 +797,8 @@ protected:
//Filter out entries for devices not presented on the curent system //Filter out entries for devices not presented on the curent system
bool b_found = false; bool b_found = false;
for (int i=0; i<dev_cap.size(); i++) for (int i = 0; i < dev_cap.size(); i++) {
{ if (dev_cap[i] == cfg.deviceCap) {
if (dev_cap[i] == cfg.deviceCap)
{
b_found = true; b_found = true;
break; break;
} }
...@@ -882,23 +816,21 @@ protected: ...@@ -882,23 +816,21 @@ protected:
std::getline(is, comp, csv_sep); std::getline(is, comp, csv_sep);
std::getline(is, scale, csv_sep); std::getline(is, scale, csv_sep);
is >> ws_min >> c >> ws_max >> c >> algo_id >> c >> algo_idx; is >> ws_min >> c >> ws_max >> c >> algo_id >> c >> algo_idx;
if (is.bad()) if (is.bad()) {
{
std::cerr << "Parsing CSV line failed: " << line << "\n"; std::cerr << "Parsing CSV line failed: " << line << "\n";
return; return;
} }
if (ws_min > ws_max) if (ws_min > ws_max) {
{
std::cout << "[WARNING] Invalid WS size at " << line << "\n"; std::cout << "[WARNING] Invalid WS size at " << line << "\n";
continue; continue;
} }
#if HIP_VERSION >= 60300000 #if HIP_VERSION >= 60300000
auto fp8_filter = [](const hipDataType& val) { auto fp8_filter = [](const hipDataType& val) {
return (val != HIP_R_8F_E4M3_FNUZ && val != HIP_R_8F_E5M2_FNUZ); return (val != HIP_R_8F_E4M3_FNUZ && val != HIP_R_8F_E5M2_FNUZ);
}; };
#else #else
auto fp8_filter = nullptr; auto fp8_filter = nullptr;
#endif #endif
...@@ -916,28 +848,23 @@ protected: ...@@ -916,28 +848,23 @@ protected:
cfg.epilogue = epilogueNameMapper.getValue(epi, "epi"); cfg.epilogue = epilogueNameMapper.getValue(epi, "epi");
//Check and filter out compute and scale types //Check and filter out compute and scale types
if (computeNameMapper.getValue(comp, "comp") != HIPBLAS_COMPUTE_32F || if (computeNameMapper.getValue(comp, "comp") != HIPBLAS_COMPUTE_32F ||
typeNameMapper.getValue(scale, "scale") != HIP_R_32F) typeNameMapper.getValue(scale, "scale") != HIP_R_32F) {
{
continue; continue;
} }
if (find_(cfg, ws_min, ws_max)) if (find_(cfg, ws_min, ws_max)) {
{ std::cout << "[WARNING] Duplicated/overlapped entry in algo cache\n";
std::cout << "[WARNING] Duplicated/overlapped entry in algo cache\n"; continue;
continue;
} }
d.emplace(cfg, Algo(algo_idx, algo_id, ws_min, ws_max)); d.emplace(cfg, Algo(algo_idx, algo_id, ws_min, ws_max));
} }
} }
bool can_save_(bool reopen = false) bool can_save_(bool reopen = false) {
{ if (!save_fs) {
if (!save_fs)
{
const char* temp = std::getenv("TE_HIPBLASLT_ALGO_SAVE"); const char* temp = std::getenv("TE_HIPBLASLT_ALGO_SAVE");
if (temp == nullptr || temp[0] == '\0') if (temp == nullptr || temp[0] == '\0') {
{
return false; return false;
} }
...@@ -954,60 +881,51 @@ protected: ...@@ -954,60 +881,51 @@ protected:
std::cout << "Saving autotune results to " << save_fs_name << "\n"; std::cout << "Saving autotune results to " << save_fs_name << "\n";
} }
if (reopen) if (reopen) {
{ if (save_fs->is_open()) {
if (save_fs->is_open())
{
save_fs->close(); save_fs->close();
} }
save_fs->open(save_fs_name, std::ios_base::trunc); save_fs->open(save_fs_name, std::ios_base::trunc);
} }
if (save_fs->is_open() && !save_fs->bad()) if (save_fs->is_open() && !save_fs->bad()) {
{
return true; return true;
} } else {
else
{
if (reopen) std::cerr << "Could not open autotune results storage " << save_fs_name << "\n"; if (reopen) std::cerr << "Could not open autotune results storage " << save_fs_name << "\n";
return false; return false;
} }
} }
void save_() void save_() {
{ if (!can_save_(true)) {
if (!can_save_(true))
{
return; return;
} }
header_(*save_fs); header_(*save_fs);
*save_fs << "\n"; *save_fs << "\n";
for (const auto &elem: d) for (const auto& elem : d) {
{
save_(elem.first, elem.second); save_(elem.first, elem.second);
} }
} }
void save_(const Key &cfg, const Algo &algo) void save_(const Key& cfg, const Algo& algo) {
{ if (!can_save_()) {
if (!can_save_())
{
return; return;
} }
csv_helper csv(*save_fs, csv_sep); csv_helper csv(*save_fs, csv_sep);
csv << cfg.deviceCap << cfg.m << cfg.n << cfg.k csv << cfg.deviceCap << cfg.m << cfg.n << cfg.k << transposeNameMapper.getName(cfg.transa)
<< transposeNameMapper.getName(cfg.transa) << transposeNameMapper.getName(cfg.transb) << transposeNameMapper.getName(cfg.transb) << typeNameMapper.getName(cfg.a_type)
<< typeNameMapper.getName(cfg.a_type) << typeNameMapper.getName(cfg.b_type) << typeNameMapper.getName(cfg.d_type) << typeNameMapper.getName(cfg.b_type) << typeNameMapper.getName(cfg.d_type)
<< ((cfg.bias_type == (hipDataType)-1) ? "-" : typeNameMapper.getName(cfg.bias_type)) << ((cfg.bias_type == (hipDataType)-1) ? "-" : typeNameMapper.getName(cfg.bias_type))
<< cfg.lda << cfg.ldb << cfg.ldd << epilogueNameMapper.getName(cfg.epilogue) << cfg.lda << cfg.ldb << cfg.ldd << epilogueNameMapper.getName(cfg.epilogue)
<< computeNameMapper.getName(HIPBLAS_COMPUTE_32F) << typeNameMapper.getName(HIP_R_32F) << computeNameMapper.getName(HIPBLAS_COMPUTE_32F) << typeNameMapper.getName(HIP_R_32F)
<< algo.ws_size_min << algo.ws_size_max << algo.algoId << algo.index << csv_helper::end() << "\n"; << algo.ws_size_min << algo.ws_size_max << algo.algoId << algo.index << csv_helper::end()
<< "\n";
} }
private: private:
std::vector<int> dev_cap; std::vector<int> dev_cap;
constexpr static char csv_sep = ','; constexpr static char csv_sep = ',';
std::unique_ptr<std::ofstream> save_fs; std::unique_ptr<std::ofstream> save_fs;
std::string save_fs_name; std::string save_fs_name;
std::mutex mt; std::mutex mt;
...@@ -1018,23 +936,19 @@ private: ...@@ -1018,23 +936,19 @@ private:
std::multimap<Key, Algo, Key::Comp> d; std::multimap<Key, Algo, Key::Comp> d;
} algoCache; } algoCache;
static inline int getIntEnv(const char *name, int defval, int minval) static inline int getIntEnv(const char* name, int defval, int minval) {
{
int val = defval; int val = defval;
const char* env = std::getenv(name); const char* env = std::getenv(name);
if (env != nullptr && env[0] != '\0') if (env != nullptr && env[0] != '\0') {
{ val = atoi(env);
val = atoi(env); if (val < minval) {
if (val < minval) val = minval;
{ }
val = minval;
}
} }
return val; return val;
} }
} //namespace } //namespace
/* Warning: only call once per device! /* Warning: only call once per device!
* When calling nvte_multi_stream_cublas_gemm with hipblaslt backend * When calling nvte_multi_stream_cublas_gemm with hipblaslt backend
...@@ -1048,39 +962,23 @@ static void init_hipblaslt_handles(hipblasLtHandle_t* hipblaslt_handles) { ...@@ -1048,39 +962,23 @@ static void init_hipblaslt_handles(hipblasLtHandle_t* hipblaslt_handles) {
} }
} }
void hipblaslt_gemm(const Tensor *inputA, void hipblaslt_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
const Tensor *inputB, const Tensor* inputBias, Tensor* outputPreGelu, int m, int n, int k, int lda,
Tensor *outputD, int ldb, int ldd, hipblasOperation_t transa, hipblasOperation_t transb,
const Tensor *inputBias, bool grad, void* workspace, size_t workspaceSize, bool accumulate,
Tensor *outputPreGelu, bool use_split_accumulator, int math_sm_count, int m_split, int n_split,
int m, int n, int k, bool gemm_producer, const Tensor* inputCounter, hipStream_t stream,
int lda, int ldb, int ldd, hipblasLtHandle_t handle) {
hipblasOperation_t transa, void* A = inputA->data.dptr;
hipblasOperation_t transb, void* A_scale_inverse = inputA->scale_inv.dptr;
bool grad, void* B = inputB->data.dptr;
void* workspace, void* B_scale_inverse = inputB->scale_inv.dptr;
size_t workspaceSize, void* D = outputD->data.dptr;
bool accumulate, void* bias_ptr = inputBias->data.dptr;
bool use_split_accumulator,
int math_sm_count,
int m_split,
int n_split,
bool gemm_producer,
const Tensor *inputCounter,
hipStream_t stream,
hipblasLtHandle_t handle
) {
void *A = inputA->data.dptr;
void *A_scale_inverse = inputA->scale_inv.dptr;
void *B = inputB->data.dptr;
void *B_scale_inverse = inputB->scale_inv.dptr;
void *D = outputD->data.dptr;
void *bias_ptr = inputBias->data.dptr;
const bool bias = bias_ptr != nullptr; const bool bias = bias_ptr != nullptr;
void *pre_gelu_out = outputPreGelu->data.dptr; void* pre_gelu_out = outputPreGelu->data.dptr;
const bool gelu = pre_gelu_out != nullptr; const bool gelu = pre_gelu_out != nullptr;
const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) || const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) || is_fp8_dtype(inputB->data.dtype);
is_fp8_dtype(inputB->data.dtype);
const hipDataType A_type = get_hipblaslt_dtype(inputA->data.dtype); const hipDataType A_type = get_hipblaslt_dtype(inputA->data.dtype);
const hipDataType B_type = get_hipblaslt_dtype(inputB->data.dtype); const hipDataType B_type = get_hipblaslt_dtype(inputB->data.dtype);
const hipDataType D_type = get_hipblaslt_dtype(outputD->data.dtype); const hipDataType D_type = get_hipblaslt_dtype(outputD->data.dtype);
...@@ -1106,38 +1004,33 @@ void hipblaslt_gemm(const Tensor *inputA, ...@@ -1106,38 +1004,33 @@ void hipblaslt_gemm(const Tensor *inputA,
if (handle == nullptr) { if (handle == nullptr) {
handle = cached_handles.get(device_id); handle = cached_handles.get(device_id);
if (handle == nullptr) if (handle == nullptr) {
{
handle = cached_handles.obtain(device_id); handle = cached_handles.obtain(device_id);
} }
} }
hipblasLtMatmulDesc_t operationDesc = nullptr; hipblasLtMatmulDesc_t operationDesc = nullptr;
hipblasLtMatrixLayout_t Adesc = nullptr, Bdesc = nullptr, Cdesc = nullptr, Ddesc = nullptr; hipblasLtMatrixLayout_t Adesc = nullptr, Bdesc = nullptr, Cdesc = nullptr, Ddesc = nullptr;
hipblasLtMatmulPreference_t preference = nullptr; hipblasLtMatmulPreference_t preference = nullptr;
hipblasLtEpilogue_t epilogue = HIPBLASLT_EPILOGUE_DEFAULT; hipblasLtEpilogue_t epilogue = HIPBLASLT_EPILOGUE_DEFAULT;
int64_t ld_gelumat = (int64_t) ldd; int64_t ld_gelumat = (int64_t)ldd;
// default to tf32 except for e5m2 inputs where the config is not supported // default to tf32 except for e5m2 inputs where the config is not supported
hipblasComputeType_t gemm_compute_type = HIPBLAS_COMPUTE_32F; hipblasComputeType_t gemm_compute_type = HIPBLAS_COMPUTE_32F;
// Create matrix descriptors. Not setting any extra attributes. // Create matrix descriptors. Not setting any extra attributes.
NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutCreate(&Adesc, A_type, NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutCreate(&Adesc, A_type, transa == HIPBLAS_OP_N ? m : k,
transa == HIPBLAS_OP_N ? m : k, transa == HIPBLAS_OP_N ? k : m, lda));
transa == HIPBLAS_OP_N ? k : m, NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutCreate(&Bdesc, B_type, transb == HIPBLAS_OP_N ? k : n,
lda)); transb == HIPBLAS_OP_N ? n : k, ldb));
NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutCreate(&Bdesc, B_type,
transb == HIPBLAS_OP_N ? k : n,
transb == HIPBLAS_OP_N ? n : k,
ldb));
NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutCreate(&Ddesc, D_type, m, n, ldd)); NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutCreate(&Ddesc, D_type, m, n, ldd));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescCreate(&operationDesc, gemm_compute_type, HIP_R_32F)); NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescCreate(&operationDesc, gemm_compute_type, HIP_R_32F));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc, HIPBLASLT_MATMUL_DESC_TRANSA, NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc, HIPBLASLT_MATMUL_DESC_TRANSA,
&transa, sizeof(transa))); &transa, sizeof(transa)));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc, HIPBLASLT_MATMUL_DESC_TRANSB, NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc, HIPBLASLT_MATMUL_DESC_TRANSB,
&transb, sizeof(transb))); &transb, sizeof(transb)));
// set fp8 attributes -- input and output types should already be set to fp8 as appropriate // set fp8 attributes -- input and output types should already be set to fp8 as appropriate
// Note: gelu fusion isn't available right now, and we don't need // Note: gelu fusion isn't available right now, and we don't need
...@@ -1151,18 +1044,15 @@ void hipblaslt_gemm(const Tensor *inputA, ...@@ -1151,18 +1044,15 @@ void hipblaslt_gemm(const Tensor *inputA,
&fastAccuMode, &fastAccuMode,
sizeof(fastAccuMode))); sizeof(fastAccuMode)));
*/ */
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc, NVTE_CHECK_HIPBLASLT(
HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER, hipblasLtMatmulDescSetAttribute(operationDesc, HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER,
&A_scale_inverse, &A_scale_inverse, sizeof(A_scale_inverse)));
sizeof(A_scale_inverse))); NVTE_CHECK_HIPBLASLT(
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc, hipblasLtMatmulDescSetAttribute(operationDesc, HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER,
HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER, &B_scale_inverse, sizeof(B_scale_inverse)));
&B_scale_inverse,
sizeof(B_scale_inverse)));
if (bias) { if (bias) {
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc, NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(
HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, operationDesc, HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, &bias_type, sizeof(bias_type)));
&bias_type, sizeof(bias_type)));
} }
} }
...@@ -1172,15 +1062,13 @@ void hipblaslt_gemm(const Tensor *inputA, ...@@ -1172,15 +1062,13 @@ void hipblaslt_gemm(const Tensor *inputA,
} else { } else {
epilogue = HIPBLASLT_EPILOGUE_GELU_AUX_BIAS; epilogue = HIPBLASLT_EPILOGUE_GELU_AUX_BIAS;
} }
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
HIPBLASLT_MATMUL_DESC_BIAS_POINTER,
&bias_ptr, sizeof(bias_ptr)));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute( NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(
operationDesc, HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, operationDesc, HIPBLASLT_MATMUL_DESC_BIAS_POINTER, &bias_ptr, sizeof(bias_ptr)));
&pre_gelu_out, sizeof(pre_gelu_out)));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc, NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
&ld_gelumat, sizeof(ld_gelumat))); &pre_gelu_out, sizeof(pre_gelu_out)));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(
operationDesc, HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ld_gelumat, sizeof(ld_gelumat)));
} else if (bias) { } else if (bias) {
if (grad) { if (grad) {
// grad output is always input B // grad output is always input B
...@@ -1188,41 +1076,36 @@ void hipblaslt_gemm(const Tensor *inputA, ...@@ -1188,41 +1076,36 @@ void hipblaslt_gemm(const Tensor *inputA,
} else { } else {
epilogue = HIPBLASLT_EPILOGUE_BIAS; epilogue = HIPBLASLT_EPILOGUE_BIAS;
} }
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc, NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(
HIPBLASLT_MATMUL_DESC_BIAS_POINTER, operationDesc, HIPBLASLT_MATMUL_DESC_BIAS_POINTER, &bias_ptr, sizeof(bias_ptr)));
&bias_ptr, sizeof(bias_ptr)));
} else if (gelu) { } else if (gelu) {
if (grad) { if (grad) {
epilogue = HIPBLASLT_EPILOGUE_DGELU; epilogue = HIPBLASLT_EPILOGUE_DGELU;
} else { } else {
epilogue = HIPBLASLT_EPILOGUE_GELU_AUX; epilogue = HIPBLASLT_EPILOGUE_GELU_AUX;
} }
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(
operationDesc, HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
&pre_gelu_out, sizeof(pre_gelu_out)));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc, NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc,
HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
&ld_gelumat, sizeof(ld_gelumat))); &pre_gelu_out, sizeof(pre_gelu_out)));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(
operationDesc, HIPBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ld_gelumat, sizeof(ld_gelumat)));
} }
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(operationDesc, NVTE_CHECK_HIPBLASLT(hipblasLtMatmulDescSetAttribute(
HIPBLASLT_MATMUL_DESC_EPILOGUE, operationDesc, HIPBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue)));
&epilogue, sizeof(epilogue)));
GemmAlgoCache::Key gemm_cfg(algoCache.device_cap(device_id), A_type, B_type, D_type, GemmAlgoCache::Key gemm_cfg(algoCache.device_cap(device_id), A_type, B_type, D_type,
use_fp8 ? bias_type : (hipDataType)-1, use_fp8 ? bias_type : (hipDataType)-1, m, n, k, lda, ldb, ldd, transa,
m, n, k, lda, ldb, ldd, transa, transb, epilogue ); transb, epilogue);
GemmAlgoCache::Algo cached_algo; GemmAlgoCache::Algo cached_algo;
if (algoCache.find(gemm_cfg, workspaceSize, cached_algo) == 0 || !cached_algo.algo.has_value()) if (algoCache.find(gemm_cfg, workspaceSize, cached_algo) == 0 || !cached_algo.algo.has_value()) {
{
int firstAlgo = getIntEnv("TE_HIPBLASLT_ALGO_SELECTION", 0, 0); int firstAlgo = getIntEnv("TE_HIPBLASLT_ALGO_SELECTION", 0, 0);
int tuneLoopCount = getIntEnv("TE_HIPBLASLT_TUNING_RUN_COUNT", 0, 0); int tuneLoopCount = getIntEnv("TE_HIPBLASLT_TUNING_RUN_COUNT", 0, 0);
int algoTuneCount = 1; int algoTuneCount = 1;
std::vector<hipblasLtMatmulHeuristicResult_t> algoArr; std::vector<hipblasLtMatmulHeuristicResult_t> algoArr;
bool logTuning = getIntEnv("TE_HIPBLASLT_LOG_TUNING", 0, 0) != 0; bool logTuning = getIntEnv("TE_HIPBLASLT_LOG_TUNING", 0, 0) != 0;
if (tuneLoopCount) if (tuneLoopCount) {
{
/* HIPBLASLT may return hundreds of algos for some configs /* HIPBLASLT may return hundreds of algos for some configs
* Limit amount by default. User may override with env * Limit amount by default. User may override with env
*/ */
...@@ -1230,35 +1113,31 @@ void hipblaslt_gemm(const Tensor *inputA, ...@@ -1230,35 +1113,31 @@ void hipblaslt_gemm(const Tensor *inputA,
algoTuneCount = getIntEnv("TE_HIPBLASLT_TUNING_ALGO_COUNT", defaultAlgoCount, 1); algoTuneCount = getIntEnv("TE_HIPBLASLT_TUNING_ALGO_COUNT", defaultAlgoCount, 1);
} }
algoTuneCount += firstAlgo; algoTuneCount += firstAlgo;
int algoTotalCount = cached_algo.hasId() ? std::max(algoTuneCount, (cached_algo.index + 1)) : algoTuneCount; int algoTotalCount =
cached_algo.hasId() ? std::max(algoTuneCount, (cached_algo.index + 1)) : algoTuneCount;
algoArr.resize(algoTotalCount); algoArr.resize(algoTotalCount);
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulPreferenceCreate(&preference)); NVTE_CHECK_HIPBLASLT(hipblasLtMatmulPreferenceCreate(&preference));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulPreferenceSetAttribute( NVTE_CHECK_HIPBLASLT(
preference, HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, hipblasLtMatmulPreferenceSetAttribute(preference, HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&workspaceSize, sizeof(workspaceSize))); &workspaceSize, sizeof(workspaceSize)));
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Ddesc, NVTE_CHECK_HIPBLASLT(hipblasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Ddesc,
Ddesc, preference, algoTotalCount, algoArr.data(), Ddesc, preference, algoTotalCount,
&algoTotalCount)); algoArr.data(), &algoTotalCount));
algoArr.resize(algoTotalCount); algoArr.resize(algoTotalCount);
NVTE_CHECK_HIPBLASLT(hipblasLtMatmulPreferenceDestroy(preference)); NVTE_CHECK_HIPBLASLT(hipblasLtMatmulPreferenceDestroy(preference));
//If cached algo exists in persistent storage we just need to find matching hipblasLtMatmulAlgo_t //If cached algo exists in persistent storage we just need to find matching hipblasLtMatmulAlgo_t
if (cached_algo.hasId()) if (cached_algo.hasId()) {
{
int idx = (cached_algo.index < algoTotalCount) ? cached_algo.index : 0; int idx = (cached_algo.index < algoTotalCount) ? cached_algo.index : 0;
for (int i=0; i<algoTotalCount; i++) for (int i = 0; i < algoTotalCount; i++) {
{ const auto& algo = algoArr[idx];
const auto &algo = algoArr[idx]; if (algo.state == HIPBLAS_STATUS_SUCCESS) {
if (algo.state == HIPBLAS_STATUS_SUCCESS) if (cached_algo.algoId == cached_algo.getAlgoId(algo.algo)) {
{
if (cached_algo.algoId == cached_algo.getAlgoId(algo.algo))
{
cached_algo.algo = algo.algo; cached_algo.algo = algo.algo;
if (algo.workspaceSize != cached_algo.ws_size_min || idx != cached_algo.index) if (algo.workspaceSize != cached_algo.ws_size_min || idx != cached_algo.index) {
{
cached_algo.ws_size_min = algo.workspaceSize; cached_algo.ws_size_min = algo.workspaceSize;
cached_algo.index = idx; cached_algo.index = idx;
algoCache.store(gemm_cfg, cached_algo); algoCache.store(gemm_cfg, cached_algo);
...@@ -1268,20 +1147,17 @@ void hipblaslt_gemm(const Tensor *inputA, ...@@ -1268,20 +1147,17 @@ void hipblaslt_gemm(const Tensor *inputA,
} }
idx = (idx + 1) % algoTotalCount; idx = (idx + 1) % algoTotalCount;
} }
if (logTuning && !cached_algo.algo.has_value()) if (logTuning && !cached_algo.algo.has_value()) {
{ std::cout << "[WARNING] Cannot find cached algoId " << cached_algo.algoId
std::cout << "[WARNING] Cannot find cached algoId " << cached_algo.algoId << " in hipBLASLt results" << std::endl; << " in hipBLASLt results" << std::endl;
} }
} }
//No suitable entry in autotune cache or could not find matched algo in hipBLASLt results //No suitable entry in autotune cache or could not find matched algo in hipBLASLt results
if (!cached_algo.algo.has_value()) if (!cached_algo.algo.has_value()) {
{
int bestAlgo = -1; int bestAlgo = -1;
algoTuneCount = std::min(algoTuneCount, algoTotalCount); algoTuneCount = std::min(algoTuneCount, algoTotalCount);
if (tuneLoopCount > 0) if (tuneLoopCount > 0) {
{
if (logTuning) if (logTuning)
std::cout << "[INFO] Perform hipBLASLt algo selection on GPU" << device_id std::cout << "[INFO] Perform hipBLASLt algo selection on GPU" << device_id
<< " in range [" << firstAlgo << "-" << (algoTuneCount - 1) << "] with " << " in range [" << firstAlgo << "-" << (algoTuneCount - 1) << "] with "
...@@ -1291,75 +1167,57 @@ void hipblaslt_gemm(const Tensor *inputA, ...@@ -1291,75 +1167,57 @@ void hipblaslt_gemm(const Tensor *inputA,
hipStream_t profilingStream; hipStream_t profilingStream;
NVTE_CHECK_CUDA(hipStreamCreateWithFlags(&profilingStream, hipStreamNonBlocking)); NVTE_CHECK_CUDA(hipStreamCreateWithFlags(&profilingStream, hipStreamNonBlocking));
using tuning_clock = std::chrono::steady_clock; using tuning_clock = std::chrono::steady_clock;
tuning_clock::now(); //the first call takes little longer so do it outside the loop tuning_clock::now(); //the first call takes little longer so do it outside the loop
tuning_clock::duration bestTime = tuning_clock::duration::max(); tuning_clock::duration bestTime = tuning_clock::duration::max();
for (int algo=firstAlgo; algo<algoTuneCount; algo++) for (int algo = firstAlgo; algo < algoTuneCount; algo++) {
{ if (algoArr[algo].state != HIPBLAS_STATUS_SUCCESS) {
if (algoArr[algo].state != HIPBLAS_STATUS_SUCCESS) continue;
{ }
continue; // Warm-up call
} NVTE_CHECK_HIPBLASLT(hipblasLtMatmul(handle, operationDesc,
// Warm-up call static_cast<const void*>(&one), /* alpha */
NVTE_CHECK_HIPBLASLT(hipblasLtMatmul(handle, A, /* A */
operationDesc, Adesc, B, /* B */
static_cast<const void*>(&one), /* alpha */ Bdesc, static_cast<const void*>(&beta), /* beta */
A, /* A */ D, /* C */
Adesc, Ddesc, D, /* D */
B, /* B */ Ddesc, &algoArr[algo].algo, /* algo */
Bdesc, workspace, /* workspace */
static_cast<const void*>(&beta), /* beta */ workspaceSize, profilingStream)); /* stream */
D, /* C */
Ddesc,
D, /* D */
Ddesc,
&algoArr[algo].algo, /* algo */
workspace, /* workspace */
workspaceSize,
profilingStream)); /* stream */
NVTE_CHECK_CUDA(hipStreamSynchronize(profilingStream)); NVTE_CHECK_CUDA(hipStreamSynchronize(profilingStream));
//Profiling loop //Profiling loop
tuning_clock::time_point startTime = tuning_clock::now(); tuning_clock::time_point startTime = tuning_clock::now();
for (int loop=0; loop<tuneLoopCount; loop++) for (int loop = 0; loop < tuneLoopCount; loop++) {
{ NVTE_CHECK_HIPBLASLT(hipblasLtMatmul(handle, operationDesc,
NVTE_CHECK_HIPBLASLT(hipblasLtMatmul(handle, static_cast<const void*>(&one), /* alpha */
operationDesc, A, /* A */
static_cast<const void*>(&one), /* alpha */ Adesc, B, /* B */
A, /* A */ Bdesc, static_cast<const void*>(&beta), /* beta */
Adesc, D, /* C */
B, /* B */ Ddesc, D, /* D */
Bdesc, Ddesc, &algoArr[algo].algo, /* algo */
static_cast<const void*>(&beta), /* beta */ workspace, /* workspace */
D, /* C */ workspaceSize, profilingStream)); /* stream */
Ddesc,
D, /* D */
Ddesc,
&algoArr[algo].algo, /* algo */
workspace, /* workspace */
workspaceSize,
profilingStream)); /* stream */
} }
NVTE_CHECK_CUDA(hipStreamSynchronize(profilingStream)); NVTE_CHECK_CUDA(hipStreamSynchronize(profilingStream));
tuning_clock::duration algoTime = tuning_clock::now() - startTime; tuning_clock::duration algoTime = tuning_clock::now() - startTime;
if (algoTime < bestTime) if (algoTime < bestTime) {
{
bestAlgo = algo; bestAlgo = algo;
bestTime = algoTime; bestTime = algoTime;
} }
} }
NVTE_CHECK_CUDA(hipStreamDestroy(profilingStream)); NVTE_CHECK_CUDA(hipStreamDestroy(profilingStream));
if (bestAlgo >= 0) if (bestAlgo >= 0) {
{
if (logTuning) if (logTuning)
std::cout << "[INFO] Select hipBLASLt algo " << bestAlgo << " with time " std::cout << "[INFO] Select hipBLASLt algo " << bestAlgo << " with time "
<< std::chrono::duration_cast<std::chrono::nanoseconds>(bestTime).count() / tuneLoopCount << std::chrono::duration_cast<std::chrono::nanoseconds>(bestTime).count() /
tuneLoopCount
<< " ns" << std::endl; << " ns" << std::endl;
} }
} } else if (firstAlgo < algoTuneCount) {
else if (firstAlgo < algoTuneCount)
{
bestAlgo = firstAlgo; bestAlgo = firstAlgo;
} }
...@@ -1377,30 +1235,24 @@ void hipblaslt_gemm(const Tensor *inputA, ...@@ -1377,30 +1235,24 @@ void hipblaslt_gemm(const Tensor *inputA,
cached_algo.ws_size_max = workspaceSize; cached_algo.ws_size_max = workspaceSize;
if (logTuning) if (logTuning)
std::cout << "[INFO] Use hipBLASLt algo [" << bestAlgo << "] " << cached_algo.algoId << std::endl; std::cout << "[INFO] Use hipBLASLt algo [" << bestAlgo << "] " << cached_algo.algoId
<< std::endl;
algoCache.store(gemm_cfg, cached_algo); algoCache.store(gemm_cfg, cached_algo);
} }
} }
// D = alpha * (A * B) + beta * C // D = alpha * (A * B) + beta * C
NVTE_CHECK_HIPBLASLT(hipblasLtMatmul(handle, NVTE_CHECK_HIPBLASLT(hipblasLtMatmul(handle, operationDesc,
operationDesc, static_cast<const void*>(&one), /* alpha */
static_cast<const void*>(&one), /* alpha */ A, /* A */
A, /* A */ Adesc, B, /* B */
Adesc, Bdesc, static_cast<const void*>(&beta), /* beta */
B, /* B */ D, /* C */
Bdesc, Ddesc, D, /* D */
static_cast<const void*>(&beta), /* beta */ Ddesc, &cached_algo.algo.value(), /* algo */
D, /* C */ workspace, /* workspace */
Ddesc, workspaceSize, stream)); /* stream */
D, /* D */
Ddesc,
&cached_algo.algo.value(), /* algo */
workspace, /* workspace */
workspaceSize,
stream)); /* stream */
NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Ddesc)); NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Ddesc));
NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Bdesc)); NVTE_CHECK_HIPBLASLT(hipblasLtMatrixLayoutDestroy(Bdesc));
...@@ -1409,118 +1261,86 @@ void hipblaslt_gemm(const Tensor *inputA, ...@@ -1409,118 +1261,86 @@ void hipblaslt_gemm(const Tensor *inputA,
} }
class userArgsManager { class userArgsManager {
public: public:
userArgsManager() {} userArgsManager() {}
~userArgsManager() { ~userArgsManager() {
// Release all userArgs when the manager is destroyed // Release all userArgs when the manager is destroyed
for (auto& device_pair : userArgs_map_) { for (auto& device_pair : userArgs_map_) {
hipFree(device_pair.second); // Only one userArgs per device hipFree(device_pair.second); // Only one userArgs per device
}
} }
}
// Get a userArgs for the given device (creates if necessary) // Get a userArgs for the given device (creates if necessary)
hipblaslt_ext::UserArguments* get(int device_id, size_t size) { hipblaslt_ext::UserArguments* get(int device_id, size_t size) {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
// Check if the userArgs for this device exists // Check if the userArgs for this device exists
auto device_it = userArgs_map_.find(device_id); auto device_it = userArgs_map_.find(device_id);
if (device_it != userArgs_map_.end()) { if (device_it != userArgs_map_.end()) {
return device_it->second; return device_it->second;
} }
// Create a new userArgs for this device if it doesn't exist // Create a new userArgs for this device if it doesn't exist
hipblaslt_ext::UserArguments* userArgs; hipblaslt_ext::UserArguments* userArgs;
NVTE_CHECK_CUDA(hipHostMalloc(&userArgs, size * sizeof(hipblaslt_ext::UserArguments))); NVTE_CHECK_CUDA(hipHostMalloc(&userArgs, size * sizeof(hipblaslt_ext::UserArguments)));
// Store the userArgs in the map for this device // Store the userArgs in the map for this device
userArgs_map_[device_id] = userArgs; userArgs_map_[device_id] = userArgs;
return userArgs; return userArgs;
} }
private: private:
std::unordered_map<int, hipblaslt_ext::UserArguments*> userArgs_map_; // Map from device_id to hipblasHandle std::unordered_map<int, hipblaslt_ext::UserArguments*>
std::mutex mutex_; userArgs_map_; // Map from device_id to hipblasHandle
std::mutex mutex_;
}; };
class d_userArgsManager { class d_userArgsManager {
public: public:
d_userArgsManager() {} d_userArgsManager() {}
~d_userArgsManager() {
// Release all userArgs when the manager is destroyed
for (auto& device_pair : d_userArgs_map_) {
hipFree(device_pair.second); // Only one userArgs per device
}
}
// Get a userArgs for the given device (creates if necessary)
hipblaslt_ext::UserArguments* get(int device_id, size_t size) {
std::lock_guard<std::mutex> lock(mutex_);
// Check if the userArgs for this device exists ~d_userArgsManager() {
auto device_it = d_userArgs_map_.find(device_id); // Release all userArgs when the manager is destroyed
if (device_it != d_userArgs_map_.end()) { for (auto& device_pair : d_userArgs_map_) {
return device_it->second; hipFree(device_pair.second); // Only one userArgs per device
}
// Create a new userArgs for this device if it doesn't exist
hipblaslt_ext::UserArguments* d_userArgs;
NVTE_CHECK_CUDA(hipMalloc(&d_userArgs, size * sizeof(hipblaslt_ext::UserArguments)));
// Store the userArgs in the map for this device
d_userArgs_map_[device_id] = d_userArgs;
return d_userArgs;
} }
}
private: // Get a userArgs for the given device (creates if necessary)
std::unordered_map<int, hipblaslt_ext::UserArguments*> d_userArgs_map_; // Map from device_id to hipblasHandle hipblaslt_ext::UserArguments* get(int device_id, size_t size) {
std::mutex mutex_; std::lock_guard<std::mutex> lock(mutex_);
};
class tmp_userArgsManager {
public:
tmp_userArgsManager() {}
~tmp_userArgsManager() { // Check if the userArgs for this device exists
// Release all userArgs when the manager is destroyed auto device_it = d_userArgs_map_.find(device_id);
for (auto& device_pair : tmp_userArgs_map_) { if (device_it != d_userArgs_map_.end()) {
hipFree(device_pair.second); // Only one userArgs per device return device_it->second;
}
} }
// Get a userArgs for the given device (creates if necessary) // Create a new userArgs for this device if it doesn't exist
void* get(int device_id, size_t size) { hipblaslt_ext::UserArguments* d_userArgs;
std::lock_guard<std::mutex> lock(mutex_); NVTE_CHECK_CUDA(hipMalloc(&d_userArgs, size * sizeof(hipblaslt_ext::UserArguments)));
// Check if the userArgs for this device exists
auto device_it = tmp_userArgs_map_.find(device_id);
if (device_it != tmp_userArgs_map_.end()) {
return device_it->second;
}
// Create a new userArgs for this device if it doesn't exist
void* tmp_userArgs;
NVTE_CHECK_CUDA(hipHostMalloc(&tmp_userArgs, size));
// Store the userArgs in the map for this device // Store the userArgs in the map for this device
tmp_userArgs_map_[device_id] = tmp_userArgs; d_userArgs_map_[device_id] = d_userArgs;
return tmp_userArgs; return d_userArgs;
} }
private: private:
std::unordered_map<int, void*> tmp_userArgs_map_; // Map from device_id to hipblasHandle std::unordered_map<int, hipblaslt_ext::UserArguments*>
std::mutex mutex_; d_userArgs_map_; // Map from device_id to hipblasHandle
std::mutex mutex_;
}; };
// Define a static userArgs manager // Define a static userArgs manager
static userArgsManager UAManager; static userArgsManager UAManager;
static d_userArgsManager d_UAManager; static d_userArgsManager d_UAManager;
static tmp_userArgsManager tmp_UAManager;
void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const Tensor*>& inputB, std::vector<Tensor*>& outputD, void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const Tensor*>& inputB,
std::vector<int64_t>& m, std::vector<int64_t>& n, std::vector<int64_t>& k, std::vector<int64_t>& b, hipblasOperation_t transa, hipblasOperation_t transb, std::vector<Tensor*>& outputD, std::vector<int64_t>& m,
void *workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, std::vector<int64_t>& n, std::vector<int64_t>& k, std::vector<int64_t>& b,
hipblasOperation_t transa, hipblasOperation_t transb, void* workspace,
size_t workspaceSize, bool accumulate, bool use_split_accumulator,
int math_sm_count, hipStream_t stream, int compute_stream_offset = 0) { int math_sm_count, hipStream_t stream, int compute_stream_offset = 0) {
// Check compute_stream_offset valid. // Check compute_stream_offset valid.
NVTE_CHECK(compute_stream_offset >= -1 && compute_stream_offset < compute_num_streams); NVTE_CHECK(compute_stream_offset >= -1 && compute_stream_offset < compute_num_streams);
...@@ -1529,7 +1349,6 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const ...@@ -1529,7 +1349,6 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
hipGetDevice(&device_id); hipGetDevice(&device_id);
hipblaslt_ext::UserArguments* userArgs = UAManager.get(device_id, m.size()); hipblaslt_ext::UserArguments* userArgs = UAManager.get(device_id, m.size());
hipblaslt_ext::UserArguments* d_userArgs = d_UAManager.get(device_id, m.size()); hipblaslt_ext::UserArguments* d_userArgs = d_UAManager.get(device_id, m.size());
void* tmp_userArgs = tmp_UAManager.get(device_id, 32768);
// hipblaslt_ext::UserArguments* userArgs; // hipblaslt_ext::UserArguments* userArgs;
// NVTE_CHECK_CUDA(hipHostMalloc(&userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments))); // NVTE_CHECK_CUDA(hipHostMalloc(&userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments)));
...@@ -1557,63 +1376,53 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const ...@@ -1557,63 +1376,53 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
int int_zero = 0; int int_zero = 0;
int int_beta = int_zero; int int_beta = int_zero;
bool use_int8 = false; bool use_int8 = false;
if ((A_type == HIP_R_8I) && (B_type == HIP_R_8I) && (D_type == HIP_R_32I)) { if ((A_type == HIP_R_8I) && (B_type == HIP_R_8I) && (D_type == HIP_R_32I)) {
NVTE_CHECK(!accumulate, "Int8 gemm not support accumulate."); NVTE_CHECK(!accumulate, "Int8 gemm not support accumulate.");
use_int8 = true; use_int8 = true;
computeType = HIPBLAS_COMPUTE_32I; computeType = HIPBLAS_COMPUTE_32I;
} }
hipblaslt_ext::GemmPreference gemmPref; hipblaslt_ext::GemmPreference gemmPref;
gemmPref.setMaxWorkspaceBytes(workspaceSize); gemmPref.setMaxWorkspaceBytes(workspaceSize);
hipblaslt_ext::GroupedGemm groupedgemm(handle, hipblaslt_ext::GroupedGemm groupedgemm(handle, transa, transb, A_type, B_type, D_type, D_type,
transa, computeType);
transb,
A_type,
B_type,
D_type,
D_type,
computeType,
tmp_userArgs);
std::vector<hipblaslt_ext::GemmEpilogue> epilogue{ std::vector<hipblaslt_ext::GemmEpilogue> epilogue{
hipblaslt_ext:: hipblaslt_ext::
GemmEpilogue()}; // No action needed, default is HIPBLASLT_EPILOGUE_DEFAULT. (Gemm only) GemmEpilogue()}; // No action needed, default is HIPBLASLT_EPILOGUE_DEFAULT. (Gemm only)
std::vector<hipblaslt_ext::GemmInputs> inputs(m.size()); std::vector<hipblaslt_ext::GemmInputs> inputs(m.size());
for(int i = 0; i < m.size(); i++) for (int i = 0; i < m.size(); i++) {
{ inputs[i].a = inputA[i]->data.dptr;
inputs[i].a = inputA[i]->data.dptr; inputs[i].b = inputB[i]->data.dptr;
inputs[i].b = inputB[i]->data.dptr; inputs[i].c = outputD[i]->data.dptr;
inputs[i].c = outputD[i]->data.dptr; inputs[i].d = outputD[i]->data.dptr;
inputs[i].d = outputD[i]->data.dptr; inputs[i].alpha = use_int8 ? static_cast<void*>(&int_one) : static_cast<void*>(&one);
inputs[i].alpha = use_int8 ? static_cast<void*>(&int_one) : static_cast<void*>(&one); inputs[i].beta = use_int8 ? static_cast<void*>(&int_beta) : static_cast<void*>(&beta);
inputs[i].beta = use_int8 ? static_cast<void*>(&int_beta) : static_cast<void*>(&beta);
} }
// hipblaslt_ext::GemmEpilogue supports broadcasting // hipblaslt_ext::GemmEpilogue supports broadcasting
groupedgemm.setProblem(m, n, k, b, epilogue, inputs); groupedgemm.setProblem(m, n, k, b, epilogue, inputs);
const int request_solutions = 1; const int request_solutions = 1;
std::vector<hipblasLtMatmulHeuristicResult_t> heuristicResult; std::vector<hipblasLtMatmulHeuristicResult_t> heuristicResult;
NVTE_CHECK_HIPBLASLT( NVTE_CHECK_HIPBLASLT(groupedgemm.algoGetHeuristic(request_solutions, gemmPref, heuristicResult));
groupedgemm.algoGetHeuristic(request_solutions, gemmPref, heuristicResult));
if(heuristicResult.empty()) if (heuristicResult.empty()) {
{ std::cerr << "No valid solution found!" << std::endl;
std::cerr << "No valid solution found!" << std::endl; return;
return;
} }
// Make sure to initialize everytime the algo changes // Make sure to initialize everytime the algo changes
NVTE_CHECK_HIPBLASLT(groupedgemm.initialize(heuristicResult[0].algo, workspace)); NVTE_CHECK_HIPBLASLT(groupedgemm.initialize(heuristicResult[0].algo, workspace));
// Get the default values from the grouepdgemm object // Get the default values from the grouepdgemm object
groupedgemm.getDefaultValueForDeviceUserArguments(userArgs); groupedgemm.getDefaultValueForDeviceUserArguments(userArgs);
// Copy them to device memory // Copy them to device memory
// hipblaslt_ext::UserArguments* d_userArgs; // hipblaslt_ext::UserArguments* d_userArgs;
// NVTE_CHECK_CUDA(hipMallocAsync(&d_userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments), stream)); // NVTE_CHECK_CUDA(hipMallocAsync(&d_userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments), stream));
NVTE_CHECK_CUDA(hipMemcpy(d_userArgs, NVTE_CHECK_CUDA(hipMemcpy(d_userArgs, userArgs, m.size() * sizeof(hipblaslt_ext::UserArguments),
userArgs,
m.size() * sizeof(hipblaslt_ext::UserArguments),
hipMemcpyHostToDevice)); hipMemcpyHostToDevice));
NVTE_CHECK_HIPBLASLT(groupedgemm.run(d_userArgs, stream)); NVTE_CHECK_HIPBLASLT(groupedgemm.run(d_userArgs, stream));
// NVTE_CHECK_HIPBLASLT(groupedgemm.initialize(heuristicResult[0].algo, workspace, false, stream)); // NVTE_CHECK_HIPBLASLT(groupedgemm.initialize(heuristicResult[0].algo, workspace, false, stream));
// NVTE_CHECK_HIPBLASLT(groupedgemm.run(stream)); // NVTE_CHECK_HIPBLASLT(groupedgemm.run(stream));
...@@ -1622,66 +1431,49 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const ...@@ -1622,66 +1431,49 @@ void hipblaslt_goupedgemm(std::vector<const Tensor*>& inputA, std::vector<const
// NVTE_CHECK_CUDA(hipFree(userArgs)); // NVTE_CHECK_CUDA(hipFree(userArgs));
} }
#endif //USE_HIPBLASLT #endif //USE_HIPBLASLT
#ifdef USE_ROCBLAS // Use rocblas + kernel, no fusion #ifdef USE_ROCBLAS // Use rocblas + kernel, no fusion
inline void CreateRocblasHandle(rocblas_handle *handle) { inline void CreateRocblasHandle(rocblas_handle* handle) {
NVTE_CHECK_ROCBLAS(rocblas_create_handle(handle)); NVTE_CHECK_ROCBLAS(rocblas_create_handle(handle));
} }
using rocblasHandleManager = detail::HandleManager<rocblas_handle, CreateRocblasHandle>; using rocblasHandleManager = detail::HandleManager<rocblas_handle, CreateRocblasHandle>;
void rocblas_gemm(const Tensor *inputA, void rocblas_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
const Tensor *inputB, const Tensor* inputBias, Tensor* outputPreGelu, int m, int n, int k, int lda,
Tensor *outputD, int ldb, int ldd, rocblas_operation transa, rocblas_operation transb, bool grad,
const Tensor *inputBias, void* workspace, size_t workspaceSize, bool accumulate,
Tensor *outputPreGelu, bool use_split_accumulator, int math_sm_count, int m_split, int n_split,
int m, int n, int k, bool gemm_producer, const Tensor* inputCounter, hipStream_t stream) {
int lda, int ldb, int ldd, void* A = inputA->data.dptr;
rocblas_operation transa, void* A_scale_inverse = inputA->scale_inv.dptr;
rocblas_operation transb, void* B = inputB->data.dptr;
bool grad, void* B_scale_inverse = inputB->scale_inv.dptr;
void* workspace, void* C = outputD->data.dptr;
size_t workspaceSize, void* D = outputD->data.dptr;
bool accumulate, void* D_scale = outputD->scale.dptr;
bool use_split_accumulator, void* D_amax = outputD->amax.dptr;
int math_sm_count, void* bias_ptr = inputBias->data.dptr;
int m_split,
int n_split,
bool gemm_producer,
const Tensor *inputCounter,
hipStream_t stream
) {
void *A = inputA->data.dptr;
void *A_scale_inverse = inputA->scale_inv.dptr;
void *B = inputB->data.dptr;
void *B_scale_inverse = inputB->scale_inv.dptr;
void *C = outputD->data.dptr;
void *D = outputD->data.dptr;
void *D_scale = outputD->scale.dptr;
void *D_amax = outputD->amax.dptr;
void *bias_ptr = inputBias->data.dptr;
const bool bias = bias_ptr != nullptr; const bool bias = bias_ptr != nullptr;
void *pre_gelu_out = outputPreGelu->data.dptr; void* pre_gelu_out = outputPreGelu->data.dptr;
const bool gelu = pre_gelu_out != nullptr; const bool gelu = pre_gelu_out != nullptr;
const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) || const bool use_fp8 = is_fp8_dtype(inputA->data.dtype) || is_fp8_dtype(inputB->data.dtype);
is_fp8_dtype(inputB->data.dtype);
const rocblas_datatype A_type = get_rocblas_dtype(inputA->data.dtype); const rocblas_datatype A_type = get_rocblas_dtype(inputA->data.dtype);
const rocblas_datatype B_type = get_rocblas_dtype(inputB->data.dtype); const rocblas_datatype B_type = get_rocblas_dtype(inputB->data.dtype);
const rocblas_datatype D_type = get_rocblas_dtype(outputD->data.dtype); const rocblas_datatype D_type = get_rocblas_dtype(outputD->data.dtype);
const rocblas_datatype bias_type = get_rocblas_dtype(inputBias->data.dtype); const rocblas_datatype bias_type = get_rocblas_dtype(inputBias->data.dtype);
const rocblas_datatype gelu_type = get_rocblas_dtype(outputPreGelu->data.dtype); const rocblas_datatype gelu_type = get_rocblas_dtype(outputPreGelu->data.dtype);
// check consistency of arguments: // check consistency of arguments:
// if fp8 is desired, context cannot be null // if fp8 is desired, context cannot be null
// fp8 + gelu fusion + fp8 aux is unavailable right now. // fp8 + gelu fusion + fp8 aux is unavailable right now.
if (use_fp8 && gelu) { if (use_fp8 && gelu) {
NVTE_CHECK(!is_fp8_dtype(outputPreGelu->data.dtype), NVTE_CHECK(!is_fp8_dtype(outputPreGelu->data.dtype),
"fp8 Aux output for gemm + gelu fusion not supported!"); "fp8 Aux output for gemm + gelu fusion not supported!");
} }
if (is_fp8_dtype(outputD->data.dtype)) { if (is_fp8_dtype(outputD->data.dtype)) {
NVTE_CHECK(!accumulate, NVTE_CHECK(!accumulate, "Accumulation mode not supported with FP8 GEMM output!");
"Accumulation mode not supported with FP8 GEMM output!");
} }
// fp8 + grad unavailable in upstream // fp8 + grad unavailable in upstream
NVTE_CHECK(!(use_fp8 && grad), "fp8 + grad not supported!"); NVTE_CHECK(!(use_fp8 && grad), "fp8 + grad not supported!");
...@@ -1692,120 +1484,136 @@ void rocblas_gemm(const Tensor *inputA, ...@@ -1692,120 +1484,136 @@ void rocblas_gemm(const Tensor *inputA,
float alpha = 1.0; float alpha = 1.0;
if (use_fp8) { if (use_fp8) {
float A_scale_inv, B_scale_inv; float A_scale_inv, B_scale_inv;
(void)hipMemcpy(&A_scale_inv, A_scale_inverse, sizeof(float), hipMemcpyDeviceToHost); (void)hipMemcpy(&A_scale_inv, A_scale_inverse, sizeof(float), hipMemcpyDeviceToHost);
(void)hipMemcpy(&B_scale_inv, B_scale_inverse, sizeof(float), hipMemcpyDeviceToHost); (void)hipMemcpy(&B_scale_inv, B_scale_inverse, sizeof(float), hipMemcpyDeviceToHost);
alpha = A_scale_inv * B_scale_inv; alpha = A_scale_inv * B_scale_inv;
} }
rocblas_handle handle = rocblasHandleManager::Instance().GetHandle(); rocblas_handle handle = rocblasHandleManager::Instance().GetHandle();
NVTE_CHECK_ROCBLAS(rocblas_set_stream(handle, stream)); NVTE_CHECK_ROCBLAS(rocblas_set_stream(handle, stream));
// extract the stream order alloc env // extract the stream order alloc env
bool stream_order_alloc = false; bool stream_order_alloc = false;
if (const char* env_p = std::getenv("ROCBLAS_STREAM_ORDER_ALLOC") ) { if (const char* env_p = std::getenv("ROCBLAS_STREAM_ORDER_ALLOC")) {
if (env_p == nullptr || std::string(env_p) == "1") if (env_p == nullptr || std::string(env_p) == "1") stream_order_alloc = true;
stream_order_alloc = true; }
}
int64_t ld_gelumat = (int64_t)ldd;
int64_t ld_gelumat = (int64_t) ldd;
NVTE_CHECK((A_type == rocblas_datatype_f16_r && B_type == rocblas_datatype_f16_r &&
D_type == rocblas_datatype_f16_r) ||
NVTE_CHECK((A_type==rocblas_datatype_f16_r && B_type==rocblas_datatype_f16_r && D_type==rocblas_datatype_f16_r) || (A_type == rocblas_datatype_f16_r && B_type == rocblas_datatype_f16_r &&
(A_type==rocblas_datatype_f16_r && B_type==rocblas_datatype_f16_r && D_type==rocblas_datatype_f32_r) || D_type == rocblas_datatype_f32_r) ||
(A_type==rocblas_datatype_bf16_r && B_type==rocblas_datatype_bf16_r && D_type==rocblas_datatype_bf16_r) || (A_type == rocblas_datatype_bf16_r && B_type == rocblas_datatype_bf16_r &&
(A_type==rocblas_datatype_bf16_r && B_type==rocblas_datatype_bf16_r && D_type==rocblas_datatype_f32_r) || D_type == rocblas_datatype_bf16_r) ||
(A_type==rocblas_datatype_f32_r && B_type==rocblas_datatype_f32_r && D_type==rocblas_datatype_f32_r) || (A_type == rocblas_datatype_bf16_r && B_type == rocblas_datatype_bf16_r &&
(A_type==rocblas_datatype_f8_r && B_type==rocblas_datatype_f8_r && D_type==rocblas_datatype_f32_r) || D_type == rocblas_datatype_f32_r) ||
(A_type==rocblas_datatype_f8_r && B_type==rocblas_datatype_f8_r && D_type==rocblas_datatype_f16_r) || (A_type == rocblas_datatype_f32_r && B_type == rocblas_datatype_f32_r &&
(A_type==rocblas_datatype_f8_r && B_type==rocblas_datatype_f8_r && D_type==rocblas_datatype_bf16_r) || D_type == rocblas_datatype_f32_r) ||
(A_type==rocblas_datatype_f8_r && B_type==rocblas_datatype_f8_r && D_type==rocblas_datatype_f8_r) || (A_type == rocblas_datatype_f8_r && B_type == rocblas_datatype_f8_r &&
(A_type==rocblas_datatype_f8_r && B_type==rocblas_datatype_f8_r && D_type==rocblas_datatype_bf8_r) || D_type == rocblas_datatype_f32_r) ||
(A_type==rocblas_datatype_f8_r && B_type==rocblas_datatype_bf8_r && D_type==rocblas_datatype_f32_r) || (A_type == rocblas_datatype_f8_r && B_type == rocblas_datatype_f8_r &&
(A_type==rocblas_datatype_f8_r && B_type==rocblas_datatype_bf8_r && D_type==rocblas_datatype_f16_r) || D_type == rocblas_datatype_f16_r) ||
(A_type==rocblas_datatype_f8_r && B_type==rocblas_datatype_bf8_r && D_type==rocblas_datatype_bf16_r) || (A_type == rocblas_datatype_f8_r && B_type == rocblas_datatype_f8_r &&
(A_type==rocblas_datatype_f8_r && B_type==rocblas_datatype_bf8_r && D_type==rocblas_datatype_f8_r) || D_type == rocblas_datatype_bf16_r) ||
(A_type==rocblas_datatype_f8_r && B_type==rocblas_datatype_bf8_r && D_type==rocblas_datatype_bf8_r) || (A_type == rocblas_datatype_f8_r && B_type == rocblas_datatype_f8_r &&
(A_type==rocblas_datatype_bf8_r && B_type==rocblas_datatype_f8_r && D_type==rocblas_datatype_f32_r) || D_type == rocblas_datatype_f8_r) ||
(A_type==rocblas_datatype_bf8_r && B_type==rocblas_datatype_f8_r && D_type==rocblas_datatype_f16_r) || (A_type == rocblas_datatype_f8_r && B_type == rocblas_datatype_f8_r &&
(A_type==rocblas_datatype_bf8_r && B_type==rocblas_datatype_f8_r && D_type==rocblas_datatype_bf16_r)|| D_type == rocblas_datatype_bf8_r) ||
(A_type==rocblas_datatype_bf8_r && B_type==rocblas_datatype_f8_r && D_type==rocblas_datatype_f8_r)|| (A_type == rocblas_datatype_f8_r && B_type == rocblas_datatype_bf8_r &&
(A_type==rocblas_datatype_bf8_r && B_type==rocblas_datatype_f8_r && D_type==rocblas_datatype_bf8_r), D_type == rocblas_datatype_f32_r) ||
"Only the following combinations of data types are enabled now!\n\ (A_type == rocblas_datatype_f8_r && B_type == rocblas_datatype_bf8_r &&
D_type == rocblas_datatype_f16_r) ||
(A_type == rocblas_datatype_f8_r && B_type == rocblas_datatype_bf8_r &&
D_type == rocblas_datatype_bf16_r) ||
(A_type == rocblas_datatype_f8_r && B_type == rocblas_datatype_bf8_r &&
D_type == rocblas_datatype_f8_r) ||
(A_type == rocblas_datatype_f8_r && B_type == rocblas_datatype_bf8_r &&
D_type == rocblas_datatype_bf8_r) ||
(A_type == rocblas_datatype_bf8_r && B_type == rocblas_datatype_f8_r &&
D_type == rocblas_datatype_f32_r) ||
(A_type == rocblas_datatype_bf8_r && B_type == rocblas_datatype_f8_r &&
D_type == rocblas_datatype_f16_r) ||
(A_type == rocblas_datatype_bf8_r && B_type == rocblas_datatype_f8_r &&
D_type == rocblas_datatype_bf16_r) ||
(A_type == rocblas_datatype_bf8_r && B_type == rocblas_datatype_f8_r &&
D_type == rocblas_datatype_f8_r) ||
(A_type == rocblas_datatype_bf8_r && B_type == rocblas_datatype_f8_r &&
D_type == rocblas_datatype_bf8_r),
"Only the following combinations of data types are enabled now!\n\
1. input: fp32, output: fp32.\n\ 1. input: fp32, output: fp32.\n\
2. input: fp16, output: fp16.\n\ 2. input: fp16, output: fp16.\n\
3. input: bf16, output: bf16.\n\ 3. input: bf16, output: bf16.\n\
4. input: fp8/bf8, output: fp8/bf8, fp16/bf16, fp32"); 4. input: fp8/bf8, output: fp8/bf8, fp16/bf16, fp32");
//If D is not fp32, then we need a temp buffer for GEMM result before applying epilogues. Otherwise, we can apply epilogues in-place. //If D is not fp32, then we need a temp buffer for GEMM result before applying epilogues. Otherwise, we can apply epilogues in-place.
// with bias or gelu, allocate fp32 D_temp if the output is not fp32 // with bias or gelu, allocate fp32 D_temp if the output is not fp32
// with input fp8/bf8 (use_fp8) and bf16 output, need a fp32 D_temp, as rocblas does not support this case (fp8/bf8 input fp16/fp32 output is supported) // with input fp8/bf8 (use_fp8) and bf16 output, need a fp32 D_temp, as rocblas does not support this case (fp8/bf8 input fp16/fp32 output is supported)
// with use_fp8 true and fp8/bf8 output, need fp32 D_temp to support amax and scale operation // with use_fp8 true and fp8/bf8 output, need fp32 D_temp to support amax and scale operation
void* D_temp; void* D_temp;
if (((bias || gelu) && (D_type==rocblas_datatype_f16_r ||D_type==rocblas_datatype_bf16_r))|| if (((bias || gelu) && (D_type == rocblas_datatype_f16_r || D_type == rocblas_datatype_bf16_r)) ||
(use_fp8 && (D_type==rocblas_datatype_bf16_r||D_type==rocblas_datatype_f8_r||D_type==rocblas_datatype_bf8_r))) { (use_fp8 && (D_type == rocblas_datatype_bf16_r || D_type == rocblas_datatype_f8_r ||
if(! stream_order_alloc){ D_type == rocblas_datatype_bf8_r))) {
NVTE_CHECK_CUDA( hipMalloc(&D_temp, sizeof(float)*m*n) ); if (!stream_order_alloc) {
}else{ NVTE_CHECK_CUDA(hipMalloc(&D_temp, sizeof(float) * m * n));
NVTE_CHECK_CUDA( hipMallocAsync(&D_temp, sizeof(float)*m*n, stream) ); } else {
NVTE_CHECK_CUDA(hipMallocAsync(&D_temp, sizeof(float) * m * n, stream));
} }
}else { } else {
D_temp = D; D_temp = D;
} }
// When Ti=To=fp16 and there is no bias or gelu, D_temp points to D and we would like it to be fp16 // When Ti=To=fp16 and there is no bias or gelu, D_temp points to D and we would like it to be fp16
rocblas_datatype D_temp_type = rocblas_datatype_f32_r; rocblas_datatype D_temp_type = rocblas_datatype_f32_r;
if (!(bias || gelu) && (A_type==rocblas_datatype_f16_r && B_type==rocblas_datatype_f16_r && D_type==rocblas_datatype_f16_r)) { if (!(bias || gelu) && (A_type == rocblas_datatype_f16_r && B_type == rocblas_datatype_f16_r &&
D_type == rocblas_datatype_f16_r)) {
D_temp_type = rocblas_datatype_f16_r; D_temp_type = rocblas_datatype_f16_r;
} }
// When Ti=To=bf16 and there is no bias or gelu, D_temp points to D and we would like it to be bf16 // When Ti=To=bf16 and there is no bias or gelu, D_temp points to D and we would like it to be bf16
if (!(bias || gelu) && (A_type==rocblas_datatype_bf16_r && B_type==rocblas_datatype_bf16_r && D_type==rocblas_datatype_bf16_r)) { if (!(bias || gelu) && (A_type == rocblas_datatype_bf16_r && B_type == rocblas_datatype_bf16_r &&
D_type == rocblas_datatype_bf16_r)) {
D_temp_type = rocblas_datatype_bf16_r; D_temp_type = rocblas_datatype_bf16_r;
} }
// When Ti in fp8 or bf8, To=fp16, there is no bias or gelu, D_temp points to D and we would like it to be fp16, as rocblas support this case. // When Ti in fp8 or bf8, To=fp16, there is no bias or gelu, D_temp points to D and we would like it to be fp16, as rocblas support this case.
if ((!(bias||gelu))&& (use_fp8 && D_type==rocblas_datatype_f16_r)) { if ((!(bias || gelu)) && (use_fp8 && D_type == rocblas_datatype_f16_r)) {
D_temp_type = rocblas_datatype_f16_r; D_temp_type = rocblas_datatype_f16_r;
} }
if(accumulate && (D_temp!=D || D_temp_type!=D_type)){ if (accumulate && (D_temp != D || D_temp_type != D_type)) {
DType output_dtype = get_transformer_engine_dtype(D_type); DType output_dtype = get_transformer_engine_dtype(D_type);
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output_dtype, OType, TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
//D_temp allocated only with fp32 output_dtype, OType,
detail::identity_kernelLauncher<OType, float>(reinterpret_cast<const OType*>(D), //D_temp allocated only with fp32
reinterpret_cast<float*>(D_temp), detail::identity_kernelLauncher<OType, float>(
m*n, reinterpret_cast<const OType*>(D), reinterpret_cast<float*>(D_temp), m * n, stream););
stream);
);
} }
// D = alpha * (A * B) + beta * C // D = alpha * (A * B) + beta * C
if (use_fp8) { if (use_fp8) {
rocblas_computetype computeType = rocblas_compute_type_f32; rocblas_computetype computeType = rocblas_compute_type_f32;
NVTE_CHECK_ROCBLAS(rocblas_gemm_ex3(handle, transa, transb, m, n, k, &alpha, NVTE_CHECK_ROCBLAS(rocblas_gemm_ex3(handle, transa, transb, m, n, k, &alpha, A, A_type, lda, B,
A, A_type, lda, B_type, ldb, &beta, D_temp, D_temp_type, ldd, D_temp,
B, B_type, ldb, D_temp_type, ldd, computeType,
&beta, D_temp, D_temp_type, ldd, D_temp, D_temp_type, ldd, rocblas_gemm_algo::rocblas_gemm_algo_standard, 0, 0));
computeType, rocblas_gemm_algo::rocblas_gemm_algo_standard,0,0)); } else {
}else {
rocblas_datatype computeType = rocblas_datatype_f32_r; rocblas_datatype computeType = rocblas_datatype_f32_r;
uint32_t flags = rocblas_gemm_flags_none; uint32_t flags = rocblas_gemm_flags_none;
if((A_type==rocblas_datatype_f16_r && B_type==rocblas_datatype_f16_r) && grad){ if ((A_type == rocblas_datatype_f16_r && B_type == rocblas_datatype_f16_r) && grad) {
flags = rocblas_gemm_flags_fp16_alt_impl; flags = rocblas_gemm_flags_fp16_alt_impl;
} }
NVTE_CHECK_ROCBLAS(rocblas_gemm_ex(handle, transa, transb, m, n, k, &alpha, NVTE_CHECK_ROCBLAS(rocblas_gemm_ex(handle, transa, transb, m, n, k, &alpha, A, A_type, lda, B,
A, A_type, lda, B_type, ldb, &beta, D_temp, D_temp_type, ldd, D_temp,
B, B_type, ldb, D_temp_type, ldd, computeType,
&beta, D_temp, D_temp_type, ldd, D_temp, D_temp_type, ldd, rocblas_gemm_algo::rocblas_gemm_algo_standard, 0, flags));
computeType, rocblas_gemm_algo::rocblas_gemm_algo_standard,0,flags));
} }
int batch_size, input_dim, output_dim; int batch_size, input_dim, output_dim;
if (bias && gelu) { if (bias && gelu) {
if (grad) { if (grad) {
// epilogue = CUBLASLT_EPILOGUE_DGELU_BGRAD; // epilogue = CUBLASLT_EPILOGUE_DGELU_BGRAD;
// Apply GELU gradient to D_temp and store in D // Apply GELU gradient to D_temp and store in D
// Apply bias gradient to D (D is already the result of GELU gradient) and store in bias_ptr; // Apply bias gradient to D (D is already the result of GELU gradient) and store in bias_ptr;
// This case is NN // This case is NN
// D_temp is of shape is (m, n) in column major and thus is of shape (n, m) in row major // D_temp is of shape is (m, n) in column major and thus is of shape (n, m) in row major
// The bias vector length is m. So it will be reduced along axis 0 in row major // The bias vector length is m. So it will be reduced along axis 0 in row major
...@@ -1815,57 +1623,53 @@ void rocblas_gemm(const Tensor *inputA, ...@@ -1815,57 +1623,53 @@ void rocblas_gemm(const Tensor *inputA,
// the GELU gradient result in lower precision (D). It might be better to take the GELU // the GELU gradient result in lower precision (D). It might be better to take the GELU
// gradient result in fp32 but as it requires some kernel changes I would only do that // gradient result in fp32 but as it requires some kernel changes I would only do that
// once we confirm that this is the right form of the epilogue. // once we confirm that this is the right form of the epilogue.
// This is for linear1 -> gelu -> linear2 // This is for linear1 -> gelu -> linear2
// compute dX = dY * W for linear2 // compute dX = dY * W for linear2
// gemm_ex(A=W, B=dY) // gemm_ex(A=W, B=dY)
batch_size = n; batch_size = n;
input_dim = m; // input dimension of the second linear layer is the output dimension of the first linear layer input_dim =
m; // input dimension of the second linear layer is the output dimension of the first linear layer
output_dim = k; output_dim = k;
DType output_dtype = get_transformer_engine_dtype(D_type); DType output_dtype = get_transformer_engine_dtype(D_type);
DType gelu_dtype = get_transformer_engine_dtype(gelu_type); DType gelu_dtype = get_transformer_engine_dtype(gelu_type);
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output_dtype, OType, TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(gelu_dtype, GType, output_dtype, OType,
detail::gelu_backward_kernelLauncher<OType, GType>(reinterpret_cast<const float*>(D_temp), TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
reinterpret_cast<OType*>(D), gelu_dtype, GType,
reinterpret_cast<const GType*>(pre_gelu_out), detail::gelu_backward_kernelLauncher<OType, GType>(
batch_size, reinterpret_cast<const float*>(D_temp), reinterpret_cast<OType*>(D),
input_dim, reinterpret_cast<const GType*>(pre_gelu_out), batch_size, input_dim, stream);););
stream);
);
);
void* bias_tmp; void* bias_tmp;
if (bias_type != rocblas_datatype_f32_r) { if (bias_type != rocblas_datatype_f32_r) {
if(! stream_order_alloc){ if (!stream_order_alloc) {
NVTE_CHECK_CUDA( hipMalloc(&bias_tmp, sizeof(float)*input_dim) ); // The bias gradient is for the first linear layer NVTE_CHECK_CUDA(hipMalloc(
}else{ &bias_tmp,
NVTE_CHECK_CUDA( hipMallocAsync(&bias_tmp, sizeof(float)*input_dim, stream) ); sizeof(float) * input_dim)); // The bias gradient is for the first linear layer
} else {
NVTE_CHECK_CUDA(hipMallocAsync(&bias_tmp, sizeof(float) * input_dim, stream));
} }
}else { } else {
bias_tmp = bias_ptr; bias_tmp = bias_ptr;
} }
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output_dtype, OType, TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
detail::bias_gradient_kernelLauncher<OType>(reinterpret_cast<const OType*>(D), output_dtype, OType,
reinterpret_cast<float*>(bias_tmp), detail::bias_gradient_kernelLauncher<OType>(
batch_size, reinterpret_cast<const OType*>(D), reinterpret_cast<float*>(bias_tmp), batch_size,
input_dim, input_dim, stream_order_alloc, stream););
stream_order_alloc,
stream);
);
if (bias_type != rocblas_datatype_f32_r) { if (bias_type != rocblas_datatype_f32_r) {
DType bias_dtype = get_transformer_engine_dtype(bias_type); DType bias_dtype = get_transformer_engine_dtype(bias_type);
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(bias_dtype, BType, TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
detail::identity_kernelLauncher<float, BType>(reinterpret_cast<const float*>(bias_tmp), bias_dtype, BType,
reinterpret_cast<BType*>(bias_ptr), detail::identity_kernelLauncher<float, BType>(reinterpret_cast<const float*>(bias_tmp),
input_dim, reinterpret_cast<BType*>(bias_ptr),
stream); input_dim, stream););
); if (!stream_order_alloc) {
if(! stream_order_alloc){ NVTE_CHECK_CUDA(hipFree(bias_tmp));
NVTE_CHECK_CUDA( hipFree(bias_tmp) ); } else {
}else{ NVTE_CHECK_CUDA(hipFreeAsync(bias_tmp, stream));
NVTE_CHECK_CUDA( hipFreeAsync(bias_tmp, stream) );
} }
} }
...@@ -1880,23 +1684,20 @@ void rocblas_gemm(const Tensor *inputA, ...@@ -1880,23 +1684,20 @@ void rocblas_gemm(const Tensor *inputA,
DType output_dtype = get_transformer_engine_dtype(D_type); DType output_dtype = get_transformer_engine_dtype(D_type);
DType bias_dtype = get_transformer_engine_dtype(bias_type); DType bias_dtype = get_transformer_engine_dtype(bias_type);
DType gelu_dtype = get_transformer_engine_dtype(gelu_type); DType gelu_dtype = get_transformer_engine_dtype(gelu_type);
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output_dtype, OType, TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(gelu_dtype, GType, output_dtype, OType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(bias_dtype, BType, TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
detail::add_bias_gelu_kernelLauncher<OType, GType, BType>(reinterpret_cast<const float*>(D_temp), gelu_dtype, GType,
reinterpret_cast<OType*>(D), TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
reinterpret_cast<GType*>(pre_gelu_out), bias_dtype, BType,
reinterpret_cast<const BType*>(bias_ptr), detail::add_bias_gelu_kernelLauncher<OType, GType, BType>(
reinterpret_cast<float*>(D_amax), reinterpret_cast<const float*>(D_temp), reinterpret_cast<OType*>(D),
reinterpret_cast<const float*>(D_scale), reinterpret_cast<GType*>(pre_gelu_out),
batch_size, reinterpret_cast<const BType*>(bias_ptr), reinterpret_cast<float*>(D_amax),
output_dim, reinterpret_cast<const float*>(D_scale), batch_size, output_dim,
stream); stream););););
);
);
);
} }
}else if (bias) { } else if (bias) {
if (grad) { if (grad) {
// grad output is always input B // grad output is always input B
// epilogue = CUBLASLT_EPILOGUE_BGRADB; // epilogue = CUBLASLT_EPILOGUE_BGRADB;
...@@ -1908,52 +1709,47 @@ void rocblas_gemm(const Tensor *inputA, ...@@ -1908,52 +1709,47 @@ void rocblas_gemm(const Tensor *inputA,
batch_size = k; batch_size = k;
input_dim = m; input_dim = m;
output_dim = n; output_dim = n;
void * bias_tmp; void* bias_tmp;
if (bias_type != rocblas_datatype_f32_r) { if (bias_type != rocblas_datatype_f32_r) {
if(! stream_order_alloc){ if (!stream_order_alloc) {
NVTE_CHECK_CUDA( hipMalloc(&bias_tmp, sizeof(float)*output_dim) ); NVTE_CHECK_CUDA(hipMalloc(&bias_tmp, sizeof(float) * output_dim));
}else{ } else {
NVTE_CHECK_CUDA( hipMallocAsync(&bias_tmp, sizeof(float)*output_dim, stream) ); NVTE_CHECK_CUDA(hipMallocAsync(&bias_tmp, sizeof(float) * output_dim, stream));
} }
}else { } else {
bias_tmp = bias_ptr; bias_tmp = bias_ptr;
} }
DType input_dtype = get_transformer_engine_dtype(B_type); DType input_dtype = get_transformer_engine_dtype(B_type);
DType output_dtype = get_transformer_engine_dtype(D_type); DType output_dtype = get_transformer_engine_dtype(D_type);
DType bias_dtype = get_transformer_engine_dtype(bias_type); DType bias_dtype = get_transformer_engine_dtype(bias_type);
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(input_dtype, IType, TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
detail::bias_gradient_kernelLauncher<IType>(reinterpret_cast<const IType*>(B), input_dtype, IType,
reinterpret_cast<float*>(bias_tmp), detail::bias_gradient_kernelLauncher<IType>(
batch_size, reinterpret_cast<const IType*>(B), reinterpret_cast<float*>(bias_tmp), batch_size,
output_dim, output_dim, stream_order_alloc, stream););
stream_order_alloc,
stream);
);
if (bias_type != rocblas_datatype_f32_r) { if (bias_type != rocblas_datatype_f32_r) {
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(bias_dtype, BType, TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
detail::identity_kernelLauncher<float, BType>(reinterpret_cast<const float*>(bias_tmp), bias_dtype, BType,
reinterpret_cast<BType*>(bias_ptr), detail::identity_kernelLauncher<float, BType>(reinterpret_cast<const float*>(bias_tmp),
output_dim, reinterpret_cast<BType*>(bias_ptr),
stream); output_dim, stream););
); if (!stream_order_alloc) {
if(! stream_order_alloc){ NVTE_CHECK_CUDA(hipFree(bias_tmp));
NVTE_CHECK_CUDA( hipFree(bias_tmp) ); } else {
}else{ NVTE_CHECK_CUDA(hipFreeAsync(bias_tmp, stream));
NVTE_CHECK_CUDA( hipFreeAsync(bias_tmp, stream) );
} }
} }
if (D_type == rocblas_datatype_f16_r || D_type == rocblas_datatype_bf16_r) { if (D_type == rocblas_datatype_f16_r || D_type == rocblas_datatype_bf16_r) {
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output_dtype, OType, TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
detail::identity_kernelLauncher<float, OType>(reinterpret_cast<const float*>(D_temp), output_dtype, OType,
reinterpret_cast<OType*>(D), detail::identity_kernelLauncher<float, OType>(reinterpret_cast<const float*>(D_temp),
input_dim*output_dim, reinterpret_cast<OType*>(D),
stream); input_dim * output_dim, stream););
);
} }
} else { } else {
// epilogue = CUBLASLT_EPILOGUE_BIAS; // epilogue = CUBLASLT_EPILOGUE_BIAS;
// Broadcast bias and add it to D_temp and store in D. The bias vector length is m // Broadcast bias and add it to D_temp and store in D. The bias vector length is m
// D_temp is of shape is (m, n) in column major and thus is of shape (n, m) in row major // D_temp is of shape is (m, n) in column major and thus is of shape (n, m) in row major
// gemm_ex(A=W, B=X, transA=T) // gemm_ex(A=W, B=X, transA=T)
batch_size = n; batch_size = n;
...@@ -1961,40 +1757,33 @@ void rocblas_gemm(const Tensor *inputA, ...@@ -1961,40 +1757,33 @@ void rocblas_gemm(const Tensor *inputA,
output_dim = m; output_dim = m;
DType output_dtype = get_transformer_engine_dtype(D_type); DType output_dtype = get_transformer_engine_dtype(D_type);
DType bias_dtype = get_transformer_engine_dtype(bias_type); DType bias_dtype = get_transformer_engine_dtype(bias_type);
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output_dtype, OType, TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(bias_dtype, BType, output_dtype, OType,
detail::add_bias_kernelLauncher<OType, BType>(reinterpret_cast<const float*>(D_temp), TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
reinterpret_cast<OType*>(D), bias_dtype, BType,
reinterpret_cast<const BType*>(bias_ptr), detail::add_bias_kernelLauncher<OType, BType>(
reinterpret_cast<float*>(D_amax), reinterpret_cast<const float*>(D_temp), reinterpret_cast<OType*>(D),
reinterpret_cast<const float*>(D_scale), reinterpret_cast<const BType*>(bias_ptr), reinterpret_cast<float*>(D_amax),
batch_size, reinterpret_cast<const float*>(D_scale), batch_size, output_dim, stream);););
output_dim,
stream);
);
);
} }
}else if (gelu) { } else if (gelu) {
if (grad) { if (grad) {
// epilogue = CUBLASLT_EPILOGUE_DGELU; // epilogue = CUBLASLT_EPILOGUE_DGELU;
// Take input from pre_gelu_out and apply GELU gradients to D_temp and store result in D // Take input from pre_gelu_out and apply GELU gradients to D_temp and store result in D
// D_temp is of shape is (m, n) in column major and thus is of shape (n, m) in row major // D_temp is of shape is (m, n) in column major and thus is of shape (n, m) in row major
// gemm_ex(A=W, B=dY) // gemm_ex(A=W, B=dY)
batch_size = n; batch_size = n;
input_dim = m; input_dim = m;
output_dim = k; output_dim = k;
DType output_dtype = get_transformer_engine_dtype(D_type); DType output_dtype = get_transformer_engine_dtype(D_type);
DType gelu_dtype = get_transformer_engine_dtype(gelu_type); DType gelu_dtype = get_transformer_engine_dtype(gelu_type);
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output_dtype, OType, TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(gelu_dtype, GType, output_dtype, OType,
detail::gelu_backward_kernelLauncher<OType, GType>(reinterpret_cast<const float*>(D_temp), TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
reinterpret_cast<OType*>(D), gelu_dtype, GType,
reinterpret_cast<const GType*>(pre_gelu_out), detail::gelu_backward_kernelLauncher<OType, GType>(
batch_size, reinterpret_cast<const float*>(D_temp), reinterpret_cast<OType*>(D),
input_dim, reinterpret_cast<const GType*>(pre_gelu_out), batch_size, input_dim, stream);););
stream);
);
);
} else { } else {
// epilogue = CUBLASLT_EPILOGUE_GELU_AUX; // epilogue = CUBLASLT_EPILOGUE_GELU_AUX;
// Store (quantized) D_temp in pre_gelu_out, and apply GELU to D_temp then store in D // Store (quantized) D_temp in pre_gelu_out, and apply GELU to D_temp then store in D
...@@ -2005,57 +1794,53 @@ void rocblas_gemm(const Tensor *inputA, ...@@ -2005,57 +1794,53 @@ void rocblas_gemm(const Tensor *inputA,
output_dim = m; output_dim = m;
DType gelu_dtype = get_transformer_engine_dtype(gelu_type); DType gelu_dtype = get_transformer_engine_dtype(gelu_type);
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(gelu_dtype, GType, TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
detail::identity_kernelLauncher<float, GType>(reinterpret_cast<const float*>(D_temp), gelu_dtype, GType,
reinterpret_cast<GType*>(pre_gelu_out), detail::identity_kernelLauncher<float, GType>(reinterpret_cast<const float*>(D_temp),
batch_size*output_dim, reinterpret_cast<GType*>(pre_gelu_out),
stream); batch_size * output_dim, stream););
);
DType output_dtype = get_transformer_engine_dtype(D_type); DType output_dtype = get_transformer_engine_dtype(D_type);
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output_dtype, OType, TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
detail::gelu_forward_kernelLauncher<OType>(reinterpret_cast<const float*>(D_temp), output_dtype, OType,
reinterpret_cast<OType*>(D), detail::gelu_forward_kernelLauncher<OType>(
reinterpret_cast<float*>(D_amax), reinterpret_cast<const float*>(D_temp), reinterpret_cast<OType*>(D),
reinterpret_cast<const float*>(D_scale), reinterpret_cast<float*>(D_amax), reinterpret_cast<const float*>(D_scale), batch_size,
batch_size, output_dim, stream););
output_dim, }
stream); } else { // No epilogue - !(bias || gelu)
); if (use_fp8 && (D_type == rocblas_datatype_bf16_r || D_type == rocblas_datatype_f8_r ||
} D_type == rocblas_datatype_bf8_r)) {
} else { // No epilogue - !(bias || gelu)
if (use_fp8 && (D_type==rocblas_datatype_bf16_r || D_type == rocblas_datatype_f8_r || D_type == rocblas_datatype_bf8_r)) {
DType output_dtype = get_transformer_engine_dtype(D_type); DType output_dtype = get_transformer_engine_dtype(D_type);
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output_dtype, OType, TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
detail::identity_output_kernelLauncher<OType>(reinterpret_cast<const float*>(D_temp), output_dtype, OType,
reinterpret_cast<OType*>(D), detail::identity_output_kernelLauncher<OType>(
reinterpret_cast<float*>(D_amax), reinterpret_cast<const float*>(D_temp), reinterpret_cast<OType*>(D),
reinterpret_cast<const float*>(D_scale), reinterpret_cast<float*>(D_amax), reinterpret_cast<const float*>(D_scale), m * n,
m*n, stream););
stream);
);
} }
} }
if (((bias || gelu) && (D_type==rocblas_datatype_f16_r ||D_type==rocblas_datatype_bf16_r))|| if (((bias || gelu) && (D_type == rocblas_datatype_f16_r || D_type == rocblas_datatype_bf16_r)) ||
(use_fp8 && (D_type==rocblas_datatype_bf16_r || D_type==rocblas_datatype_f8_r || D_type==rocblas_datatype_bf8_r))) { (use_fp8 && (D_type == rocblas_datatype_bf16_r || D_type == rocblas_datatype_f8_r ||
if(! stream_order_alloc){ D_type == rocblas_datatype_bf8_r))) {
NVTE_CHECK_CUDA( hipFree(D_temp) ); if (!stream_order_alloc) {
}else{ NVTE_CHECK_CUDA(hipFree(D_temp));
NVTE_CHECK_CUDA( hipFreeAsync(D_temp, stream) ); } else {
NVTE_CHECK_CUDA(hipFreeAsync(D_temp, stream));
} }
} }
} }
#endif //USE_ROCBLAS #endif //USE_ROCBLAS
void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, void cublas_gemm(const Tensor* inputA, const Tensor* inputB, Tensor* outputD,
const Tensor *inputBias, Tensor *outputPreGelu, int m, int n, int k, int lda, const Tensor* inputBias, Tensor* outputPreGelu, int m, int n, int k, int lda,
int ldb, int ldd, bool transa, bool transb, bool grad, int ldb, int ldd, bool transa, bool transb, bool grad, void* workspace,
void *workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, size_t workspaceSize, bool accumulate, bool use_split_accumulator,
int math_sm_count, int m_split, int n_split, bool gemm_producer, int math_sm_count, int m_split, int n_split, bool gemm_producer,
const Tensor *inputCounter, hipStream_t stream, bool nvte_use_hipblaslt = 0, bool nvte_use_rocblas = 0, int compute_stream_offset = -1) const Tensor* inputCounter, hipStream_t stream, bool nvte_use_hipblaslt = 0,
{ bool nvte_use_rocblas = 0, int compute_stream_offset = -1) {
/*If no backend is specified with env variable use HIPBLASLT unless it is disabled /*If no backend is specified with env variable use HIPBLASLT unless it is disabled
If HIPBLASLT backend is enabled and requested, use it despite ROCBLAS status If HIPBLASLT backend is enabled and requested, use it despite ROCBLAS status
Otherwise use ROCBLAS Otherwise use ROCBLAS
*/ */
...@@ -2066,27 +1851,23 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -2066,27 +1851,23 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
#if !defined(USE_HIPBLASLT) && !defined(USE_ROCBLAS) #if !defined(USE_HIPBLASLT) && !defined(USE_ROCBLAS)
#error GEMM backend is not specified #error GEMM backend is not specified
#elif !defined(USE_HIPBLASLT) #elif !defined(USE_HIPBLASLT)
if (use_hipblaslt) if (use_hipblaslt) {
{
use_hipblaslt = false; use_hipblaslt = false;
use_rocblas = true; use_rocblas = true;
std::cout << "[NOTICE] hipBLASLt is not enabled, NVTE_USE_HIPBLASLT env is ignored\n"; std::cout << "[NOTICE] hipBLASLt is not enabled, NVTE_USE_HIPBLASLT env is ignored\n";
} }
#elif !defined(USE_ROCBLAS) #elif !defined(USE_ROCBLAS)
if (use_rocblas) if (use_rocblas) {
{
use_rocblas = false; use_rocblas = false;
use_hipblaslt = true; use_hipblaslt = true;
std::cout << "[NOTICE] rocBLAS is not enabled, NVTE_USE_ROCBLAS env is ignored\n"; std::cout << "[NOTICE] rocBLAS is not enabled, NVTE_USE_ROCBLAS env is ignored\n";
} }
#else #else
if (use_hipblaslt && use_rocblas) if (use_hipblaslt && use_rocblas) {
{
use_rocblas = false; use_rocblas = false;
use_hipblaslt = true; use_hipblaslt = true;
// std::cout << "[NOTICE] Two GEMM backend are enabled, hipBLASLt will be used\n"; // std::cout << "[NOTICE] Two GEMM backend are enabled, hipBLASLt will be used\n";
} else if (!use_hipblaslt && !use_rocblas) } else if (!use_hipblaslt && !use_rocblas) {
{
use_rocblas = false; use_rocblas = false;
use_hipblaslt = true; use_hipblaslt = true;
// std::cout << "[NOTICE] Two GEMM backend are disabled, hipBLASLt will be used\n"; // std::cout << "[NOTICE] Two GEMM backend are disabled, hipBLASLt will be used\n";
...@@ -2094,8 +1875,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -2094,8 +1875,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
#endif #endif
#ifdef USE_HIPBLASLT #ifdef USE_HIPBLASLT
if (use_hipblaslt || !use_rocblas) if (use_hipblaslt || !use_rocblas) {
{
// Check compute_stream_offset valid. // Check compute_stream_offset valid.
NVTE_CHECK(compute_stream_offset >= -1 && compute_stream_offset < compute_num_streams); NVTE_CHECK(compute_stream_offset >= -1 && compute_stream_offset < compute_num_streams);
...@@ -2109,33 +1889,24 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, ...@@ -2109,33 +1889,24 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
handle = hipblaslt_handles[compute_stream_offset]; handle = hipblaslt_handles[compute_stream_offset];
} }
hipblaslt_gemm(inputA, inputB, outputD, inputBias, outputPreGelu, hipblaslt_gemm(inputA, inputB, outputD, inputBias, outputPreGelu, m, n, k, lda, ldb, ldd,
m, n, k, lda, ldb, ldd, (transa) ? HIPBLAS_OP_T : HIPBLAS_OP_N, (transb) ? HIPBLAS_OP_T : HIPBLAS_OP_N,
(transa) ? HIPBLAS_OP_T : HIPBLAS_OP_N, grad, workspace, workspaceSize, accumulate, use_split_accumulator, math_sm_count,
(transb) ? HIPBLAS_OP_T : HIPBLAS_OP_N, m_split, n_split, gemm_producer, inputCounter, stream, handle);
grad,
workspace, workspaceSize, accumulate, use_split_accumulator,
math_sm_count, m_split, n_split, gemm_producer,
inputCounter, stream,
handle);
return; return;
} }
#endif #endif
#ifdef USE_ROCBLAS #ifdef USE_ROCBLAS
if (use_rocblas) if (use_rocblas) {
{ rocblas_gemm(inputA, inputB, outputD, inputBias, outputPreGelu, m, n, k, lda, ldb, ldd,
rocblas_gemm(inputA, inputB, outputD, inputBias, outputPreGelu, (transa) ? rocblas_operation_transpose : rocblas_operation_none,
m, n, k, lda, ldb, ldd, (transb) ? rocblas_operation_transpose : rocblas_operation_none, grad, workspace,
(transa) ? rocblas_operation_transpose : rocblas_operation_none, workspaceSize, accumulate, use_split_accumulator, math_sm_count, m_split, n_split,
(transb) ? rocblas_operation_transpose : rocblas_operation_none, gemm_producer, inputCounter, stream);
grad,
workspace, workspaceSize, accumulate, use_split_accumulator,
math_sm_count, m_split, n_split, gemm_producer,
inputCounter, stream);
} }
#endif #endif
} }
} //namespace transformer_engine } //namespace transformer_engine
\ No newline at end of file \ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment