"vscode:/vscode.git/clone" did not exist on "5b6acc1495f4c4d44bfdb0ce8090426de280b002"
Commit 899b52ce authored by aiss's avatar aiss
Browse files

Merge branch 'ds-v0.8.2-rocm' into 'main'

Ds v0.8.2 rocm, support torch1.13 for the hipify change

See merge request aicomponent/deepspeed!1
parents 67ea635f 4acf0e01
// !!! This is a file automatically generated by hipify!!!
#pragma once
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <stdio.h>
#include "custom_hip_layers.h"
#include <fstream>
using namespace std;
template <typename T>
class Softmax {
public:
struct Config {
size_t batchSize;
size_t heads;
size_t seq_length;
size_t prob_depth;
float temperature;
bool mem_alloc;
Config(size_t batch, size_t h, size_t seq, int prob_size = 0, bool mem_alloc = false)
: batchSize(batch),
heads(h),
seq_length(seq),
prob_depth(prob_size),
temperature(1.0),
mem_alloc(mem_alloc)
{
}
};
Softmax(Config config) : config_(config) {}
~Softmax() {}
void Forward(int bsz, T* vals, const T* attn_mask, hipStream_t& stream)
{
launch_attn_softmax<T>(vals, attn_mask, bsz, config_.heads, config_.seq_length, stream);
}
void Backward(int bsz, T* out_grad, const T* soft_out, hipStream_t stream)
{
launch_attn_softmax_backward_v2<T>(
out_grad, soft_out, bsz, config_.heads, config_.seq_length, stream);
}
inline size_t GetProbDepth() const { return config_.prob_depth; }
inline size_t GetBatchSize() const { return config_.batchSize; }
inline size_t GetNumHeads() const { return config_.heads; }
inline size_t GetSeqLength() const { return config_.seq_length; }
inline void SetSeqLength(size_t seq_len) { config_.seq_length = seq_len; }
private:
Config config_;
};
// !!! This is a file automatically generated by hipify!!!
#pragma once
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <stdio.h>
#include "context_hip.h"
template <typename T>
class StridedBatchGemm {
public:
struct Config {
int batch_size;
int m;
int n;
int k;
float alpha;
float beta;
rocblas_operation op_A;
rocblas_operation op_B;
std::array<int, 3> gemm_algos;
Config(int batch,
int mm,
int nn,
int kk,
float param_alpha,
float param_beta,
rocblas_operation opA,
rocblas_operation opB,
const std::array<int, 3>& algos)
: batch_size(batch),
m(mm),
n(nn),
k(kk),
alpha(param_alpha),
beta(param_beta),
op_A(opA),
op_B(opB),
gemm_algos(algos)
{
}
void SetConfig(int mm, int nn, int kk)
{
m = mm;
n = nn;
k = kk;
}
};
StridedBatchGemm(const Config& config) : _config(config) {}
virtual ~StridedBatchGemm() {}
void Forward(int bsz, T* output, const T* _buffer_a, const T* _buffer_b, rocblas_handle handle)
{
int stride_a = _config.m * _config.k;
int stride_b = _config.n * _config.k;
int stride_c = _config.m * _config.n;
cublas_strided_batched_gemm(handle,
_config.m,
_config.n,
_config.k,
&_config.alpha,
&_config.beta,
_buffer_a,
_buffer_b,
output,
_config.op_A,
_config.op_B,
stride_a,
stride_b,
stride_c,
bsz,
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo(_config.gemm_algos[0]));
#else
cublasGemmAlgo_t(_config.gemm_algos[0]));
#endif
}
void ForwardPlusSave(T* output, const T* _buffer_a, const T* _buffer_b, rocblas_handle handle)
{
int stride_a = _config.m * _config.k;
int stride_b = _config.n * _config.k;
int stride_c = _config.m * _config.n;
cublas_strided_batched_gemm(handle,
_config.m,
_config.n,
_config.k,
&_config.alpha,
&_config.beta,
_buffer_a,
_buffer_b,
output,
_config.op_A,
_config.op_B,
stride_a,
stride_b,
stride_c,
_config.batch_size,
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo(_config.gemm_algos[0]));
#else
cublasGemmAlgo_t(_config.gemm_algos[0]));
#endif
k_buf = _buffer_a;
q_buf = _buffer_b;
}
void Backward(int bsz,
const T* d_output,
const T* _buffer_a,
const T* _buffer_b,
rocblas_handle handle,
T* inpGradA = nullptr,
T* inpGradB = nullptr)
{
int mb = (_config.op_A == rocblas_operation_transpose ? _config.k : _config.m);
int kb = (_config.op_A == rocblas_operation_transpose ? _config.m : _config.k);
int stride_a = mb * _config.n;
int stride_b = _config.n * kb;
int stride_c = _config.m * _config.k;
// B need to transpose.
rocblas_operation op_b = (_config.op_B == rocblas_operation_transpose ? rocblas_operation_none : rocblas_operation_transpose);
// Calculate d_A.
cublas_strided_batched_gemm(handle,
mb,
kb,
_config.n,
&_config.alpha,
&_config.beta,
(_config.op_A == rocblas_operation_transpose ? _buffer_b : d_output),
(_config.op_A == rocblas_operation_transpose ? d_output : _buffer_b),
inpGradA,
rocblas_operation_none,
op_b,
stride_a,
stride_b,
stride_c,
bsz,
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo(_config.gemm_algos[1]));
#else
cublasGemmAlgo_t(_config.gemm_algos[1]));
#endif
// A need to transpose.
rocblas_operation op_a = (_config.op_A == rocblas_operation_transpose ? rocblas_operation_none : rocblas_operation_transpose);
stride_a = _config.m * _config.k;
stride_b = _config.m * _config.n;
stride_c = _config.n * _config.k;
// Calculate d_B.
cublas_strided_batched_gemm(handle,
_config.k,
_config.n,
_config.m,
&_config.alpha,
&_config.beta,
_buffer_a,
d_output,
inpGradB,
op_a,
rocblas_operation_none,
stride_a,
stride_b,
stride_c,
bsz,
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo(_config.gemm_algos[2]));
#else
cublasGemmAlgo_t(_config.gemm_algos[2]));
#endif
}
inline int GetN() const { return _config.k; }
inline const T* GetBufferA() const { return k_buf; }
inline const T* GetBufferB() const { return q_buf; }
inline void SetConfig(int m, int n, int k) { _config.SetConfig(m, n, k); }
private:
Config _config;
const T* q_buf;
const T* k_buf;
};
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
/* Taken from NVIDIA/apex commit 855808f3fc268e9715d613f3c2e56469d8c986d8 */
#include <ATen/ATen.h>
// Forward/backward compatibility hack around
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
// pending more future-proof guidance from upstream.
// struct TypeShim
// {
// const at::Type& payload;
// TypeShim(const at::Type& type) : payload(type) {}
// // Enable trivial conversion to a const at::Type& for pre-3aeb78
// operator const at::Type&(){ return payload; };
// // Enable dispatch switch statements to take *this directly for post-3aeb78
// //operator at::ScalarType(){ return payload.; };
// };
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t_##LEVEL = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Double: { \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t_##LEVEL = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Double: { \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
default: AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
template <typename T>
__device__ __forceinline__ T
reduce_block_into_lanes(T* x,
T val,
int lanes = 1,
bool share_result = false) // lanes is intended to be <= 32.
{
int tid = threadIdx.x + threadIdx.y * blockDim.x;
int blockSize = blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
if (blockSize >= 64) {
x[tid] = val;
__syncthreads();
}
#pragma unroll
for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
if (tid < i) x[tid] = x[tid] + x[tid + i];
__syncthreads();
}
T final;
if (tid < 32) {
if (blockSize >= 64)
final = x[tid] + x[tid + 32];
else
final = val;
// __SYNCWARP();
#pragma unroll
for (int i = 16; i >= lanes; i >>= 1)
final = final + __shfl_down_sync(0xffffffff, final, i);
}
if (share_result) {
if (tid < lanes) x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads();
}
return final;
}
// !!! This is a file automatically generated by hipify!!!
/* Copyright 2019 The Microsoft DeepSpeed Team */
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
#include <stdio.h>
#include <cmath>
#include "ATen/ATen.h"
#include "ATen/TensorUtils.h"
#include "ATen/hip/HIPContext.h"
#include "ATen/hip/detail/IndexUtils.cuh"
//#include "ATen/Type.h"
#include "ATen/AccumulateType.h"
#include <iostream>
//#include <helper_functions.h>
#if defined(__HIP_PLATFORM_HCC__) && HIP_VERSION > 305
#include <hip/hip_cooperative_groups.h>
#else
#include <cooperative_groups.h>
#endif
#include <hip/hip_runtime_api.h>
#include <stdio.h>
namespace cg = cooperative_groups;
// Utility class used to avoid linker errors with extern
// unsized shared memory arrays with templated type
namespace {
// This is the un-specialized struct. Note that we prevent instantiation of this
// struct by putting an undefined symbol in the function body so it won't compile.
template <typename T>
struct SharedMemory {
// Ensure that we won't compile any un-specialized types
__device__ inline operator T*()
{
#ifndef _WIN32
extern __device__ void error(void);
error();
#endif
return NULL;
}
};
template <>
struct SharedMemory<float> {
__device__ inline operator float*()
{
HIP_DYNAMIC_SHARED( float, s_float)
return s_float;
}
};
template <>
struct SharedMemory<double> {
__device__ inline operator double*()
{
HIP_DYNAMIC_SHARED( double, s_double)
return s_double;
}
};
} // namespace
#include "type_shim_hip.h"
typedef enum {
ADAM_MODE_0 = 0, // eps under square root
ADAM_MODE_1 = 1 // eps outside square root
} adamMode_t;
// s_a and s_b are in shared memory
// g_a and g_b are in shared memory
template <typename T, int blockSize>
__device__ void reduce_block_in_shared_memory(T* s_a, T* s_b, T* g_a, T* g_b)
{
// Handle to thread block group
cg::thread_block cta = cg::this_thread_block();
// perform block reduction in shared memory,
unsigned int tid = cta.thread_rank();
T a_sum = s_a[tid];
T b_sum = s_b[tid];
cg::sync(cta);
// do reduction in shared mem
if ((blockSize >= 512) && (tid < 256)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 256];
s_b[tid] = b_sum = b_sum + s_b[tid + 256];
}
cg::sync(cta);
if ((blockSize >= 256) && (tid < 128)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 128];
s_b[tid] = b_sum = b_sum + s_b[tid + 128];
}
cg::sync(cta);
if ((blockSize >= 128) && (tid < 64)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 64];
s_b[tid] = b_sum = b_sum + s_b[tid + 64];
}
cg::sync(cta);
#if (__CUDA_ARCH__ >= 300)
if (tid < 32) {
cg::coalesced_group active = cg::coalesced_threads();
// Fetch final intermediate sum from 2nd warp
if (blockSize >= 64) {
a_sum = a_sum + s_a[tid + 32];
b_sum = b_sum + s_b[tid + 32];
}
// Reduce final warp using shuffle
for (int offset = warpSize / 2; offset > 0; offset /= 2) {
a_sum += active.shfl_down(a_sum, offset);
b_sum += active.shfl_down(b_sum, offset);
}
}
#else
if ((blockSize >= 64) && (tid < 32)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 32];
s_b[tid] = b_sum = b_sum + s_b[tid + 32];
}
cg::sync(cta);
if ((blockSize >= 32) && (tid < 16)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 16];
s_b[tid] = b_sum = b_sum + s_b[tid + 16];
}
cg::sync(cta);
if ((blockSize >= 16) && (tid < 8)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 8];
s_b[tid] = b_sum = b_sum + s_b[tid + 8];
}
cg::sync(cta);
if ((blockSize >= 8) && (tid < 4)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 4];
s_b[tid] = b_sum = b_sum + s_b[tid + 4];
}
cg::sync(cta);
if ((blockSize >= 4) && (tid < 2)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 2];
s_b[tid] = b_sum = b_sum + s_b[tid + 2];
}
cg::sync(cta);
if ((blockSize >= 2) && (tid < 1)) {
s_a[tid] = a_sum = a_sum + s_a[tid + 1];
s_b[tid] = b_sum = b_sum + s_b[tid + 1];
}
cg::sync(cta);
#endif
// write result for this block to global mem
if (tid == 0) {
g_a[blockIdx.x] = (T)a_sum;
g_b[blockIdx.x] = (T)b_sum;
}
}
template <typename T, int blockSize>
__device__ void reduce_two_vectors_in_register(T a, T b, T* g_a, T* g_b)
{
const int threadIdInBlock = cg::this_thread_block().thread_rank();
T* s_a = SharedMemory<T>();
T* s_b = SharedMemory<T>() + cg::this_thread_block().size();
s_a[threadIdInBlock] = a;
s_b[threadIdInBlock] = b;
reduce_block_in_shared_memory<T, blockSize>(s_a, s_b, g_a, g_b);
}
template <typename T, typename GRAD_T, int blockSize>
__global__ void lamb_cuda_kernel_part1(
T* __restrict__ p,
GRAD_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed
T* __restrict__ m,
T* __restrict__ v,
const GRAD_T* __restrict__ g,
const float b1,
const float b2,
const float eps,
const float grad_scale,
const float step_size,
const size_t tsize,
adamMode_t mode,
const float decay,
T* __restrict__ w_l2_i,
T* __restrict__ u_l2_i)
{
// Assuming 2D grids and 2D blocks
const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
const int threadsPerBlock = blockDim.x * blockDim.y;
const int threadIdInBlock = cg::this_thread_block().thread_rank();
const int i = (blockId * threadsPerBlock + threadIdInBlock);
const int totThreads = gridDim.x * gridDim.y * threadsPerBlock;
T reg_w = 0;
T reg_u = 0;
for (int j = i; j < tsize; j += totThreads) {
T scaled_grad = g[j] / grad_scale;
T pj = p[j];
m[j] = b1 * m[j] + (1 - b1) * scaled_grad;
v[j] = b2 * v[j] + (1 - b2) * scaled_grad * scaled_grad;
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(v[j] + eps);
else // Mode 1
denom = sqrtf(v[j]) + eps;
T update = (m[j] / denom) + (decay * p[j]);
reg_u += update * update;
reg_w += pj * pj;
}
reduce_two_vectors_in_register<T, blockSize>(reg_w, reg_u, w_l2_i, u_l2_i);
}
template <typename T, typename GRAD_T, int blockSize>
__global__ void lamb_cuda_kernel_part2(const size_t tsize, T* __restrict__ g_a, T* __restrict__ g_b)
{
T* s_a = SharedMemory<T>();
T* s_b = SharedMemory<T>() + cg::this_thread_block().size();
const int threadIdInBlock = cg::this_thread_block().thread_rank();
s_a[threadIdInBlock] = g_a[threadIdInBlock];
s_b[threadIdInBlock] = g_b[threadIdInBlock];
if (threadIdInBlock >= tsize) {
s_a[threadIdInBlock] = 0.0;
s_b[threadIdInBlock] = 0.0;
}
reduce_block_in_shared_memory<T, blockSize>(s_a, s_b, g_a, g_b);
}
template <typename T, typename GRAD_T>
__global__ void lamb_cuda_kernel_part3(
T* __restrict__ p,
GRAD_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed
T* __restrict__ m,
T* __restrict__ v,
const GRAD_T* __restrict__ g,
const float b1,
const float b2,
const float max_coeff,
const float min_coeff,
const float eps,
const float grad_scale,
const float step_size,
const size_t tsize,
adamMode_t mode,
const float decay,
T* __restrict__ w_l2_i,
T* __restrict__ u_l2_i,
T* __restrict__ lamb_coeff_val)
{
// Assuming 2D grids and 2D blocks
const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
const int threadsPerBlock = blockDim.x * blockDim.y;
const int threadIdInBlock = cg::this_thread_block().thread_rank();
const int i = (blockId * threadsPerBlock + threadIdInBlock);
const int totThreads = gridDim.x * gridDim.y * threadsPerBlock;
T reg_w = sqrtf(w_l2_i[0]);
T reg_u = sqrtf(u_l2_i[0]);
float lamb_coeff = 1.0;
if (reg_w != 0 && reg_u != 0) {
lamb_coeff = reg_w / reg_u;
if (lamb_coeff > max_coeff) { lamb_coeff = max_coeff; }
if (lamb_coeff < min_coeff) { lamb_coeff = min_coeff; }
}
if (blockId == 0 && threadIdInBlock == 0) {
lamb_coeff_val[0] = lamb_coeff;
// printf("Cuda Lamb Coeff is %.6f \n",lamb_coeff);
}
for (int j = i; j < tsize; j += totThreads) {
T pj = (float)p[j];
T mj = m[j];
T vj = v[j];
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(vj + eps);
else // Mode 1
denom = sqrtf(vj) + eps;
T update = (mj / denom) + (decay * pj);
pj = pj - (step_size * lamb_coeff * update);
p[j] = pj;
if (p_copy != NULL) p_copy[j] = (GRAD_T)pj;
}
}
void fused_lamb_cuda(at::Tensor& p,
at::Tensor& p_copy,
at::Tensor& m,
at::Tensor& v,
at::Tensor& g,
float lr,
float beta1,
float beta2,
float max_coeff,
float min_coeff,
float eps,
float grad_scale,
int step,
int mode,
int bias_correction,
float decay,
at::Tensor& w_l2_i,
at::Tensor& u_l2_i,
at::Tensor& lamb_coeff)
{
// using namespace at;
// Get tensor size
int tsize = p.numel();
// Determine #threads and #blocks
const int threadsPerBlock = 512;
int num_blocks = (tsize + threadsPerBlock - 1) / threadsPerBlock;
if (num_blocks > 512) num_blocks = 512;
int smemsize = 0;
if (p.type().scalarType() == at::ScalarType::Double)
smemsize = 2 * threadsPerBlock * sizeof(double);
else
smemsize = 2 * threadsPerBlock * sizeof(float);
const dim3 blocks(num_blocks);
const dim3 threads(threadsPerBlock);
AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p),
"parameter tensor is too large to be indexed with int32");
// Constants
float step_size = 0;
if (bias_correction == 1) {
const float bias_correction1 = 1 - ::pow(beta1, step);
const float bias_correction2 = 1 - ::pow(beta2, step);
step_size = lr * std::sqrt(bias_correction2) / bias_correction1;
} else {
step_size = lr;
}
hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
if (g.type().scalarType() == at::ScalarType::Half) {
// all other values should be fp32 for half gradients
AT_ASSERTM(p.type().scalarType() == at::ScalarType::Float,
"expected parameter to be of float type");
// dispatch is done on the gradient type
using namespace at; // prevents "toString is undefined" errors
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
g.scalar_type(), "lamb_cuda_kernel", ([&] {
using accscalar_t = at::acc_type<scalar_t, true>;
hipLaunchKernelGGL(( lamb_cuda_kernel_part1<accscalar_t, scalar_t, threadsPerBlock>)
, dim3(blocks), dim3(threadsPerBlock), smemsize, stream,
p.data<accscalar_t>(),
p_copy.numel() ? p_copy.data<scalar_t>() : NULL,
m.data<accscalar_t>(),
v.data<accscalar_t>(),
g.data<scalar_t>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t)mode,
decay,
w_l2_i.data<accscalar_t>(),
u_l2_i.data<accscalar_t>());
hipLaunchKernelGGL(( lamb_cuda_kernel_part2<accscalar_t, scalar_t, threadsPerBlock>)
, dim3(1), dim3(threadsPerBlock), smemsize, stream,
num_blocks, w_l2_i.data<accscalar_t>(), u_l2_i.data<accscalar_t>());
hipLaunchKernelGGL(( lamb_cuda_kernel_part3<accscalar_t, scalar_t>)
, dim3(blocks), dim3(threadsPerBlock), smemsize, stream,
p.data<accscalar_t>(),
p_copy.numel() ? p_copy.data<scalar_t>() : NULL,
m.data<accscalar_t>(),
v.data<accscalar_t>(),
g.data<scalar_t>(),
beta1,
beta2,
max_coeff,
min_coeff,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t)mode,
decay,
w_l2_i.data<accscalar_t>(),
u_l2_i.data<accscalar_t>(),
lamb_coeff.data<accscalar_t>());
}));
} else {
using namespace at;
AT_DISPATCH_FLOATING_TYPES(
g.scalar_type(), "lamb_cuda_kernel", ([&] {
hipLaunchKernelGGL(( lamb_cuda_kernel_part1<scalar_t, scalar_t, threadsPerBlock>)
, dim3(blocks), dim3(threadsPerBlock), smemsize, stream,
p.data<scalar_t>(),
NULL, // don't output p_copy for fp32, it's wasted write
m.data<scalar_t>(),
v.data<scalar_t>(),
g.data<scalar_t>(),
beta1,
beta2,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t)mode,
decay,
w_l2_i.data<scalar_t>(),
u_l2_i.data<scalar_t>());
hipLaunchKernelGGL(( lamb_cuda_kernel_part2<scalar_t, scalar_t, threadsPerBlock>)
, dim3(1), dim3(threadsPerBlock), smemsize, stream,
num_blocks, w_l2_i.data<scalar_t>(), u_l2_i.data<scalar_t>());
hipLaunchKernelGGL(( lamb_cuda_kernel_part3<scalar_t, scalar_t>)
, dim3(blocks), dim3(threadsPerBlock), smemsize, stream,
p.data<scalar_t>(),
NULL, // don't output p_copy for fp32, it's wasted write
m.data<scalar_t>(),
v.data<scalar_t>(),
g.data<scalar_t>(),
beta1,
beta2,
max_coeff,
min_coeff,
eps,
grad_scale,
step_size,
tsize,
(adamMode_t)mode,
decay,
w_l2_i.data<scalar_t>(),
u_l2_i.data<scalar_t>(),
lamb_coeff.data<scalar_t>());
}));
}
C10_HIP_CHECK(hipGetLastError());
}
// template __device__ void reduce_two_vectors_in_register<float,512>(float a, float b, float* g_a,
// float* g_b, cg::grid_group &cgg);
// !!! This is a file automatically generated by hipify!!!
#include <ATen/hip/HIPContext.h>
#include <torch/extension.h>
#include <vector>
#include "custom_hip_layers.h"
template <typename T>
at::Tensor ds_quantize(at::Tensor& vals, int groups, int bits)
{
auto t_size = vals.sizes();
int size = 1;
for (auto dim : t_size) size *= dim;
if ((((size / groups) - 1) / 4096 + 1) <= MAX_REG) {
launch_quantize_kernel(
(T*)vals.data_ptr(), size, groups, bits, at::hip::getCurrentHIPStreamMasqueradingAsCUDA());
}
return vals;
}
template <typename T>
at::Tensor ds_sr_quantize(at::Tensor& vals, int groups, int bits)
{
auto t_size = vals.sizes();
int size = 1;
for (auto dim : t_size) size *= dim;
if (((size / groups) / 4 / 1024) <= 256) {
launch_sr_quantize_kernel(
(T*)vals.data_ptr(), size, groups, bits, at::hip::getCurrentHIPStreamMasqueradingAsCUDA());
}
return vals;
}
template <typename T>
at::Tensor ds_quantize_asym(at::Tensor& vals, int groups, int bits)
{
auto t_size = vals.sizes();
int size = 1;
for (auto dim : t_size) size *= dim;
if ((((size / groups) - 1) / 4096 + 1) <= MAX_REG) {
launch_quantize_kernel_asym(
(T*)vals.data_ptr(), size, groups, bits, at::hip::getCurrentHIPStreamMasqueradingAsCUDA());
}
return vals;
}
template <typename T>
at::Tensor ds_sr_quantize_asym(at::Tensor& vals, int groups, int bits)
{
auto t_size = vals.sizes();
int size = 1;
for (auto dim : t_size) size *= dim;
if (((size / groups) / 4 / 1024) <= 256) {
launch_sr_quantize_kernel_asym(
(T*)vals.data_ptr(), size, groups, bits, at::hip::getCurrentHIPStreamMasqueradingAsCUDA());
}
return vals;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("ds_quantize_fp32", &ds_quantize<float>, "DeepSpeed Quantize with fp32 (CUDA)");
m.def("ds_quantize_fp16", &ds_quantize<__half>, "DeepSpeed Quantize with fp16 (CUDA)");
m.def("ds_sr_quantize_fp32", &ds_sr_quantize<float>, "DeepSpeed Quantize with fp32 (CUDA)");
m.def("ds_sr_quantize_fp16", &ds_sr_quantize<__half>, "DeepSpeed Quantize with fp16 (CUDA)");
m.def("ds_quantize_asym_fp32", &ds_quantize_asym<float>, "DeepSpeed Quantize with fp32 (CUDA)");
m.def(
"ds_quantize_asym_fp16", &ds_quantize_asym<__half>, "DeepSpeed Quantize with fp16 (CUDA)");
m.def("ds_sr_quantize_asym_fp32",
&ds_sr_quantize_asym<float>,
"DeepSpeed Quantize with fp32 (CUDA)");
m.def("ds_sr_quantize_asym_fp16",
&ds_sr_quantize_asym<__half>,
"DeepSpeed Quantize with fp16 (CUDA)");
}
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
#include <math.h>
#include "custom_hip_layers.h"
namespace cg = cooperative_groups;
__global__ void quantize_kernel(__half* vals, int group_size, int num_bits)
{
#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int gid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int id = threadIdx.x;
float2* vals_cast = reinterpret_cast<float2*>(vals);
float2 data[MAX_REG];
int group_id = blockIdx.x;
{
int group_index = id;
int reg_count = 0;
int offset = group_id * group_size;
float max = -10000.0;
while (group_index < group_size && reg_count < MAX_REG) {
data[reg_count] = vals_cast[offset + group_index];
__half* data_h = reinterpret_cast<__half*>(&data[reg_count]);
if (abs((float)data_h[0]) > max) max = abs((float)data_h[0]);
if (abs((float)data_h[1]) > max) max = abs((float)data_h[1]);
if (abs((float)data_h[2]) > max) max = abs((float)data_h[2]);
if (abs((float)data_h[3]) > max) max = abs((float)data_h[3]);
group_index += blockDim.x;
reg_count++;
}
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(max, i);
if (max < temp) max = temp;
}
__shared__ float partialMax[WARP_SIZE];
if (lane == 0) partialMax[gid] = max;
b.sync();
if (lane < warp_num) max = partialMax[lane];
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_down(max, i);
if (max < temp) max = temp;
}
max = g.shfl(max, 0);
float q_scale = (1 << num_bits) / (2 * max + 1e-5);
float q_scale_inv = 1 / q_scale;
for (int i = 0; i < reg_count; i++) {
group_index = i * blockDim.x + id;
if (group_index < group_size) {
__half2* data_h = reinterpret_cast<__half2*>(&data[i]);
float2 q_data[2];
q_data[0] = __half22float2(data_h[0]);
q_data[1] = __half22float2(data_h[1]);
float2 q_data_int[2];
q_data_int[0].x = roundf(q_data[0].x * q_scale);
q_data_int[0].y = roundf(q_data[0].y * q_scale);
q_data_int[1].x = roundf(q_data[1].x * q_scale);
q_data_int[1].y = roundf(q_data[1].y * q_scale);
q_data_int[0].x *= q_scale_inv;
q_data_int[0].y *= q_scale_inv;
q_data_int[1].x *= q_scale_inv;
q_data_int[1].y *= q_scale_inv;
data_h[0] = __float22half2_rn(q_data_int[0]);
data_h[1] = __float22half2_rn(q_data_int[1]);
vals_cast[offset + group_index] = data[i];
}
}
}
#endif
}
__global__ void quantize_kernel(float* vals, int group_size, int num_bits)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int gid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int id = threadIdx.x;
float4* vals_cast = reinterpret_cast<float4*>(vals);
float4 data[MAX_REG];
int bid = blockIdx.x;
int group_index = bid * group_size + id;
int reg_count = 0;
float max = -10000.0;
while (id < group_size && reg_count < MAX_REG) {
float4 data_reg = vals_cast[group_index];
data[reg_count] = data_reg;
if (abs(data_reg.x) > max) max = abs(data_reg.x);
if (abs(data_reg.y) > max) max = abs(data_reg.y);
if (abs(data_reg.z) > max) max = abs(data_reg.z);
if (abs(data_reg.w) > max) max = abs(data_reg.w);
group_index += blockDim.x;
id += blockDim.x;
reg_count++;
}
id = threadIdx.x;
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(max, i);
if (max < temp) max = temp;
}
__shared__ float partialMax[WARP_SIZE];
if (lane == 0) partialMax[gid] = max;
b.sync();
if (lane < warp_num) max = partialMax[lane];
b.sync();
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(max, i);
if (max < temp) max = temp;
}
max = g.shfl(max, 0);
float q_scale = (1 << num_bits) / (2 * max + 1e-5);
float q_scale_inv = 1 / q_scale;
for (int i = 0; i < reg_count; i++) {
group_index = i * blockDim.x + id;
if (group_index < group_size) {
float4 q_data;
q_data = data[i];
float4 q_data_int;
q_data_int.x = roundf(q_data.x * q_scale);
q_data_int.y = roundf(q_data.y * q_scale);
q_data_int.w = roundf(q_data.w * q_scale);
q_data_int.z = roundf(q_data.z * q_scale);
q_data.x = q_data_int.x * q_scale_inv;
q_data.y = q_data_int.y * q_scale_inv;
q_data.w = q_data_int.w * q_scale_inv;
q_data.z = q_data_int.z * q_scale_inv;
vals_cast[group_index + bid * group_size] = q_data;
}
}
}
template <typename T>
void launch_quantize_kernel(T* vals,
int total_count,
int group_num,
int num_bits,
hipStream_t stream)
{
dim3 grid_dim(group_num);
dim3 block_dim(1024);
hipLaunchKernelGGL(( quantize_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
vals, (total_count / group_num) / 4, num_bits);
}
template void launch_quantize_kernel(float* vals,
int total_count,
int group_num,
int num_bits,
hipStream_t stream);
template void launch_quantize_kernel(__half* vals,
int total_count,
int group_num,
int num_bits,
hipStream_t stream);
__global__ void sr_quantize_kernel(__half* vals,
int token_size,
int token_num,
int num_bits,
std::pair<uint64_t, uint64_t> seed)
{
#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int gid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int idx = blockIdx.x * blockDim.x + threadIdx.x;
float2* vals_cast = reinterpret_cast<float2*>(vals);
__half2 data_low[128];
__half2 data_high[128];
int bid = blockIdx.x;
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed.first, idx, seed.second, &state);
unsigned int tid = threadIdx.x;
int reg_count = 0;
int offset = bid * token_size;
int group_index = bid * token_size + tid;
int total_count = token_size * token_num;
if (group_index < total_count) {
// float min = 10000.0;
float max = -10000.0;
while (tid < token_size) {
float2 data = vals_cast[offset + tid];
__half2* data_h = reinterpret_cast<__half2*>(&data);
data_low[reg_count] = data_h[0];
data_high[reg_count] = data_h[1];
float2 data_f[2];
data_f[0] = __half22float2(data_h[0]);
data_f[1] = __half22float2(data_h[1]);
if (abs((float)data_f[0].x) > max) max = abs((float)data_f[0].x);
if (abs((float)data_f[0].y) > max) max = abs((float)data_f[0].y);
if (abs((float)data_f[1].x) > max) max = abs((float)data_f[1].x);
if (abs((float)data_f[1].y) > max) max = abs((float)data_f[1].y);
tid += blockDim.x;
reg_count++;
}
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(max, i);
if (max < temp) max = temp;
}
__shared__ float partialMax[WARP_SIZE];
if (lane == 0) partialMax[gid] = max;
b.sync();
if (lane < warp_num) max = partialMax[lane];
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(max, i);
if (max < temp) max = temp;
}
max = g.shfl(max, 0);
float q_scale_val = (float)(1 << num_bits) / (max * 2 + 1e-5);
float high_q = (float)((1 << (num_bits - 1)) - 1);
float low_q = (float)(-((1 << (num_bits - 1))));
for (int i = 0; i < reg_count; i++) {
int token_index = i * blockDim.x + threadIdx.x;
if (token_index < token_size) {
float2 data_f[2];
data_f[0] = __half22float2(data_low[i]);
data_f[1] = __half22float2(data_high[i]);
float2 q_data_int[2];
q_data_int[0].x = (float)((int)(data_f[0].x * q_scale_val));
q_data_int[0].y = (float)((int)(data_f[0].y * q_scale_val));
q_data_int[1].x = (float)((int)(data_f[1].x * q_scale_val));
q_data_int[1].y = (float)((int)(data_f[1].y * q_scale_val));
// Stochastic rounding
float4 rand = hiprand_uniform4(&state);
float q_error[4];
q_error[0] = abs(data_f[0].x - (q_data_int[0].x / q_scale_val)) * q_scale_val;
q_error[1] = abs(data_f[0].y - (q_data_int[0].y / q_scale_val)) * q_scale_val;
q_error[2] = abs(data_f[1].x - (q_data_int[1].x / q_scale_val)) * q_scale_val;
q_error[3] = abs(data_f[1].y - (q_data_int[1].y / q_scale_val)) * q_scale_val;
q_data_int[0].x =
(rand.x < q_error[0] && q_data_int[0].x > low_q && q_data_int[0].x < high_q)
? (q_data_int[0].x + (data_f[0].x > 0 ? 1 : -1))
: q_data_int[0].x;
q_data_int[0].y =
(rand.y < q_error[1] && q_data_int[0].y > low_q && q_data_int[0].y < high_q)
? (q_data_int[0].y + (data_f[0].y > 0 ? 1 : -1))
: q_data_int[0].y;
q_data_int[1].x =
(rand.w < q_error[2] && q_data_int[1].x > low_q && q_data_int[1].x < high_q)
? (q_data_int[1].x + (data_f[1].x > 0 ? 1 : -1))
: q_data_int[1].x;
q_data_int[1].y =
(rand.z < q_error[3] && q_data_int[1].y > low_q && q_data_int[1].y < high_q)
? (q_data_int[1].y + (data_f[1].y > 0 ? 1 : -1))
: q_data_int[1].y;
data_f[0].x = q_data_int[0].x / q_scale_val;
data_f[0].y = q_data_int[0].y / q_scale_val;
data_f[1].x = q_data_int[1].x / q_scale_val;
data_f[1].y = q_data_int[1].y / q_scale_val;
float2 result;
__half2* result_h = reinterpret_cast<__half2*>(&result);
result_h[0] = __float22half2_rn(data_f[0]);
result_h[1] = __float22half2_rn(data_f[1]);
vals_cast[offset + token_index] = result;
}
}
}
#endif
}
__global__ void sr_quantize_kernel(float* vals,
int token_size,
int token_num,
int num_bits,
std::pair<uint64_t, uint64_t> seed)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int gid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int id = threadIdx.x;
int idx = blockIdx.x * blockDim.x + id;
float4* vals_cast = reinterpret_cast<float4*>(vals);
float4 data[128];
int bid = blockIdx.x;
int tid = threadIdx.x;
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed.first, idx, seed.second, &state);
int group_index = bid * token_size + threadIdx.x;
int reg_count = 0;
int total_count = token_size * token_num;
if (group_index < total_count) {
// float min = 10000.0;
float max = -10000.0;
while (tid < token_size) {
data[reg_count] = vals_cast[group_index];
if (abs(data[reg_count].x) > max) max = abs(data[reg_count].x);
if (abs(data[reg_count].y) > max) max = abs(data[reg_count].y);
if (abs(data[reg_count].z) > max) max = abs(data[reg_count].z);
if (abs(data[reg_count].w) > max) max = abs(data[reg_count].w);
group_index += blockDim.x;
tid += blockDim.x;
reg_count++;
}
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(max, i);
if (max < temp) max = temp;
}
__shared__ float partialMax[WARP_SIZE];
if (lane == 0) partialMax[gid] = max;
b.sync();
if (lane < warp_num) max = partialMax[lane];
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(max, i);
if (max < temp) max = temp;
}
max = g.shfl(max, 0);
float q_scale_val = (float)(1 << num_bits) / (max * 2 + 1e-5);
float high_q = (float)((1 << (num_bits - 1)) - 1);
float low_q = (float)(-((1 << (num_bits - 1))));
int offset = (bid)*token_size;
for (int i = 0; i < reg_count; i++) {
group_index = i * blockDim.x + threadIdx.x;
if (group_index < token_size) {
float4 q_data = data[i];
float4 q_data_int;
q_data_int.x = (float)((int)(q_data.x * q_scale_val));
q_data_int.y = (float)((int)(q_data.y * q_scale_val));
q_data_int.w = (float)((int)(q_data.w * q_scale_val));
q_data_int.z = (float)((int)(q_data.z * q_scale_val));
// Stochastic rounding
float4 rand = hiprand_uniform4(&state);
float q_error[4];
q_error[0] = abs(q_data.x - (q_data_int.x / q_scale_val)) * q_scale_val;
q_error[1] = abs(q_data.y - (q_data_int.y / q_scale_val)) * q_scale_val;
q_error[2] = abs(q_data.w - (q_data_int.w / q_scale_val)) * q_scale_val;
q_error[3] = abs(q_data.z - (q_data_int.z / q_scale_val)) * q_scale_val;
q_data_int.x =
(rand.x < q_error[0] && q_data_int.x > low_q && q_data_int.x < high_q)
? (q_data_int.x + (q_data.x > 0 ? 1 : -1))
: q_data_int.x;
q_data_int.y =
(rand.y < q_error[1] && q_data_int.y > low_q && q_data_int.y < high_q)
? (q_data_int.y + (q_data.y > 0 ? 1 : -1))
: q_data_int.y;
q_data_int.w =
(rand.w < q_error[2] && q_data_int.w > low_q && q_data_int.w < high_q)
? (q_data_int.w + (q_data.w > 0 ? 1 : -1))
: q_data_int.w;
q_data_int.z =
(rand.z < q_error[3] && q_data_int.z > low_q && q_data_int.z < high_q)
? (q_data_int.z + (q_data.z > 0 ? 1 : -1))
: q_data_int.z;
q_data_int.x /= q_scale_val;
q_data_int.y /= q_scale_val;
q_data_int.w /= q_scale_val;
q_data_int.z /= q_scale_val;
vals_cast[group_index + offset] = q_data_int;
}
}
}
}
template <typename T>
void launch_sr_quantize_kernel(T* vals,
int total_count,
int group_num,
int num_bits,
hipStream_t stream)
{
dim3 block_dim(1024);
dim3 grid_dim(group_num);
uint64_t inc = total_count / grid_dim.x / block_dim.x;
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
hipLaunchKernelGGL(( sr_quantize_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
vals, (total_count / group_num) / 4, group_num, num_bits, seed);
}
template void launch_sr_quantize_kernel(float* vals,
int total_count,
int group_num,
int num_bits,
hipStream_t stream);
template void launch_sr_quantize_kernel(__half* vals,
int total_count,
int group_num,
int num_bits,
hipStream_t stream);
__global__ void quantize_kernel_asym(__half* vals, int group_size, int num_bits)
{
#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int gid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int id = threadIdx.x;
float2* vals_cast = reinterpret_cast<float2*>(vals);
float2 data[MAX_REG];
int group_id = blockIdx.x;
{
int group_index = id;
int reg_count = 0;
int offset = group_id * group_size;
float max = -10000.0;
float min = 10000.0;
while (group_index < group_size && reg_count < MAX_REG) {
data[reg_count] = vals_cast[offset + group_index];
__half* data_h = reinterpret_cast<__half*>(&data[reg_count]);
if (((float)data_h[0]) > max) max = (float)data_h[0];
if (((float)data_h[1]) > max) max = (float)data_h[1];
if (((float)data_h[2]) > max) max = (float)data_h[2];
if (((float)data_h[3]) > max) max = (float)data_h[3];
if (((float)data_h[0]) < min) min = (float)data_h[0];
if (((float)data_h[1]) < min) min = (float)data_h[1];
if (((float)data_h[2]) < min) min = (float)data_h[2];
if (((float)data_h[3]) < min) min = (float)data_h[3];
group_index += blockDim.x;
reg_count++;
}
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(max, i);
if (max < temp) max = temp;
}
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(min, i);
if (min > temp) min = temp;
}
__shared__ float partialMax[WARP_SIZE];
__shared__ float partialMin[WARP_SIZE];
if (lane == 0) partialMax[gid] = max;
if (lane == 0) partialMin[gid] = min;
b.sync();
if (lane < warp_num) max = partialMax[lane];
if (lane < warp_num) min = partialMin[lane];
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(max, i);
if (max < temp) max = temp;
}
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(min, i);
if (min > temp) min = temp;
}
max = g.shfl(max, 0);
min = g.shfl(min, 0);
float q_scale = ((max - min) + 1e-5) / (float)(1 << num_bits);
float q_scale_inv = 1 / q_scale;
for (int i = 0; i < reg_count; i++) {
group_index = i * blockDim.x + id;
if (group_index < group_size) {
__half2* data_h = reinterpret_cast<__half2*>(&data[i]);
float2 q_data[2];
q_data[0] = __half22float2(data_h[0]);
q_data[1] = __half22float2(data_h[1]);
float2 q_data_int[2];
q_data_int[0].x = roundf((q_data[0].x - min) * q_scale_inv);
q_data_int[0].y = roundf((q_data[0].y - min) * q_scale_inv);
q_data_int[1].x = roundf((q_data[1].x - min) * q_scale_inv);
q_data_int[1].y = roundf((q_data[1].y - min) * q_scale_inv);
q_data_int[0].x = q_data_int[0].x * q_scale + min;
q_data_int[0].y = q_data_int[0].y * q_scale + min;
q_data_int[1].x = q_data_int[1].x * q_scale + min;
q_data_int[1].y = q_data_int[1].y * q_scale + min;
data_h[0] = __float22half2_rn(q_data_int[0]);
data_h[1] = __float22half2_rn(q_data_int[1]);
vals_cast[offset + group_index] = data[i];
}
}
}
#endif
}
__global__ void quantize_kernel_asym(float* vals, int group_size, int num_bits)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int gid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int id = threadIdx.x;
float4* vals_cast = reinterpret_cast<float4*>(vals);
float4 data[MAX_REG];
int bid = blockIdx.x;
int group_index = bid * group_size + id;
int reg_count = 0;
float max = -10000.0;
float min = 10000.0;
while (id < group_size && reg_count < MAX_REG) {
float4 data_reg = vals_cast[group_index];
data[reg_count] = data_reg;
if (data_reg.x > max) max = data_reg.x;
if (data_reg.y > max) max = data_reg.y;
if (data_reg.w > max) max = data_reg.w;
if (data_reg.z > max) max = data_reg.z;
if (data_reg.x < min) min = data_reg.x;
if (data_reg.y < min) min = data_reg.y;
if (data_reg.w < min) min = data_reg.w;
if (data_reg.z < min) min = data_reg.z;
group_index += blockDim.x;
id += blockDim.x;
reg_count++;
}
id = threadIdx.x;
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(max, i);
if (max < temp) max = temp;
}
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(min, i);
if (min > temp) min = temp;
}
__shared__ float partialMax[WARP_SIZE];
__shared__ float partialMin[WARP_SIZE];
if (lane == 0) partialMax[gid] = max;
if (lane == 0) partialMin[gid] = min;
b.sync();
if (lane < warp_num) max = partialMax[lane];
if (lane < warp_num) min = partialMin[lane];
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(max, i);
if (max < temp) max = temp;
}
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(min, i);
if (min > temp) min = temp;
}
max = g.shfl(max, 0);
min = g.shfl(min, 0);
float q_scale = ((max - min) + 1e-5) / (float)(1 << num_bits);
float q_scale_inv = 1 / q_scale;
for (int i = 0; i < reg_count; i++) {
group_index = i * blockDim.x + id;
if (group_index < group_size) {
float4 q_data;
q_data = data[i];
float4 q_data_int;
q_data_int.x = roundf((q_data.x - min) * q_scale_inv);
q_data_int.y = roundf((q_data.y - min) * q_scale_inv);
q_data_int.w = roundf((q_data.w - min) * q_scale_inv);
q_data_int.z = roundf((q_data.z - min) * q_scale_inv);
q_data.x = q_data_int.x * q_scale + min;
q_data.y = q_data_int.y * q_scale + min;
q_data.w = q_data_int.w * q_scale + min;
q_data.z = q_data_int.z * q_scale + min;
vals_cast[group_index + bid * group_size] = q_data;
}
}
}
template <typename T>
void launch_quantize_kernel_asym(T* vals,
int total_count,
int group_num,
int num_bits,
hipStream_t stream)
{
dim3 grid_dim(group_num);
dim3 block_dim(1024);
hipLaunchKernelGGL(( quantize_kernel_asym), dim3(grid_dim), dim3(block_dim), 0, stream,
vals, (total_count / group_num) / 4, num_bits);
}
template void launch_quantize_kernel_asym(float* vals,
int total_count,
int group_num,
int num_bits,
hipStream_t stream);
template void launch_quantize_kernel_asym(__half* vals,
int total_count,
int group_num,
int num_bits,
hipStream_t stream);
__global__ void sr_quantize_kernel_asym(__half* vals,
int token_size,
int token_num,
int num_bits,
std::pair<uint64_t, uint64_t> seed)
{
#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int gid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int idx = blockIdx.x * blockDim.x + threadIdx.x;
float2* vals_cast = reinterpret_cast<float2*>(vals);
__half2 data_low[128];
__half2 data_high[128];
int bid = blockIdx.x;
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed.first, idx, seed.second, &state);
unsigned int tid = threadIdx.x;
int reg_count = 0;
int offset = bid * token_size;
int group_index = bid * token_size + tid;
int total_count = token_size * token_num;
if (group_index < total_count) {
float min = 10000.0;
float max = -10000.0;
while (tid < token_size) {
float2 data = vals_cast[offset + tid];
__half2* data_h = reinterpret_cast<__half2*>(&data);
data_low[reg_count] = data_h[0];
data_high[reg_count] = data_h[1];
float2 data_f[2];
data_f[0] = __half22float2(data_h[0]);
data_f[1] = __half22float2(data_h[1]);
if (((float)data_f[0].x) > max) max = (float)data_f[0].x;
if (((float)data_f[0].y) > max) max = (float)data_f[0].y;
if (((float)data_f[1].x) > max) max = (float)data_f[1].x;
if (((float)data_f[1].y) > max) max = (float)data_f[1].y;
if (((float)data_f[0].x) < min) min = (float)data_f[0].x;
if (((float)data_f[0].y) < min) min = (float)data_f[0].y;
if (((float)data_f[1].x) < min) min = (float)data_f[1].x;
if (((float)data_f[1].y) < min) min = (float)data_f[1].y;
tid += blockDim.x;
reg_count++;
}
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(max, i);
if (max < temp) max = temp;
}
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(min, i);
if (min > temp) min = temp;
}
__shared__ float partialMax[WARP_SIZE];
__shared__ float partialMin[WARP_SIZE];
if (lane == 0) partialMax[gid] = max;
if (lane == 0) partialMin[gid] = min;
b.sync();
if (lane < warp_num) max = partialMax[lane];
if (lane < warp_num) min = partialMin[lane];
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(max, i);
if (max < temp) max = temp;
}
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(min, i);
if (min > temp) min = temp;
}
max = g.shfl(max, 0);
min = g.shfl(min, 0);
float q_scale_val = ((max - min) + 1e-5) / (float)(1 << num_bits);
float q_scale_val_inv = 1 / q_scale_val;
float high_q = (float)((1 << num_bits) - 1);
for (int i = 0; i < reg_count; i++) {
int token_index = i * blockDim.x + threadIdx.x;
if (token_index < token_size) {
float2 data_f[2];
data_f[0] = __half22float2(data_low[i]);
data_f[1] = __half22float2(data_high[i]);
float2 q_data_int[2];
q_data_int[0].x = (float)((unsigned int)((data_f[0].x - min) * q_scale_val_inv));
q_data_int[0].y = (float)((unsigned int)((data_f[0].y - min) * q_scale_val_inv));
q_data_int[1].x = (float)((unsigned int)((data_f[1].x - min) * q_scale_val_inv));
q_data_int[1].y = (float)((unsigned int)((data_f[1].y - min) * q_scale_val_inv));
// Stochastic rounding
float4 rand = hiprand_uniform4(&state);
float q_error[4];
q_error[0] =
abs(data_f[0].x - ((q_data_int[0].x * q_scale_val) + min)) * q_scale_val_inv;
q_error[1] =
abs(data_f[0].y - ((q_data_int[0].y * q_scale_val) + min)) * q_scale_val_inv;
q_error[2] =
abs(data_f[1].x - ((q_data_int[1].x * q_scale_val) + min)) * q_scale_val_inv;
q_error[3] =
abs(data_f[1].y - ((q_data_int[1].y * q_scale_val) + min)) * q_scale_val_inv;
q_data_int[0].x = (rand.x < q_error[0] && q_data_int[0].x < high_q)
? (q_data_int[0].x + 1)
: q_data_int[0].x;
q_data_int[0].y = (rand.y < q_error[1] && q_data_int[0].y < high_q)
? (q_data_int[0].y + 1)
: q_data_int[0].y;
q_data_int[1].x = (rand.w < q_error[2] && q_data_int[1].x < high_q)
? (q_data_int[1].x + 1)
: q_data_int[1].x;
q_data_int[1].y = (rand.z < q_error[3] && q_data_int[1].y < high_q)
? (q_data_int[1].y + 1)
: q_data_int[1].y;
data_f[0].x = q_data_int[0].x * q_scale_val + min;
data_f[0].y = q_data_int[0].y * q_scale_val + min;
data_f[1].x = q_data_int[1].x * q_scale_val + min;
data_f[1].y = q_data_int[1].y * q_scale_val + min;
float2 result;
__half2* result_h = reinterpret_cast<__half2*>(&result);
result_h[0] = __float22half2_rn(data_f[0]);
result_h[1] = __float22half2_rn(data_f[1]);
vals_cast[offset + token_index] = result;
}
}
}
#endif
}
__global__ void sr_quantize_kernel_asym(float* vals,
int token_size,
int token_num,
int num_bits,
std::pair<uint64_t, uint64_t> seed)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int gid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int id = threadIdx.x;
int idx = blockIdx.x * blockDim.x + id;
float4* vals_cast = reinterpret_cast<float4*>(vals);
float4 data[128];
int bid = blockIdx.x;
int tid = threadIdx.x;
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed.first, idx, seed.second, &state);
int group_index = bid * token_size + threadIdx.x;
int reg_count = 0;
int total_count = token_size * token_num;
if (group_index < total_count) {
float min = 10000.0;
float max = -10000.0;
while (tid < token_size) {
float4 data_reg = vals_cast[group_index];
data[reg_count] = data_reg;
if (data_reg.x > max) max = data_reg.x;
if (data_reg.y > max) max = data_reg.y;
if (data_reg.w > max) max = data_reg.w;
if (data_reg.z > max) max = data_reg.z;
if (data_reg.x < min) min = data_reg.x;
if (data_reg.y < min) min = data_reg.y;
if (data_reg.w < min) min = data_reg.w;
if (data_reg.z < min) min = data_reg.z;
group_index += blockDim.x;
tid += blockDim.x;
reg_count++;
}
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(max, i);
if (max < temp) max = temp;
}
#pragma unroll
for (int i = 1; i < WARP_SIZE; i <<= 1) {
auto temp = g.shfl_xor(min, i);
if (min > temp) min = temp;
}
__shared__ float partialMax[WARP_SIZE];
__shared__ float partialMin[WARP_SIZE];
if (lane == 0) partialMax[gid] = max;
if (lane == 0) partialMin[gid] = min;
b.sync();
if (lane < warp_num) max = partialMax[lane];
if (lane < warp_num) min = partialMin[lane];
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(max, i);
if (max < temp) max = temp;
}
#pragma unroll
for (int i = 1; i < warp_num; i <<= 1) {
auto temp = g.shfl_down(min, i);
if (min > temp) min = temp;
}
max = g.shfl(max, 0);
min = g.shfl(min, 0);
float q_scale_val = ((max - min) + 1e-5) / (float)(1 << num_bits);
float high_q = (float)((1 << num_bits) - 1);
int offset = (bid)*token_size;
for (int i = 0; i < reg_count; i++) {
group_index = i * blockDim.x + threadIdx.x;
if (group_index < token_size) {
float4 q_data = data[i];
float4 q_data_int;
q_data_int.x = (float)((int)((q_data.x - min) / q_scale_val));
q_data_int.y = (float)((int)((q_data.y - min) / q_scale_val));
q_data_int.w = (float)((int)((q_data.w - min) / q_scale_val));
q_data_int.z = (float)((int)((q_data.z - min) / q_scale_val));
// Stochastic rounding
float4 rand = hiprand_uniform4(&state);
float q_error[4];
q_error[0] = abs(q_data.x - ((q_data_int.x * q_scale_val) + min)) / q_scale_val;
q_error[1] = abs(q_data.y - ((q_data_int.y * q_scale_val) + min)) / q_scale_val;
q_error[2] = abs(q_data.w - ((q_data_int.w * q_scale_val) + min)) / q_scale_val;
q_error[3] = abs(q_data.z - ((q_data_int.z * q_scale_val) + min)) / q_scale_val;
q_data_int.x = (rand.x < q_error[0] && q_data_int.x < high_q) ? (q_data_int.x + 1)
: q_data_int.x;
q_data_int.y = (rand.y < q_error[1] && q_data_int.y < high_q) ? (q_data_int.y + 1)
: q_data_int.y;
q_data_int.w = (rand.w < q_error[2] && q_data_int.w < high_q) ? (q_data_int.w + 1)
: q_data_int.w;
q_data_int.z = (rand.z < q_error[3] && q_data_int.z < high_q) ? (q_data_int.z + 1)
: q_data_int.z;
q_data_int.x = q_data_int.x * q_scale_val + min;
q_data_int.y = q_data_int.y * q_scale_val + min;
q_data_int.w = q_data_int.w * q_scale_val + min;
q_data_int.z = q_data_int.z * q_scale_val + min;
vals_cast[group_index + offset] = q_data_int;
}
}
}
}
template <typename T>
void launch_sr_quantize_kernel_asym(T* vals,
int total_count,
int group_num,
int num_bits,
hipStream_t stream)
{
dim3 block_dim(1024);
dim3 grid_dim(group_num);
uint64_t inc = total_count / grid_dim.x / block_dim.x;
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
hipLaunchKernelGGL(( sr_quantize_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
vals, (total_count / group_num) / 4, group_num, num_bits, seed);
}
template void launch_sr_quantize_kernel_asym(float* vals,
int total_count,
int group_num,
int num_bits,
hipStream_t stream);
template void launch_sr_quantize_kernel_asym(__half* vals,
int total_count,
int group_num,
int num_bits,
hipStream_t stream);
// !!! This is a file automatically generated by hipify!!!
#include "cublas_wrappers_hip.h"
#ifdef __HIP_PLATFORM_HCC__
int cublas_gemm_ex(rocblas_handle handle,
rocblas_operation transa,
rocblas_operation transb,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const float* A,
const float* B,
float* C,
rocblas_gemm_algo algo)
#else
int cublas_gemm_ex(rocblas_handle handle,
rocblas_operation transa,
rocblas_operation transb,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const float* A,
const float* B,
float* C,
cublasGemmAlgo_t algo)
#endif
{
#ifdef __HIP_PLATFORM_HCC__
rocblas_status status = rocblas_gemm_ex(handle,
transa,
transb,
m,
n,
k,
(const void*)alpha,
(const void*)A,
rocblas_datatype_f32_r,
(transa == rocblas_operation_none) ? m : k,
(const void*)B,
rocblas_datatype_f32_r,
(transb == rocblas_operation_none) ? k : n,
(const void*)beta,
C,
rocblas_datatype_f32_r,
m,
C,
rocblas_datatype_f32_r,
m,
rocblas_datatype_f32_r,
algo,
0,
0);
#else
rocblas_status status = rocblas_gemmex(handle,
transa,
transb,
m,
n,
k,
(const void*)alpha,
(const void*)A,
hipR32F,
(transa == rocblas_operation_none) ? m : k,
(const void*)B,
hipR32F,
(transb == rocblas_operation_none) ? k : n,
(const void*)beta,
C,
hipR32F,
m,
hipR32F,
algo);
#endif
#ifdef __HIP_PLATFORM_HCC__
if (status != rocblas_status_success) {
#else
if (status != rocblas_status_success) {
#endif
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m,
n,
k,
(int)status);
return EXIT_FAILURE;
}
return 0;
}
#ifdef __HIP_PLATFORM_HCC__
int cublas_gemm_ex(rocblas_handle handle,
rocblas_operation transa,
rocblas_operation transb,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const __half* A,
const __half* B,
__half* C,
rocblas_gemm_algo algo)
#else
int cublas_gemm_ex(rocblas_handle handle,
rocblas_operation transa,
rocblas_operation transb,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const __half* A,
const __half* B,
__half* C,
cublasGemmAlgo_t algo)
#endif
{
#ifdef __HIP_PLATFORM_HCC__
rocblas_status status = rocblas_gemm_ex(handle,
transa,
transb,
m,
n,
k,
(const void*)alpha,
(const void*)A,
rocblas_datatype_f16_r,
(transa == rocblas_operation_none) ? m : k,
(const void*)B,
rocblas_datatype_f16_r,
(transb == rocblas_operation_none) ? k : n,
(const void*)beta,
(void*)C,
rocblas_datatype_f16_r,
m,
(void*)C,
rocblas_datatype_f16_r,
m,
rocblas_datatype_f32_r,
algo,
0,
0);
#else
rocblas_status status = rocblas_gemmex(handle,
transa,
transb,
m,
n,
k,
(const void*)alpha,
(const void*)A,
hipR16F,
(transa == rocblas_operation_none) ? m : k,
(const void*)B,
hipR16F,
(transb == rocblas_operation_none) ? k : n,
(const void*)beta,
(void*)C,
hipR16F,
m,
hipR32F,
algo);
#endif
#ifdef __HIP_PLATFORM_HCC__
if (status != rocblas_status_success) {
#else
if (status != rocblas_status_success) {
#endif
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m,
n,
k,
(int)status);
return EXIT_FAILURE;
}
return 0;
}
#ifdef __HIP_PLATFORM_HCC__
int cublas_strided_batched_gemm(rocblas_handle handle,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const float* A,
const float* B,
float* C,
rocblas_operation op_A,
rocblas_operation op_B,
int stride_A,
int stride_B,
int stride_C,
int batch,
rocblas_gemm_algo algo)
#else
int cublas_strided_batched_gemm(rocblas_handle handle,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const float* A,
const float* B,
float* C,
rocblas_operation op_A,
rocblas_operation op_B,
int stride_A,
int stride_B,
int stride_C,
int batch,
cublasGemmAlgo_t algo)
#endif
{
#ifdef __HIP_PLATFORM_HCC__
rocblas_status status =
rocblas_gemm_strided_batched_ex(handle,
op_A,
op_B,
m,
n,
k,
alpha,
A,
rocblas_datatype_f32_r,
(op_A == rocblas_operation_none) ? m : k,
stride_A,
B,
rocblas_datatype_f32_r,
(op_B == rocblas_operation_none) ? k : n,
stride_B,
beta,
C,
rocblas_datatype_f32_r,
m,
stride_C,
C,
rocblas_datatype_f32_r,
m,
stride_C,
batch,
rocblas_datatype_f32_r,
algo,
0,
0);
#else
rocblas_status status = cublasGemmStridedBatchedEx(handle,
op_A,
op_B,
m,
n,
k,
alpha,
A,
hipR32F,
(op_A == rocblas_operation_none) ? m : k,
stride_A,
B,
hipR32F,
(op_B == rocblas_operation_none) ? k : n,
stride_B,
beta,
C,
hipR32F,
m,
stride_C,
batch,
hipR32F,
algo);
#endif
#ifdef __HIP_PLATFORM_HCC__
if (status != rocblas_status_success) {
#else
if (status != rocblas_status_success) {
#endif
fprintf(stderr,
"!!!! kernel execution error. (batch: %d, m: %d, n: %d, k: %d, error: %d) \n",
batch,
m,
n,
k,
(int)status);
return EXIT_FAILURE;
}
return 0;
}
#ifdef __HIP_PLATFORM_HCC__
int cublas_strided_batched_gemm(rocblas_handle handle,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const __half* A,
const __half* B,
__half* C,
rocblas_operation op_A,
rocblas_operation op_B,
int stride_A,
int stride_B,
int stride_C,
int batch,
rocblas_gemm_algo algo)
#else
int cublas_strided_batched_gemm(rocblas_handle handle,
int m,
int n,
int k,
const float* alpha,
const float* beta,
const __half* A,
const __half* B,
__half* C,
rocblas_operation op_A,
rocblas_operation op_B,
int stride_A,
int stride_B,
int stride_C,
int batch,
cublasGemmAlgo_t algo)
#endif
{
#ifdef __HIP_PLATFORM_HCC__
rocblas_status status =
rocblas_gemm_strided_batched_ex(handle,
op_A,
op_B,
m,
n,
k,
alpha,
A,
rocblas_datatype_f16_r,
(op_A == rocblas_operation_none) ? m : k,
stride_A,
B,
rocblas_datatype_f16_r,
(op_B == rocblas_operation_none) ? k : n,
stride_B,
beta,
C,
rocblas_datatype_f16_r,
m,
stride_C,
C,
rocblas_datatype_f16_r,
m,
stride_C,
batch,
rocblas_datatype_f32_r,
algo,
0,
0);
#else
rocblas_status status = cublasGemmStridedBatchedEx(handle,
op_A,
op_B,
m,
n,
k,
alpha,
A,
hipR16F,
(op_A == rocblas_operation_none) ? m : k,
stride_A,
B,
hipR16F,
(op_B == rocblas_operation_none) ? k : n,
stride_B,
beta,
C,
hipR16F,
m,
stride_C,
batch,
hipR32F,
algo);
#endif
#ifdef __HIP_PLATFORM_HCC__
if (status != rocblas_status_success) {
#else
if (status != rocblas_status_success) {
#endif
fprintf(stderr,
"!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n",
m,
n,
k,
(int)status);
return EXIT_FAILURE;
}
return 0;
}
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
#include "custom_hip_layers.h"
const int unroll_factor = 4;
__global__ void dropout_kernel(const int N,
const float ratio,
float* out,
const float* Xdata,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
int idx = blockIdx.x * blockDim.x + threadIdx.x;
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed.first, idx, seed.second, &state);
CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
float4 rand = hiprand_uniform4(&state);
uint8_t m[unroll_factor];
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
int i = j * unroll_factor;
mask[i] = (uint8_t)m[0];
mask[i + 1] = (uint8_t)m[1];
mask[i + 2] = (uint8_t)m[2];
mask[i + 3] = (uint8_t)m[3];
out[i] = Xdata[i] * scale * m[0];
out[i + 1] = Xdata[i + 1] * scale * m[1];
out[i + 2] = Xdata[i + 2] * scale * m[2];
out[i + 3] = Xdata[i + 3] * scale * m[3];
}
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
float4 rand = hiprand_uniform4(&state);
float* rand_data = &(rand.x);
int k = 0;
for (int i = high_index; i < N; i++) {
uint8_t m = (uint8_t)(rand_data[k++] > ratio);
out[i] = Xdata[i] * scale * m;
mask[i] = m;
}
}
}
__global__ void dropout_kernel(const int N,
const float ratio,
__half* out,
const __half* Xdata,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
int idx = blockIdx.x * blockDim.x + threadIdx.x;
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed.first, idx, seed.second, &state);
#ifdef __STOCHASTIC_MODE__
const __half2 h_scale = __float2half2_rn(scale);
const float2* x_cast = reinterpret_cast<const float2*>(Xdata);
float2* out_cast = reinterpret_cast<float2*>(out);
uint32_t* mask_cast = reinterpret_cast<uint32_t*>(mask);
uint32_t m_32;
uint8_t* m = reinterpret_cast<uint8_t*>(&m_32);
float2 result_f;
__half2* result_h = reinterpret_cast<__half2*>(&result_f);
__half2 mask_h[2];
float2 mask_f[2];
CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
float2 x_f = x_cast[j];
__half2* x_h = reinterpret_cast<__half2*>(&x_f);
float4 rand = hiprand_uniform4(&state);
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
float* mask_f_data = &mask_f[0].x;
#pragma unroll
for (int i = 0; i < unroll_factor; i++) mask_f_data[i] = (float)(m[i]);
mask_h[0] = __float22half2_rn(mask_f[0]);
mask_h[1] = __float22half2_rn(mask_f[1]);
result_h[0] = x_h[0] * h_scale * mask_h[0];
result_h[1] = x_h[1] * h_scale * mask_h[1];
out_cast[j] = result_f;
mask_cast[j] = m_32;
}
#else
CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
int i = j * unroll_factor;
const __half2* vals_half = reinterpret_cast<const __half2*>(Xdata + i);
float2 vals_half_f[2];
vals_half_f[0] = __half22float2(vals_half[0]);
vals_half_f[1] = __half22float2(vals_half[1]);
uint8_t m[unroll_factor];
float4 rand = hiprand_uniform4(&state);
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
out[i] = __float2half(vals_half_f[0].x * scale * m[0]);
out[i + 1] = __float2half(vals_half_f[0].y * scale * m[1]);
out[i + 2] = __float2half(vals_half_f[1].x * scale * m[2]);
out[i + 3] = __float2half(vals_half_f[1].y * scale * m[3]);
mask[i] = m[0];
mask[i + 1] = m[1];
mask[i + 2] = m[2];
mask[i + 3] = m[3];
}
#endif
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
float4 rand = hiprand_uniform4(&state);
float* rand_data = &(rand.x);
int k = 0;
for (int i = high_index; i < N; i++) {
uint8_t m = (uint8_t)(rand_data[k++] > ratio);
out[i] = __float2half((float)Xdata[i] * scale * m);
mask[i] = m;
}
}
}
__global__ void dropout_kernel_bwd(const int N,
const float ratio,
const float* Xdata,
float* out,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
int i = j * unroll_factor;
out[i] = mask[i] ? Xdata[i] * scale : 0.0;
out[i + 1] = mask[i + 1] ? Xdata[i + 1] * scale : 0.0;
out[i + 2] = mask[i + 2] ? Xdata[i + 2] * scale : 0.0;
out[i + 3] = mask[i + 3] ? Xdata[i + 3] * scale : 0.0;
}
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
for (int i = high_index; i < N; i++) { out[i] = mask[i] ? Xdata[i] * scale : 0.0; }
}
}
__global__ void dropout_kernel_bwd(const int N,
const float ratio,
const __half* Xdata,
__half* out,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
#ifdef __STOCHASTIC_MODE__
const __half2 h_scale = __float2half2_rn(scale);
const float2* x_cast = reinterpret_cast<const float2*>(Xdata);
float2* out_cast = reinterpret_cast<float2*>(out);
uint32_t* mask_cast = reinterpret_cast<uint32_t*>(mask);
CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
float2 x_f = x_cast[j];
__half2* x_h = reinterpret_cast<__half2*>(&x_f);
uint32_t m_32 = mask_cast[j];
uint8_t* m = (uint8_t*)&m_32;
__half2 mask_h[2];
float2 mask_f[2];
float* mask_f_data = &mask_f[0].x;
#pragma unroll
for (int i = 0; i < unroll_factor; i++) mask_f_data[i] = (float)(m[i]);
#pragma unroll
for (int i = 0; i < 2; i++) mask_h[i] = __float22half2_rn(mask_f[i]);
float2 result_f;
__half2* result_h = reinterpret_cast<__half2*>(&result_f);
result_h[0] = x_h[0] * h_scale * mask_h[0];
result_h[1] = x_h[1] * h_scale * mask_h[1];
out_cast[j] = result_f;
}
#else
const __half h_scale = __float2half(scale);
const __half h_zero = __float2half(0.0);
CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
int i = j * unroll_factor;
const __half2* vals_half = reinterpret_cast<const __half2*>(Xdata + i);
uint8_t* m = mask + i;
float2 vals_half_f[2];
vals_half_f[0] = __half22float2(vals_half[0]);
vals_half_f[1] = __half22float2(vals_half[1]);
out[i] = __float2half(vals_half_f[0].x * scale * m[0]);
out[i + 1] = __float2half(vals_half_f[0].y * scale * m[1]);
out[i + 2] = __float2half(vals_half_f[1].x * scale * m[2]);
out[i + 3] = __float2half(vals_half_f[1].y * scale * m[3]);
}
#endif
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
for (int i = high_index; i < N; i++) {
out[i] = __float2half((float)Xdata[i] * scale * mask[i]);
}
}
}
template <typename T>
void launch_dropout(T* out,
const T* vals,
uint8_t* mask,
int total_count,
int dim,
float ratio,
hipStream_t stream,
bool bwd)
{
assert(unroll_factor == 4);
dim3 grid_dim = DS_GET_BLOCKS(total_count / unroll_factor);
dim3 block_dim = DS_CUDA_NUM_THREADS;
if (dim > 512) {
block_dim.x >>= 1;
grid_dim.x <<= 1;
}
uint64_t inc = total_count / grid_dim.x / block_dim.x;
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
if (bwd)
hipLaunchKernelGGL(( dropout_kernel_bwd), dim3(grid_dim), dim3(block_dim), 0, stream,
total_count, ratio, vals, out, mask, seed);
else
hipLaunchKernelGGL(( dropout_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
total_count, ratio, out, vals, mask, seed);
}
template void launch_dropout(float* out,
const float* vals,
uint8_t* mask,
int total_count,
int dim,
float ratio,
hipStream_t stream,
bool);
template void launch_dropout(__half* out,
const __half* vals,
uint8_t* mask,
int total_count,
int dim,
float ratio,
hipStream_t stream,
bool);
__global__ void dropout_grad_kernel(const int N, const float scale, float* Xdata, uint8_t* mask)
{
CUDA_1D_KERNEL_LOOP(i, N) { Xdata[i] *= scale * mask[i]; }
}
__global__ void dropout_grad_kernel(const int N, const float scale, __half* Xdata, uint8_t* mask)
{
const __half2 h_scale = __float2half2_rn(scale);
float2* x_cast = reinterpret_cast<float2*>(Xdata);
uint32_t* mask_cast = reinterpret_cast<uint32_t*>(mask);
CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
float2 x_data = x_cast[j];
uint32_t m_32 = mask_cast[j];
uint8_t* m = (uint8_t*)&m_32;
float2 result_f;
__half2* result_h = reinterpret_cast<__half2*>(&result_f);
#ifdef __STOCHASTIC_MODE__
__half2* x_data_h = reinterpret_cast<__half2*>(&x_data);
__half2 mask_h[2];
float2 mask_f[2];
float* mask_f_data = &mask_f[0].x;
#pragma unroll
for (int i = 0; i < unroll_factor; i++) *(mask_f_data++) = (float)(m[i]);
mask_h[0] = __float22half2_rn(mask_f[0]);
mask_h[1] = __float22half2_rn(mask_f[1]);
result_h[0] = x_data_h[0] * h_scale * mask_h[0];
result_h[1] = x_data_h[1] * h_scale * mask_h[1];
#else
__half* x_data_h = reinterpret_cast<__half*>(&x_data);
float2 result[2];
result[0].x = (float)x_data_h[0] * scale * m[0];
result[0].y = (float)x_data_h[1] * scale * m[1];
result[1].x = (float)x_data_h[2] * scale * m[2];
result[1].y = (float)x_data_h[3] * scale * m[3];
result_h[0] = __float22half2_rn(result[0]);
result_h[1] = __float22half2_rn(result[1]);
#endif
x_cast[j] = result_f;
}
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
for (int i = high_index; i < N; i++) {
Xdata[i] = __float2half((float)Xdata[i] * scale * mask[i]);
}
}
}
template <typename T>
void launch_dropout_grad(T* vals, uint8_t* mask, int total_count, float ratio, hipStream_t stream)
{
assert(unroll_factor == 4);
const float scale = 1. / (1. - ratio);
hipLaunchKernelGGL(( dropout_grad_kernel), dim3(DS_GET_BLOCKS(total_count / unroll_factor)),
dim3(DS_CUDA_NUM_THREADS),
0,
stream, total_count, scale, vals, mask);
}
template void launch_dropout_grad(float* vals,
uint8_t* mask,
int total_count,
float ratio,
hipStream_t stream);
template void launch_dropout_grad(__half* vals,
uint8_t* mask,
int total_count,
float ratio,
hipStream_t stream);
__global__ void dropout_grad_kernel(const int N,
const float scale,
const float* Xdata,
float* out,
uint8_t* mask)
{
CUDA_1D_KERNEL_LOOP(i, N) { out[i] = Xdata[i] * scale * mask[i]; }
}
__global__ void dropout_grad_kernel(const int N,
const float scale,
const __half* Xdata,
__half* out,
uint8_t* mask)
{
const float2* x_cast = reinterpret_cast<const float2*>(Xdata);
float2* out_cast = reinterpret_cast<float2*>(out);
const uint32_t* mask_cast = reinterpret_cast<const uint32_t*>(mask);
float2 result_f;
__half2* result_h = reinterpret_cast<__half2*>(&result_f);
CUDA_1D_KERNEL_LOOP(j, N / unroll_factor)
{
float2 x_data = x_cast[j];
uint32_t m_32 = mask_cast[j];
uint8_t* m = (uint8_t*)&m_32;
__half* x_data_h = reinterpret_cast<__half*>(&x_data);
float2 result[2];
result[0].x = (float)x_data_h[0] * scale * m[0];
result[0].y = (float)x_data_h[1] * scale * m[1];
result[1].x = (float)x_data_h[2] * scale * m[2];
result[1].y = (float)x_data_h[3] * scale * m[3];
result_h[0] = __float22half2_rn(result[0]);
result_h[1] = __float22half2_rn(result[1]);
out_cast[j] = result_f;
}
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
for (int i = high_index; i < N; i++) {
out[i] = __float2half((float)Xdata[i] * scale * mask[i]);
}
}
}
template <typename T>
void launch_dropout_grad(T* vals_out,
const T* vals,
uint8_t* mask,
int total_count,
float ratio,
hipStream_t stream)
{
assert(unroll_factor == 4);
const float scale = 1. / (1. - ratio);
hipLaunchKernelGGL(( dropout_grad_kernel), dim3(DS_GET_BLOCKS(total_count / unroll_factor)),
dim3(DS_CUDA_NUM_THREADS),
0,
stream, total_count, scale, vals, vals_out, mask);
}
template void launch_dropout_grad(float*,
const float* vals,
uint8_t* mask,
int total_count,
float ratio,
hipStream_t stream);
template void launch_dropout_grad(__half*,
const __half* vals,
uint8_t* mask,
int total_count,
float ratio,
hipStream_t stream);
__global__ void dropout_kernel(const int N,
const int dim,
const float ratio,
const float* bias,
float* Xdata,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int tid = threadIdx.x % (dim / unroll_factor);
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed.first, idx, seed.second, &state);
float4* Xdata_cast = reinterpret_cast<float4*>(Xdata);
uint32_t* mask_32 = reinterpret_cast<uint32_t*>(mask);
const float4* bias_cast = reinterpret_cast<const float4*>(bias);
CUDA_1D_KERNEL_LOOP(j, N)
{
float4 rand = hiprand_uniform4(&state);
uint32_t m_32;
uint8_t* m = (uint8_t*)&m_32;
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
float4 x_data = Xdata_cast[j];
float4 b_data = bias_cast[j % (dim / unroll_factor)];
x_data.x += b_data.x;
x_data.y += b_data.y;
x_data.z += b_data.z;
x_data.w += b_data.w;
x_data.x = x_data.x * scale * m[0];
x_data.y = x_data.y * scale * m[1];
x_data.z = x_data.z * scale * m[2];
x_data.w = x_data.w * scale * m[3];
mask_32[j] = m_32;
Xdata_cast[j] = x_data;
}
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
float4 rand = hiprand_uniform4(&state);
float* rand_data = &(rand.x);
int k = 0;
for (int i = high_index; i < N; i++) {
float x_data = Xdata[i] + bias[i % dim];
uint8_t m = (uint8_t)(rand_data[k++] > ratio);
Xdata[i] = x_data * scale * m;
mask[i] = m;
}
}
}
__global__ void dropout_kernel(const int N,
const int dim,
const float ratio,
const __half* bias,
__half* Xdata,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int tid = threadIdx.x % (dim / unroll_factor);
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed.first, idx, seed.second, &state);
float2* Xdata_cast = reinterpret_cast<float2*>(Xdata);
uint32_t* mask_32 = reinterpret_cast<uint32_t*>(mask);
const float2* bias_cast = reinterpret_cast<const float2*>(bias);
CUDA_1D_KERNEL_LOOP(j, N)
{
float4 rand = hiprand_uniform4(&state);
float2 data_f;
__half2* data_h = reinterpret_cast<__half2*>(&data_f);
float2 bias_f;
__half2* bias_h = reinterpret_cast<__half2*>(&bias_f);
data_f = Xdata_cast[j];
bias_f = bias_cast[j % (dim / unroll_factor)];
float2 data_h_0 = __half22float2(data_h[0]);
float2 data_h_1 = __half22float2(data_h[1]);
float2 bias_h_0 = __half22float2(bias_h[0]);
float2 bias_h_1 = __half22float2(bias_h[1]);
data_h_0.x += bias_h_0.x;
data_h_0.y += bias_h_0.y;
data_h_1.x += bias_h_1.x;
data_h_1.y += bias_h_1.y;
uint32_t m_32;
uint8_t* m = (uint8_t*)&m_32;
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
data_h_0.x = __float2half(data_h_0.x * scale * m[0]);
data_h_0.y = __float2half(data_h_0.y * scale * m[1]);
data_h_1.x = __float2half(data_h_1.x * scale * m[2]);
data_h_1.y = __float2half(data_h_1.y * scale * m[3]);
float2 result_f;
__half2* result_h = reinterpret_cast<__half2*>(&result_f);
result_h[0] = __float22half2_rn(data_h_0);
result_h[1] = __float22half2_rn(data_h_1);
Xdata_cast[j] = result_f;
mask_32[j] = m_32;
}
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
float4 rand = hiprand_uniform4(&state);
float* rand_data = &(rand.x);
int k = 0;
for (int i = high_index; i < N; i++) {
float x_data = (float)Xdata[i] + (float)bias[i % dim];
uint8_t m = (uint8_t)(rand_data[k++] > ratio);
Xdata[i] = __float2half(x_data * scale * m);
mask[i] = m;
}
}
}
template <typename T>
void launch_dropout(T* out,
const T* bias,
uint8_t* mask,
int batch,
int dim,
float ratio,
hipStream_t stream)
{
assert(unroll_factor == 4);
int total_count = batch * dim / unroll_factor;
dim3 grid_dim = DS_GET_BLOCKS(total_count);
dim3 block_dim = DS_CUDA_NUM_THREADS;
uint64_t inc = (batch * dim) / grid_dim.x / block_dim.x;
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
hipLaunchKernelGGL(( dropout_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
total_count, dim, ratio, bias, out, mask, seed);
}
template void launch_dropout(float*,
const float* bias,
uint8_t* mask,
int batch,
int dim,
float ratio,
hipStream_t stream);
template void launch_dropout(__half*,
const __half* bias,
uint8_t* mask,
int batch,
int dim,
float ratio,
hipStream_t stream);
__global__ void dropout_kernel(const int N,
const int dim,
const float ratio,
const float* input,
const float* residual,
const float* bias,
float* out,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int tid = threadIdx.x % (dim / unroll_factor);
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed.first, idx, seed.second, &state);
float4* out_cast = reinterpret_cast<float4*>(out);
uint32_t* mask_32 = reinterpret_cast<uint32_t*>(mask);
const float4* bias_cast = reinterpret_cast<const float4*>(bias);
const float4* residual_cast = reinterpret_cast<const float4*>(residual);
const float4* input_cast = reinterpret_cast<const float4*>(input);
CUDA_1D_KERNEL_LOOP(j, N)
{
float4 rand = hiprand_uniform4(&state);
uint32_t m_32;
uint8_t* m = (uint8_t*)&m_32;
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
float4 out_data;
float4 b_data = bias_cast[j % (dim / unroll_factor)];
float4 res_data = residual_cast[j];
float4 inp_data = input_cast[j];
out_data.x = (b_data.x + inp_data.x);
out_data.y = (b_data.y + inp_data.y);
out_data.z = (b_data.z + inp_data.z);
out_data.w = (b_data.w + inp_data.w);
out_data.x = out_data.x * scale * m[0];
out_data.y = out_data.y * scale * m[1];
out_data.z = out_data.z * scale * m[2];
out_data.w = out_data.w * scale * m[3];
out_data.x += res_data.x;
out_data.y += res_data.y;
out_data.z += res_data.z;
out_data.w += res_data.w;
mask_32[j] = m_32;
out_cast[j] = out_data;
}
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
float4 rand = hiprand_uniform4(&state);
float* rand_data = &(rand.x);
int k = 0;
for (int i = high_index; i < N; i++) {
float x_data = input[i] + bias[i % dim];
uint8_t m = (uint8_t)(rand_data[k++] > ratio);
x_data = x_data * scale * m;
x_data += residual[i];
out[i] = x_data;
mask[i] = m;
}
}
}
__global__ void dropout_kernel(const int N,
const int dim,
const float ratio,
const __half* input,
const __half* residual,
const __half* bias,
__half* out,
uint8_t* mask,
std::pair<uint64_t, uint64_t> seed)
{
const float scale = 1. / (1. - ratio);
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int tid = threadIdx.x % (dim / unroll_factor);
hiprandStatePhilox4_32_10_t state;
hiprand_init(seed.first, idx, seed.second, &state);
float2* out_cast = reinterpret_cast<float2*>(out);
uint32_t* mask_32 = reinterpret_cast<uint32_t*>(mask);
const float2* bias_cast = reinterpret_cast<const float2*>(bias);
const float2* residual_cast = reinterpret_cast<const float2*>(residual);
const float2* input_cast = reinterpret_cast<const float2*>(input);
CUDA_1D_KERNEL_LOOP(j, N)
{
float4 rand = hiprand_uniform4(&state);
float2 data_f;
__half2* data_h = reinterpret_cast<__half2*>(&data_f);
float2 bias_f;
__half2* bias_h = reinterpret_cast<__half2*>(&bias_f);
float2 residual_f;
__half2* residual_h = reinterpret_cast<__half2*>(&residual_f);
float2 input_f;
__half2* input_h = reinterpret_cast<__half2*>(&input_f);
bias_f = bias_cast[j % (dim / unroll_factor)];
residual_f = residual_cast[j];
input_f = input_cast[j];
float2 data_h_0 = __half22float2(data_h[0]);
float2 data_h_1 = __half22float2(data_h[1]);
float2 bias_h_0 = __half22float2(bias_h[0]);
float2 bias_h_1 = __half22float2(bias_h[1]);
float2 residual_h_0 = __half22float2(residual_h[0]);
float2 residual_h_1 = __half22float2(residual_h[1]);
float2 input_h_0 = __half22float2(input_h[0]);
float2 input_h_1 = __half22float2(input_h[1]);
data_h_0.x = (bias_h_0.x + input_h_0.x);
data_h_0.y = (bias_h_0.y + input_h_0.y);
data_h_1.x = (bias_h_1.x + input_h_1.x);
data_h_1.y = (bias_h_1.y + input_h_1.y);
uint32_t m_32;
uint8_t* m = (uint8_t*)&m_32;
m[0] = (uint8_t)(rand.x > ratio);
m[1] = (uint8_t)(rand.y > ratio);
m[2] = (uint8_t)(rand.z > ratio);
m[3] = (uint8_t)(rand.w > ratio);
data_h_0.x = __float2half(data_h_0.x * scale * m[0]);
data_h_0.y = __float2half(data_h_0.y * scale * m[1]);
data_h_1.x = __float2half(data_h_1.x * scale * m[2]);
data_h_1.y = __float2half(data_h_1.y * scale * m[3]);
data_h_0.x += residual_h_0.x;
data_h_0.y += residual_h_0.y;
data_h_1.x += residual_h_1.x;
data_h_1.y += residual_h_1.y;
float2 result_f;
__half2* result_h = reinterpret_cast<__half2*>(&result_f);
result_h[0] = __float22half2_rn(data_h_0);
result_h[1] = __float22half2_rn(data_h_1);
out_cast[j] = result_f;
mask_32[j] = m_32;
}
int high_index =
((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x;
if (N > high_index) {
float4 rand = hiprand_uniform4(&state);
float* rand_data = &(rand.x);
int k = 0;
for (int i = high_index; i < N; i++) {
float x_data = (float)input[i] + (float)bias[i % dim];
uint8_t m = (uint8_t)(rand_data[k++] > ratio);
x_data = x_data * scale * m;
x_data += (float)residual[i];
out[i] = __float2half(x_data);
mask[i] = m;
}
}
}
template <typename T>
void launch_dropout(T* out,
const T* input,
const T* residual,
const T* bias,
uint8_t* mask,
int batch,
int dim,
float ratio,
hipStream_t stream)
{
assert(unroll_factor == 4);
int total_count = batch * dim / unroll_factor;
dim3 grid_dim = DS_GET_BLOCKS(total_count);
dim3 block_dim = DS_CUDA_NUM_THREADS;
uint64_t inc = (batch * dim) / grid_dim.x / block_dim.x;
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
hipLaunchKernelGGL(( dropout_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
total_count, dim, ratio, input, residual, bias, out, mask, seed);
}
template void launch_dropout(float*,
const float*,
const float* residual,
const float* bias,
uint8_t* mask,
int batch,
int dim,
float ratio,
hipStream_t stream);
template void launch_dropout(__half*,
const __half*,
const __half* residual,
const __half* bias,
uint8_t* mask,
int batch,
int dim,
float ratio,
hipStream_t stream);
// !!! This is a file automatically generated by hipify!!!
#include <torch/extension.h>
#include <rocblas.h>
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>
#include <type_traits>
#include <unordered_map>
#include <vector>
#include "Timer_hip.h"
#include "context_hip.h"
#include "cublas_wrappers_hip.h"
#include "custom_hip_layers.h"
#include "ds_transformer_hip.h"
static std::unordered_map<int, std::shared_ptr<void>> s_transformer_layers;
const int init_seq_length = 128;
// C++ interface
template <typename T>
unsigned get_workspace_size(unsigned maxBatchSize,
unsigned seq_len,
unsigned hidden_size,
unsigned intermediate_size,
unsigned heads,
bool training,
bool gelu_checkpoint)
{
unsigned workSpacesize = 4 * (size_t(maxBatchSize) * seq_len * hidden_size);
if (training) {
workSpacesize += 2 * (size_t(maxBatchSize) * seq_len * hidden_size);
workSpacesize += ((std::max)((size_t(maxBatchSize) * seq_len * intermediate_size),
2 * (size_t(maxBatchSize) * heads * seq_len * seq_len)));
if (gelu_checkpoint)
workSpacesize += 2 * (size_t(maxBatchSize) * seq_len * intermediate_size);
}
return workSpacesize; // * sizeof(T);
}
// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
template <typename T>
BertTransformerLayer<T>::BertTransformerLayer(unsigned layer_id,
unsigned batch_size,
unsigned hidden_size,
unsigned num_heads,
unsigned intermediate_size,
unsigned seq_length,
float attn_prob_dropout_ratio,
float hidden_output_dropout_ratio,
float layer_norm_eps,
bool pre_or_postLayerNorm,
const std::vector<std::array<int, 3>>& gemm_algos,
bool attn_dropout_checkpoint,
bool normalize_invertible,
bool gelu_checkpoint,
bool stochastic_mode)
: _layer_id(layer_id),
_batch_size(batch_size),
_hidden_size(hidden_size),
_heads(num_heads),
_intermediate_size(intermediate_size),
_seq_length(seq_length),
_training(true),
_pre_or_postLayerNorm(pre_or_postLayerNorm),
_attn_dropout_checkpoint(attn_dropout_checkpoint),
_normalize_invertible(normalize_invertible),
_gelu_checkpoint(gelu_checkpoint),
_stochastic_mode(stochastic_mode),
_stream(Context::Instance().GetCurrentStream()),
_cublasHandle(Context::Instance().GetCublasHandle()),
_qkv_linear(typename FeedForward<T>::Config(batch_size * seq_length,
3 * hidden_size,
hidden_size,
gemm_algos[0])),
_attn_out_linear(typename FeedForward<T>::Config(batch_size * seq_length,
hidden_size,
hidden_size,
gemm_algos[0])),
_attn_layer_norm(typename Normalize_Layer<T>::Config(batch_size,
seq_length,
hidden_size,
layer_norm_eps,
true,
!normalize_invertible)),
_layer_norm(typename Normalize_Layer<T>::Config(batch_size,
seq_length,
hidden_size,
layer_norm_eps,
true,
!normalize_invertible)),
_ff1(typename FeedForward<T>::Config(batch_size * seq_length,
_intermediate_size,
hidden_size,
gemm_algos[1])),
_ff2(typename FeedForward<T>::Config(batch_size * seq_length,
hidden_size,
_intermediate_size,
gemm_algos[2])),
_softmax(typename Softmax<T>::Config(batch_size, num_heads, seq_length)),
_gelu(typename Gelu<T>::Config(_intermediate_size)),
_attn_prob_dropout(typename Dropout<T>::Config(attn_prob_dropout_ratio, _seq_length)),
_attn_output_dropout(typename Dropout<T>::Config(hidden_output_dropout_ratio, _hidden_size)),
_layer_output_dropout(typename Dropout<T>::Config(hidden_output_dropout_ratio, _hidden_size)),
_attn_scores(typename StridedBatchGemm<T>::Config(_batch_size * _heads,
_seq_length,
_seq_length,
_hidden_size / _heads,
//(T(1.0) / T(sqrt(_hidden_size / _heads))),
//aiss debug 0506
(T(1.0 / (sqrt(_hidden_size / _heads)))),
T(0.0),
rocblas_operation_transpose,
rocblas_operation_none,
gemm_algos[3])),
_attn_context(typename StridedBatchGemm<T>::Config(_batch_size * _heads,
_hidden_size / _heads,
_seq_length,
_seq_length,
T(1.0),
T(0.0),
rocblas_operation_none,
rocblas_operation_none,
gemm_algos[4]))
{
assert(_hidden_size % _heads == 0);
Initialize();
}
template <typename T>
BertTransformerLayer<T>::~BertTransformerLayer()
{
}
template <typename T>
void BertTransformerLayer<T>::Initialize()
{
#ifndef __HIP_PLATFORM_HCC__
if (std::is_same<T, __half>::value) rocblas_set_math_mode(_cublasHandle, CUBLAS_TENSOR_OP_MATH);
#endif
}
template <typename T>
void BertTransformerLayer<T>::Forward(unsigned bsz,
const T* input_ptr,
const T* input_mask_ptr,
const T* attn_qkvw_ptr,
const T* attn_qkvb_ptr,
const T* attn_ow_ptr,
const T* attn_ob_ptr,
const T* attn_nw_ptr,
const T* attn_nb_ptr,
const T* inter_w_ptr,
const T* inter_b_ptr,
const T* output_w_ptr,
const T* output_b_ptr,
const T* norm_w_ptr,
const T* norm_b_ptr,
T* out_ptr,
T* inp_norm_ptr,
T* q_tf_ptr,
T* k_tf_ptr,
T* v_tf_ptr,
T* soft_out_ptr,
T* ctx_bufB_ptr,
T* attn_o_inp_ptr,
T* add_res_ptr,
T* ff1_inp_ptr,
T* gelu_inp_ptr,
T* ff2_inp_ptr)
{
rocblas_set_stream(_cublasHandle, _stream);
if (!_stochastic_mode) hipStreamSynchronize(_stream);
T* workspace = static_cast<T*>(Context::Instance().GetWorkSpace());
size_t small_buf_size = bsz * _seq_length * _hidden_size;
T* buf_0 = workspace;
T* buf_1 = buf_0 + small_buf_size;
T* buf_2 = buf_1;
if (_normalize_invertible) {
add_res_ptr = buf_1 + 3 * small_buf_size;
buf_2 = add_res_ptr;
}
if (_gelu_checkpoint) buf_2 += small_buf_size;
if (_attn_dropout_checkpoint)
ctx_bufB_ptr =
(_gelu_checkpoint ? (buf_2 + (_intermediate_size / _hidden_size) * small_buf_size)
: (buf_1 + 4 * small_buf_size));
int bsz_seq = bsz * _seq_length;
if (_pre_or_postLayerNorm) {
if (_layer_norm.UseMean())
_layer_norm.ForwardCheckpoint(
bsz_seq, inp_norm_ptr, input_ptr, norm_w_ptr, norm_b_ptr, _stream, true);
else
_layer_norm.Forward(
bsz_seq, inp_norm_ptr, input_ptr, norm_w_ptr, norm_b_ptr, _stream, true);
}
if (_pre_or_postLayerNorm)
_qkv_linear.Forward(bsz_seq, inp_norm_ptr, attn_qkvw_ptr, buf_0, _cublasHandle);
else
_qkv_linear.Forward(bsz_seq, input_ptr, attn_qkvw_ptr, buf_0, _cublasHandle);
launch_bias_add_transform_0213<T>(
q_tf_ptr, buf_0, attn_qkvb_ptr, bsz, _seq_length, _hidden_size, _heads, _stream, 3);
int bsz_heads = bsz * _heads;
// attention scores
_attn_scores.Forward(bsz_heads, soft_out_ptr, k_tf_ptr, q_tf_ptr, _cublasHandle);
// Softmax + Mask
_softmax.Forward(bsz, soft_out_ptr, input_mask_ptr, _stream);
// attn prob dropout.
_attn_prob_dropout.Forward(bsz_heads * _seq_length, ctx_bufB_ptr, soft_out_ptr, _stream);
// attention context
_attn_context.Forward(bsz_heads, buf_1, v_tf_ptr, ctx_bufB_ptr, _cublasHandle);
launch_transform4d_0213<T>(
attn_o_inp_ptr, buf_1, bsz, _heads, _seq_length, _hidden_size, _stream, 1);
if (_pre_or_postLayerNorm)
_attn_out_linear.Forward(bsz_seq, attn_o_inp_ptr, attn_ow_ptr, buf_1, _cublasHandle);
else
_attn_out_linear.Forward(bsz_seq, attn_o_inp_ptr, attn_ow_ptr, ff1_inp_ptr, _cublasHandle);
// attn output dropout.
if (_pre_or_postLayerNorm)
_attn_output_dropout.ForwardWithBias(
bsz_seq, add_res_ptr, buf_1, input_ptr, attn_ob_ptr, _stream);
else
_attn_output_dropout.ForwardWithBias(
bsz_seq, add_res_ptr, ff1_inp_ptr, input_ptr, attn_ob_ptr, _stream);
if (_pre_or_postLayerNorm) {
if (_attn_layer_norm.UseMean())
_attn_layer_norm.ForwardCheckpoint(
bsz_seq, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true);
else
_attn_layer_norm.Forward(
bsz_seq, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true);
} else {
if (_attn_layer_norm.UseMean())
_attn_layer_norm.ForwardCheckpoint(
bsz_seq, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true);
else
_attn_layer_norm.Forward(
bsz_seq, ff1_inp_ptr, add_res_ptr, attn_nw_ptr, attn_nb_ptr, _stream, true);
}
_ff1.Forward(bsz_seq,
ff1_inp_ptr,
inter_w_ptr,
(_gelu_checkpoint ? ff2_inp_ptr : gelu_inp_ptr),
_cublasHandle);
_gelu.ForwardWithBiasAdd(bsz_seq,
(_gelu_checkpoint ? ff2_inp_ptr : gelu_inp_ptr),
inter_b_ptr,
(_gelu_checkpoint ? buf_2 : ff2_inp_ptr),
_stream);
_ff2.Forward(
bsz_seq, (_gelu_checkpoint ? buf_2 : ff2_inp_ptr), output_w_ptr, out_ptr, _cublasHandle);
// layer output dropout.
if (_pre_or_postLayerNorm)
_layer_output_dropout.ForwardWithBias(
bsz_seq, out_ptr, out_ptr, add_res_ptr, output_b_ptr, _stream);
else
_layer_output_dropout.ForwardWithBias(
bsz_seq, inp_norm_ptr, out_ptr, ff1_inp_ptr, output_b_ptr, _stream);
if (!_pre_or_postLayerNorm) {
if (_layer_norm.UseMean())
_layer_norm.ForwardCheckpoint(
bsz_seq, out_ptr, inp_norm_ptr, norm_w_ptr, norm_b_ptr, _stream, true);
else
_layer_norm.Forward(
bsz_seq, out_ptr, inp_norm_ptr, norm_w_ptr, norm_b_ptr, _stream, true);
}
}
template <typename T>
void BertTransformerLayer<T>::Backward(unsigned bsz,
const T* grad_output_ptr,
const T* input_ptr,
const T* output_ptr,
const T* inp_norm_ptr,
const T* q_tf_ptr,
const T* k_tf_ptr,
const T* v_tf_ptr,
const T* soft_out_ptr,
const T* ctx_bufB_ptr,
const T* attn_o_inp_ptr,
const T* add_res_ptr,
const T* ff1_inp_ptr,
const T* gelu_inp_ptr,
const T* ff2_inp_ptr,
const T* input_mask_ptr,
const T* attn_qkvw_ptr,
const T* attn_ow_ptr,
const T* attn_nw_ptr,
const T* attn_nb_ptr,
const T* inter_w_ptr,
const T* inter_b_ptr,
const T* output_w_ptr,
const T* norm_w_ptr,
const T* norm_b_ptr,
T* grad_input_ptr,
T* grad_attn_qkvw_ptr,
T* grad_attn_qkvb_ptr,
T* grad_attn_ow_ptr,
T* grad_attn_ob_ptr,
T* grad_attn_nw_ptr,
T* grad_attn_nb_ptr,
T* grad_inter_w_ptr,
T* grad_inter_b_ptr,
T* grad_output_w_ptr,
T* grad_output_b_ptr,
T* grad_norm_w_ptr,
T* grad_norm_b_ptr)
{
rocblas_set_stream(_cublasHandle, _stream);
if (!_stochastic_mode) hipStreamSynchronize(_stream);
T* workspace = static_cast<T*>(Context::Instance().GetWorkSpace());
size_t small_buf_size = bsz * _seq_length * _hidden_size;
T* buf_0 = workspace;
T* buf_1 = buf_0 + small_buf_size;
T* buf_2 = buf_1 + small_buf_size;
T* buf_3 = buf_2 + small_buf_size;
T* ff2_buf = (_gelu_checkpoint ? buf_3 + (bsz * _seq_length * _intermediate_size)
: buf_3 + small_buf_size);
T* ctx_bufB_ptr_recomp = ff2_buf + (_seq_length * _seq_length * bsz * _heads);
hipStream_t streams[2] = {_stream, _stream};
int bsz_seq = bsz * _seq_length;
int bsz_heads = bsz * _heads;
if (!_pre_or_postLayerNorm) {
if (_layer_norm.UseMean())
_layer_norm.Backward(bsz_seq,
grad_output_ptr,
norm_w_ptr,
grad_norm_w_ptr,
grad_norm_b_ptr,
streams,
buf_1,
inp_norm_ptr);
else
_layer_norm.Backward(bsz_seq,
grad_output_ptr,
norm_w_ptr,
norm_b_ptr,
grad_norm_w_ptr,
grad_norm_b_ptr,
streams,
buf_1,
output_ptr);
}
if (_pre_or_postLayerNorm)
_layer_output_dropout.Backward(bsz_seq, buf_0, grad_output_ptr, _stream);
else
_layer_output_dropout.Backward(bsz_seq, buf_0, buf_1, _stream);
const T* layer_dropout_buf = _layer_output_dropout.HasDropout()
? buf_0
: (_pre_or_postLayerNorm ? grad_output_ptr : buf_1);
if (_gelu_checkpoint)
_gelu.ForwardWithBiasAdd(bsz_seq, ff2_inp_ptr, inter_b_ptr, buf_2, _stream);
_ff2.Backward(bsz_seq,
layer_dropout_buf,
(_gelu_checkpoint ? buf_2 : ff2_inp_ptr),
output_w_ptr,
grad_output_w_ptr,
grad_output_b_ptr,
_cublasHandle,
_stream,
ff2_buf);
_gelu.Backward(
bsz_seq, ff2_buf, (_gelu_checkpoint ? ff2_inp_ptr : gelu_inp_ptr), inter_b_ptr, _stream);
_ff1.Backward(bsz_seq,
ff2_buf,
ff1_inp_ptr,
inter_w_ptr,
grad_inter_w_ptr,
grad_inter_b_ptr,
_cublasHandle,
_stream,
buf_3);
if (!_pre_or_postLayerNorm)
launch_fused_add2<T>(buf_2, buf_3, buf_1, bsz, _seq_length, _hidden_size, _stream);
if (_pre_or_postLayerNorm) {
if (_attn_layer_norm.UseMean())
_attn_layer_norm.BackwardFusedAdd(bsz_seq,
buf_3,
grad_output_ptr,
attn_nw_ptr,
grad_attn_nw_ptr,
grad_attn_nb_ptr,
streams,
buf_0,
add_res_ptr);
else
_attn_layer_norm.BackwardFusedAdd(bsz_seq,
buf_3,
grad_output_ptr,
attn_nw_ptr,
attn_nb_ptr,
grad_attn_nw_ptr,
grad_attn_nb_ptr,
streams,
buf_0,
ff1_inp_ptr);
} else {
if (_attn_layer_norm.UseMean())
_attn_layer_norm.Backward(bsz_seq,
buf_2,
attn_nw_ptr,
grad_attn_nw_ptr,
grad_attn_nb_ptr,
streams,
buf_0,
add_res_ptr);
else
_attn_layer_norm.Backward(bsz_seq,
buf_2,
attn_nw_ptr,
attn_nb_ptr,
grad_attn_nw_ptr,
grad_attn_nb_ptr,
streams,
buf_0,
ff1_inp_ptr);
}
_attn_output_dropout.Backward(bsz_seq, buf_2, buf_0, _stream);
T* attn_output_dropout_buf = _attn_output_dropout.HasDropout() ? buf_2 : buf_0;
_attn_out_linear.Backward(bsz_seq,
attn_output_dropout_buf,
attn_o_inp_ptr,
attn_ow_ptr,
grad_attn_ow_ptr,
grad_attn_ob_ptr,
_cublasHandle,
_stream,
buf_1);
launch_transform_0213<T>(buf_2, buf_1, bsz, _seq_length, _hidden_size, _heads, _stream);
if (_attn_prob_dropout.HasDropout()) {
if (_attn_dropout_checkpoint)
_attn_prob_dropout.Forward(
bsz_heads * _seq_length, ctx_bufB_ptr_recomp, soft_out_ptr, _stream, true);
_attn_context.Backward(bsz_heads,
buf_2,
v_tf_ptr,
(_attn_dropout_checkpoint ? ctx_bufB_ptr_recomp : ctx_bufB_ptr),
_cublasHandle,
buf_3,
ff2_buf);
} else
_attn_context.Backward(
bsz_heads, buf_2, v_tf_ptr, soft_out_ptr, _cublasHandle, buf_3, ff2_buf);
_attn_prob_dropout.Backward(bsz_heads * _seq_length, ff2_buf, _stream);
_softmax.Backward(bsz, ff2_buf, soft_out_ptr, _stream);
_attn_scores.Backward(bsz_heads, ff2_buf, k_tf_ptr, q_tf_ptr, _cublasHandle, buf_2, buf_1);
launch_transform4d_0213(ff2_buf, buf_1, bsz, _heads, _seq_length, _hidden_size, _stream, 3);
if (_pre_or_postLayerNorm)
_qkv_linear.Backward(bsz_seq,
ff2_buf,
inp_norm_ptr,
attn_qkvw_ptr,
grad_attn_qkvw_ptr,
grad_attn_qkvb_ptr,
_cublasHandle,
_stream,
buf_2);
else
_qkv_linear.Backward(bsz_seq,
ff2_buf,
input_ptr,
attn_qkvw_ptr,
grad_attn_qkvw_ptr,
grad_attn_qkvb_ptr,
_cublasHandle,
_stream,
buf_2);
if (_pre_or_postLayerNorm) {
if (_layer_norm.UseMean())
_layer_norm.BackwardFusedAdd(bsz_seq,
buf_2,
buf_0,
norm_w_ptr,
grad_norm_w_ptr,
grad_norm_b_ptr,
streams,
grad_input_ptr,
input_ptr);
else
_layer_norm.BackwardFusedAdd(bsz_seq,
buf_2,
buf_0,
norm_w_ptr,
norm_b_ptr,
grad_norm_w_ptr,
grad_norm_b_ptr,
streams,
grad_input_ptr,
inp_norm_ptr);
} else
launch_fused_add2<T>(grad_input_ptr, buf_2, buf_0, bsz, _seq_length, _hidden_size, _stream);
}
template <typename T>
void BertTransformerLayer<T>::SetTrainingMode(bool training)
{
// Dropout will be skipped when not in training model.
_attn_prob_dropout.SetTrainingMode(training);
_attn_output_dropout.SetTrainingMode(training);
_layer_output_dropout.SetTrainingMode(training);
}
template <typename T>
void BertTransformerLayer<T>::SetIntermediateBuffers(uint8_t* attn_prob_dropout_mask_ptr,
uint8_t* attn_output_dropout_mask_ptr,
uint8_t* layer_output_dropout_mask_ptr,
T* attn_layer_norm_var,
T* attn_layer_norm_mean,
T* layer_norm_var,
T* layer_norm_mean)
{
_attn_prob_dropout.SetMask(attn_prob_dropout_mask_ptr);
_attn_output_dropout.SetMask(attn_output_dropout_mask_ptr);
_layer_output_dropout.SetMask(layer_output_dropout_mask_ptr);
_attn_layer_norm.SetVar(attn_layer_norm_var);
_attn_layer_norm.SetMean(attn_layer_norm_mean);
_layer_norm.SetVar(layer_norm_var);
_layer_norm.SetMean(layer_norm_mean);
}
template <typename T>
void BertTransformerLayer<T>::SetSeqLength(unsigned seq_len)
{
_seq_length = seq_len;
_softmax.SetSeqLength(_seq_length);
_attn_prob_dropout.SetDimension(_seq_length);
_attn_scores.SetConfig(_seq_length, _seq_length, _hidden_size / _heads);
_attn_context.SetConfig(_hidden_size / _heads, _seq_length, _seq_length);
}
template <typename T>
int create_transformer_layer(unsigned layer_id,
unsigned batch_size,
unsigned hidden_dim,
unsigned num_heads,
unsigned intermediate_size,
float attn_dropout_ratio,
float hidden_dropout_ratio,
float layer_norm_eps,
int seed,
bool pre_or_postLayerNorm,
bool test_gemm,
bool attn_dropout_checkpoint,
bool normalize_invertible,
bool gelu_checkpoint,
bool stochastic_mode)
{
Context::Instance().SetSeed(seed);
Context::Instance().TestGemmFP16(
test_gemm, batch_size, init_seq_length, num_heads, hidden_dim / num_heads);
auto layer = std::make_shared<BertTransformerLayer<T>>(layer_id,
batch_size,
hidden_dim,
num_heads,
intermediate_size,
init_seq_length,
attn_dropout_ratio,
hidden_dropout_ratio,
layer_norm_eps,
pre_or_postLayerNorm,
Context::Instance().GetGemmAlgos(),
attn_dropout_checkpoint,
normalize_invertible,
gelu_checkpoint,
stochastic_mode);
s_transformer_layers[layer_id] = layer;
std::string dtype = (std::is_same<T, __half>::value) ? "half" : "float";
std::cout << "layer #" << layer_id << " is created with date type [" << dtype << "]."
<< std::endl;
return 0;
}
template <typename T>
std::vector<torch::Tensor> ds_transformer_forward(unsigned layer_id,
const torch::Tensor& input,
const torch::Tensor& input_mask,
const torch::Tensor& attn_qkvw,
const torch::Tensor& attn_qkvb,
const torch::Tensor& attn_ow,
const torch::Tensor& attn_ob,
const torch::Tensor& attn_nw,
const torch::Tensor& attn_nb,
const torch::Tensor& inter_w,
const torch::Tensor& inter_b,
const torch::Tensor& output_w,
const torch::Tensor& output_b,
const torch::Tensor& norm_w,
const torch::Tensor& norm_b,
bool training_mode,
bool prelayernorm,
bool attn_dropout_checkpoint,
bool normalize_invertible,
bool gelu_checkpoint)
{
CHECK_INPUT(input);
CHECK_INPUT(input_mask);
CHECK_INPUT(attn_qkvw);
CHECK_INPUT(attn_qkvb);
CHECK_INPUT(attn_ow);
CHECK_INPUT(attn_ob);
CHECK_INPUT(attn_nw);
CHECK_INPUT(attn_nb);
CHECK_INPUT(inter_w);
CHECK_INPUT(inter_b);
CHECK_INPUT(output_w);
CHECK_INPUT(output_b);
CHECK_INPUT(norm_w);
CHECK_INPUT(norm_b);
unsigned bsz = input.size(0);
const T* input_ptr = (const T*)input.data_ptr();
const T* input_mask_ptr = (const T*)input_mask.data_ptr();
const T* attn_qkvw_ptr = (const T*)attn_qkvw.data_ptr();
const T* attn_qkvb_ptr = (const T*)attn_qkvb.data_ptr();
const T* attn_ow_ptr = (const T*)attn_ow.data_ptr();
const T* attn_ob_ptr = (const T*)attn_ob.data_ptr();
const T* attn_nw_ptr = (const T*)attn_nw.data_ptr();
const T* attn_nb_ptr = (const T*)attn_nb.data_ptr();
const T* inter_w_ptr = (const T*)inter_w.data_ptr();
const T* inter_b_ptr = (const T*)inter_b.data_ptr();
const T* output_w_ptr = (const T*)output_w.data_ptr();
const T* output_b_ptr = (const T*)output_b.data_ptr();
const T* norm_w_ptr = (const T*)norm_w.data_ptr();
const T* norm_b_ptr = (const T*)norm_b.data_ptr();
auto output = torch::empty_like(input);
T* out_ptr = (T*)output.data_ptr();
auto options = torch::TensorOptions()
.dtype(input.options().dtype())
.layout(torch::kStrided)
.device(torch::kCUDA)
.requires_grad(true);
auto uint8_options = torch::TensorOptions()
.dtype(torch::kInt8)
.layout(torch::kStrided)
.device(torch::kCUDA)
.requires_grad(false);
std::shared_ptr<BertTransformerLayer<T>> layer =
std::static_pointer_cast<BertTransformerLayer<T>>(s_transformer_layers[layer_id]);
unsigned seq_len = layer->GetSeqLength();
if (input.size(1) != seq_len) {
seq_len = input.size(1);
layer->SetSeqLength(seq_len);
}
auto workspace = torch::empty({get_workspace_size<T>(bsz,
seq_len,
layer->GetHiddenSize(),
layer->GetIntermediateSize(),
layer->GetNumHeads(),
layer->IsTrainingMode(),
layer->GeluCheckpoint())},
options);
Context::Instance().SetWorkSpace((T*)workspace.data_ptr());
auto inp_norm = ((prelayernorm || !normalize_invertible) ? torch::empty_like(input) : output);
auto add_res = (normalize_invertible ? inp_norm : torch::empty_like(input));
auto attn_o_inp = torch::empty_like(input);
auto qkv_tf = torch::empty({(bsz * seq_len), output_w.size(0) * 3}, options);
auto attn_prob_dropout_mask =
torch::empty({(bsz * layer->GetNumHeads() * seq_len), seq_len}, uint8_options);
auto attn_output_dropout_mask =
torch::empty({(bsz * seq_len), layer->GetHiddenSize()}, uint8_options);
auto layer_output_dropout_mask =
torch::empty({(bsz * seq_len), layer->GetHiddenSize()}, uint8_options);
auto attn_layer_norm_var = torch::empty({(bsz * seq_len)}, options);
auto attn_layer_norm_mean = torch::empty({(bsz * seq_len)}, options);
auto layer_norm_var = torch::empty({(bsz * seq_len)}, options);
auto layer_norm_mean = torch::empty({(bsz * seq_len)}, options);
T* inp_norm_ptr = (T*)inp_norm.data_ptr();
T* add_res_ptr = (T*)add_res.data_ptr();
T* q_tf_ptr = (T*)qkv_tf.data_ptr();
T* k_tf_ptr = q_tf_ptr + (bsz * seq_len * output_w.size(0)); //(T*)k_tf.data_ptr();
T* v_tf_ptr = k_tf_ptr + (bsz * seq_len * output_w.size(0)); //(T*)v_tf.data_ptr();
T* attn_o_inp_ptr = (T*)attn_o_inp.data_ptr();
torch::Tensor ff2_inp = torch::empty({(bsz * seq_len), output_w.size(1)}, options);
torch::Tensor gelu_inp =
(gelu_checkpoint ? ff2_inp : torch::empty({(bsz * seq_len), output_w.size(1)}, options));
auto ff1_inp = torch::empty_like(input);
T* ff2_inp_ptr = (T*)ff2_inp.data_ptr();
T* gelu_inp_ptr = (T*)gelu_inp.data_ptr();
T* ff1_inp_ptr = (T*)ff1_inp.data_ptr();
torch::Tensor soft_out =
torch::empty({(bsz * layer->GetNumHeads() * seq_len), seq_len}, options);
torch::Tensor ctx_bufB =
(attn_dropout_checkpoint
? soft_out
: torch::empty({(bsz * layer->GetNumHeads() * seq_len), seq_len}, options));
T* soft_out_ptr = (T*)soft_out.data_ptr();
T* ctx_bufB_ptr = (T*)ctx_bufB.data_ptr();
layer->SetTrainingMode(training_mode);
layer->SetIntermediateBuffers((uint8_t*)attn_prob_dropout_mask.data_ptr(),
(uint8_t*)attn_output_dropout_mask.data_ptr(),
(uint8_t*)layer_output_dropout_mask.data_ptr(),
(T*)attn_layer_norm_var.data_ptr(),
(T*)attn_layer_norm_mean.data_ptr(),
(T*)layer_norm_var.data_ptr(),
(T*)layer_norm_mean.data_ptr());
layer->Forward(bsz,
input_ptr,
input_mask_ptr,
attn_qkvw_ptr,
attn_qkvb_ptr,
attn_ow_ptr,
attn_ob_ptr,
attn_nw_ptr,
attn_nb_ptr,
inter_w_ptr,
inter_b_ptr,
output_w_ptr,
output_b_ptr,
norm_w_ptr,
norm_b_ptr,
out_ptr,
inp_norm_ptr,
q_tf_ptr,
k_tf_ptr,
v_tf_ptr,
soft_out_ptr,
ctx_bufB_ptr,
attn_o_inp_ptr,
add_res_ptr,
ff1_inp_ptr,
gelu_inp_ptr,
ff2_inp_ptr);
return {output,
inp_norm,
qkv_tf,
soft_out,
ctx_bufB,
attn_o_inp,
add_res,
ff1_inp,
gelu_inp,
ff2_inp,
attn_prob_dropout_mask,
attn_output_dropout_mask,
layer_output_dropout_mask,
attn_layer_norm_var,
attn_layer_norm_mean,
layer_norm_var,
layer_norm_mean};
}
template <typename T>
std::vector<torch::Tensor> ds_transformer_backward(unsigned layer_id,
const torch::Tensor& grad_output,
const torch::Tensor& output,
const torch::Tensor& inp_norm,
const torch::Tensor& qkv_tf,
const torch::Tensor& soft_out,
const torch::Tensor& ctx_bufB,
const torch::Tensor& attn_o_inp,
const torch::Tensor& add_res,
const torch::Tensor& ff1_inp,
const torch::Tensor& gelu_inp,
const torch::Tensor& ff2_inp,
const torch::Tensor& attn_prob_dropout_mask,
const torch::Tensor& attn_output_dropout_mask,
const torch::Tensor& layer_output_dropout_mask,
const torch::Tensor& attn_layer_norm_var,
const torch::Tensor& attn_layer_norm_mean,
const torch::Tensor& layer_norm_var,
const torch::Tensor& layer_norm_mean,
const torch::Tensor& input,
const torch::Tensor& input_mask,
const torch::Tensor& attn_qkvw,
const torch::Tensor& attn_qkvb,
const torch::Tensor& attn_ow,
const torch::Tensor& attn_ob,
const torch::Tensor& attn_nw,
const torch::Tensor& attn_nb,
const torch::Tensor& inter_w,
const torch::Tensor& inter_b,
const torch::Tensor& output_w,
const torch::Tensor& output_b,
const torch::Tensor& norm_w,
const torch::Tensor& norm_b)
{
auto g_output = grad_output.contiguous();
CHECK_INPUT(g_output);
CHECK_INPUT(output);
CHECK_INPUT(inp_norm);
CHECK_INPUT(qkv_tf);
CHECK_INPUT(add_res);
CHECK_INPUT(soft_out);
CHECK_INPUT(ctx_bufB);
CHECK_INPUT(attn_o_inp);
CHECK_INPUT(ff1_inp);
CHECK_INPUT(gelu_inp);
CHECK_INPUT(ff2_inp);
CHECK_INPUT(input);
CHECK_INPUT(input_mask);
CHECK_INPUT(attn_qkvw);
CHECK_INPUT(attn_qkvb);
CHECK_INPUT(attn_ow);
CHECK_INPUT(attn_ob);
CHECK_INPUT(attn_nw);
CHECK_INPUT(attn_nb);
CHECK_INPUT(inter_w);
CHECK_INPUT(inter_b);
CHECK_INPUT(output_w);
CHECK_INPUT(output_b);
CHECK_INPUT(norm_w);
CHECK_INPUT(norm_b);
unsigned bsz = g_output.size(0);
std::shared_ptr<BertTransformerLayer<T>> layer =
std::static_pointer_cast<BertTransformerLayer<T>>(s_transformer_layers[layer_id]);
unsigned seq_len = layer->GetSeqLength();
if (g_output.size(1) != seq_len) {
seq_len = g_output.size(1);
layer->SetSeqLength(seq_len);
}
auto options = torch::TensorOptions()
.dtype(g_output.options().dtype())
.layout(torch::kStrided)
.device(torch::kCUDA)
.requires_grad(true);
auto workspace = torch::empty({get_workspace_size<T>(bsz,
seq_len,
layer->GetHiddenSize(),
layer->GetIntermediateSize(),
layer->GetNumHeads(),
layer->IsTrainingMode(),
layer->GeluCheckpoint())},
options);
Context::Instance().SetWorkSpace((T*)workspace.data_ptr());
auto grad_input = torch::empty_like(input);
auto grad_attn_qkvw = torch::empty_like(attn_qkvw);
auto grad_attn_qkvb = torch::empty_like(attn_qkvb);
auto grad_attn_ow = torch::empty_like(attn_ow);
auto grad_attn_ob = torch::empty_like(attn_ob);
auto grad_attn_nw = torch::empty_like(attn_nw);
auto grad_attn_nb = torch::empty_like(attn_nb);
auto grad_inter_w = torch::empty_like(inter_w);
auto grad_inter_b = torch::empty_like(inter_b);
auto grad_output_w = torch::empty_like(output_w);
auto grad_output_b = torch::empty_like(output_b);
auto grad_norm_w = torch::empty_like(norm_w);
auto grad_norm_b = torch::empty_like(norm_b);
// inputs.
const T* grad_output_ptr = (const T*)g_output.data_ptr();
const T* input_ptr = (const T*)input.data_ptr();
const T* output_ptr = (const T*)output.data_ptr();
const T* inp_norm_ptr = (const T*)inp_norm.data_ptr();
const T* q_tf_ptr = (const T*)qkv_tf.data_ptr();
const T* add_res_ptr = (const T*)add_res.data_ptr();
const T* k_tf_ptr =
q_tf_ptr + (bsz * layer->GetSeqLength() * output_w.size(0)); //(const T*)k_tf.data_ptr();
const T* v_tf_ptr =
k_tf_ptr + (bsz * layer->GetSeqLength() * output_w.size(0)); //(const T*)v_tf.data_ptr();
const T* ff1_inp_ptr = (const T*)ff1_inp.data_ptr();
const T* gelu_inp_ptr = (const T*)gelu_inp.data_ptr();
const T* ff2_inp_ptr = (const T*)ff2_inp.data_ptr();
const T* ctx_bufB_ptr = (const T*)ctx_bufB.data_ptr();
const T* soft_out_ptr = (const T*)soft_out.data_ptr();
const T* attn_o_inp_ptr = (const T*)attn_o_inp.data_ptr();
const T* input_mask_ptr = (const T*)input_mask.data_ptr();
const T* attn_qkvw_ptr = (const T*)attn_qkvw.data_ptr();
const T* attn_ow_ptr = (const T*)attn_ow.data_ptr();
const T* attn_nw_ptr = (const T*)attn_nw.data_ptr();
const T* attn_nb_ptr = (const T*)attn_nb.data_ptr();
const T* inter_w_ptr = (const T*)inter_w.data_ptr();
const T* inter_b_ptr = (const T*)inter_b.data_ptr();
const T* output_w_ptr = (const T*)output_w.data_ptr();
const T* norm_w_ptr = (const T*)norm_w.data_ptr();
const T* norm_b_ptr = (const T*)norm_b.data_ptr();
// outputs.
T* grad_input_ptr = (T*)grad_input.data_ptr();
T* grad_attn_qkvw_ptr = (T*)grad_attn_qkvw.data_ptr();
T* grad_attn_qkvb_ptr = (T*)grad_attn_qkvb.data_ptr();
T* grad_attn_ow_ptr = (T*)grad_attn_ow.data_ptr();
T* grad_attn_ob_ptr = (T*)grad_attn_ob.data_ptr();
T* grad_attn_nw_ptr = (T*)grad_attn_nw.data_ptr();
T* grad_attn_nb_ptr = (T*)grad_attn_nb.data_ptr();
T* grad_inter_w_ptr = (T*)grad_inter_w.data_ptr();
T* grad_inter_b_ptr = (T*)grad_inter_b.data_ptr();
T* grad_output_w_ptr = (T*)grad_output_w.data_ptr();
T* grad_output_b_ptr = (T*)grad_output_b.data_ptr();
T* grad_norm_w_ptr = (T*)grad_norm_w.data_ptr();
T* grad_norm_b_ptr = (T*)grad_norm_b.data_ptr();
layer->SetIntermediateBuffers((uint8_t*)attn_prob_dropout_mask.data_ptr(),
(uint8_t*)attn_output_dropout_mask.data_ptr(),
(uint8_t*)layer_output_dropout_mask.data_ptr(),
(T*)attn_layer_norm_var.data_ptr(),
(T*)attn_layer_norm_mean.data_ptr(),
(T*)layer_norm_var.data_ptr(),
(T*)layer_norm_mean.data_ptr());
layer->Backward(bsz,
grad_output_ptr,
input_ptr,
output_ptr,
inp_norm_ptr,
q_tf_ptr,
k_tf_ptr,
v_tf_ptr,
soft_out_ptr,
ctx_bufB_ptr,
attn_o_inp_ptr,
add_res_ptr,
ff1_inp_ptr,
gelu_inp_ptr,
ff2_inp_ptr,
input_mask_ptr,
attn_qkvw_ptr,
attn_ow_ptr,
attn_nw_ptr,
attn_nb_ptr,
inter_w_ptr,
inter_b_ptr,
output_w_ptr,
norm_w_ptr,
norm_b_ptr,
grad_input_ptr,
grad_attn_qkvw_ptr,
grad_attn_qkvb_ptr,
grad_attn_ow_ptr,
grad_attn_ob_ptr,
grad_attn_nw_ptr,
grad_attn_nb_ptr,
grad_inter_w_ptr,
grad_inter_b_ptr,
grad_output_w_ptr,
grad_output_b_ptr,
grad_norm_w_ptr,
grad_norm_b_ptr);
return {grad_input,
grad_attn_qkvw,
grad_attn_qkvb,
grad_attn_ow,
grad_attn_ob,
grad_attn_nw,
grad_attn_nb,
grad_inter_w,
grad_inter_b,
grad_output_w,
grad_output_b,
grad_norm_w,
grad_norm_b};
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("forward_fp32",
&ds_transformer_forward<float>,
"DeepSpeed Transformer forward with fp32 (CUDA)");
m.def("forward_fp16",
&ds_transformer_forward<__half>,
"DeepSpeed Transformer forward with fp16 (CUDA)");
m.def("backward_fp32",
&ds_transformer_backward<float>,
"DeepSpeed Transformer backward with fp32 (CUDA)");
m.def("backward_fp16",
&ds_transformer_backward<__half>,
"DeepSpeed Transformer backward with fp16 (CUDA)");
m.def("create_transformer_layer_fp32",
&create_transformer_layer<float>,
"Create DeepSpeed Transformer Transformer Layer with fp32 (CUDA)");
m.def("create_transformer_layer_fp16",
&create_transformer_layer<__half>,
"Create DeepSpeed Transformer Transformer Layer with fp16 (CUDA)");
}
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
#include "custom_hip_layers.h"
inline __device__ float gelu(const float x)
{
const float sqrt_param = 0.79788456080286535587989211986876f;
const float mul_param = 0.044715;
return x * 0.5f * (1.0f + tanhf(sqrt_param * (x + mul_param * x * x * x)));
}
inline __device__ float d_gelu(const float x)
{
const float sqrt_param = 0.79788456080286535587989211986876f;
const float mul_param = 0.044715;
float x2mul = x * x * mul_param;
float tan_h = tanhf(sqrt_param * (x + x * x2mul));
float dg1 = 0.5f * (1.0f + tan_h);
float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h);
float dg3 = dg2 * 3 * x2mul;
return (dg1 + dg2 + dg3);
}
/*
Fused bias add with GELU
Loads a vector of 4 elements each iteration, for stride
iterations. It was written with the intention to launch 256 thread
threadblocks, so to launch for bert-large, we would set ITERATIONS
to 4. This is currently done automatically as a heuristic, setting
the number of iterations as blocks of 1024.
For FP16, the values are loaded from memory as __half, but converted
to FP32 for the arithmetic itself, to prevent numerous overflow on
the intermediate hyperbolic tangent, since there's no intrinsic
that computes it directly.
*/
__global__ void gelu_kernel(const float* input, float* vals, int row_stride, int iterations)
{
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
const float4* input_cast = reinterpret_cast<const float4*>(input);
float4* vals_cast = reinterpret_cast<float4*>(vals);
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float4 data = input_cast[row * row_stride + i * loop_stride + id];
data.x = gelu(data.x);
data.y = gelu(data.y);
data.z = gelu(data.z);
data.w = gelu(data.w);
vals_cast[row * row_stride + i * loop_stride + id] = data;
}
}
}
__global__ void gelu_kernel(const __half* input, __half* vals, int row_stride, int iterations)
{
#ifdef HALF_PRECISION_AVAILABLE
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
const float2* input_cast = reinterpret_cast<const float2*>(input);
float2* vals_cast = reinterpret_cast<float2*>(vals);
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float2 vals_vec = input_cast[row * row_stride + i * loop_stride + id];
__half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
float2 low_data = __half22float2(vals_half[0]);
float2 high_data = __half22float2(vals_half[1]);
low_data.x = gelu(low_data.x);
low_data.y = gelu(low_data.y);
high_data.x = gelu(high_data.x);
high_data.y = gelu(high_data.y);
vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);
vals_cast[row * row_stride + i * loop_stride + id] = vals_vec;
}
}
#endif
}
__global__ void fused_bias_gelu(const float* input,
const float* bias,
float* vals,
int row_stride,
int iterations)
{
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
const float4* input_cast = reinterpret_cast<const float4*>(input);
float4* vals_cast = reinterpret_cast<float4*>(vals);
const float4* bias_cast = reinterpret_cast<const float4*>(bias);
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float4 data = input_cast[row * row_stride + i * loop_stride + id];
float4 bias_data = bias_cast[i * loop_stride + id];
data.x += bias_data.x;
data.y += bias_data.y;
data.z += bias_data.z;
data.w += bias_data.w;
data.x = gelu(data.x);
data.y = gelu(data.y);
data.z = gelu(data.z);
data.w = gelu(data.w);
vals_cast[row * row_stride + i * loop_stride + id] = data;
}
}
}
__global__ void fused_bias_gelu(const __half* input,
const __half* bias,
__half* vals,
int row_stride,
int iterations)
{
#ifdef HALF_PRECISION_AVAILABLE
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
const float2* input_cast = reinterpret_cast<const float2*>(input);
float2* vals_cast = reinterpret_cast<float2*>(vals);
const float2* bias_cast = reinterpret_cast<const float2*>(bias);
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float2 vals_vec = input_cast[row * row_stride + i * loop_stride + id];
float2 bias_vec = bias_cast[i * loop_stride + id];
__half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);
float2 low_data = __half22float2(vals_half[0]);
float2 high_data = __half22float2(vals_half[1]);
float2 low_bias = __half22float2(bias_half[0]);
float2 high_bias = __half22float2(bias_half[1]);
low_data.x += low_bias.x;
low_data.y += low_bias.y;
high_data.x += high_bias.x;
high_data.y += high_bias.y;
low_data.x = gelu(low_data.x);
low_data.y = gelu(low_data.y);
high_data.x = gelu(high_data.x);
high_data.y = gelu(high_data.y);
vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);
vals_cast[row * row_stride + i * loop_stride + id] = vals_vec;
}
}
#endif
}
__global__ void d_gelu_func(float* d_output,
const float* gelu_input,
const float* bias,
int row_stride,
int iterations)
{
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
float4* d_output_cast = reinterpret_cast<float4*>(d_output);
const float4* gelu_input_cast = reinterpret_cast<const float4*>(gelu_input);
const float4* bias_cast = reinterpret_cast<const float4*>(bias);
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float4 output_data = d_output_cast[row * row_stride + i * loop_stride + id];
float4 gelu_input_data = gelu_input_cast[row * row_stride + i * loop_stride + id];
float4 bias_data = bias_cast[i * loop_stride + id];
gelu_input_data.x += bias_data.x;
gelu_input_data.y += bias_data.y;
gelu_input_data.z += bias_data.z;
gelu_input_data.w += bias_data.w;
output_data.x *= d_gelu(gelu_input_data.x);
output_data.y *= d_gelu(gelu_input_data.y);
output_data.z *= d_gelu(gelu_input_data.z);
output_data.w *= d_gelu(gelu_input_data.w);
d_output_cast[row * row_stride + i * loop_stride + id] = output_data;
}
}
}
__global__ void d_gelu_func(__half* d_output,
const __half* gelu_input,
const __half* bias,
int row_stride,
int iterations)
{
#ifdef HALF_PRECISION_AVAILABLE
int row = blockIdx.x;
int id = threadIdx.x;
int loop_stride = blockDim.x;
float2* d_output_cast = reinterpret_cast<float2*>(d_output);
const float2* gelu_input_cast = reinterpret_cast<const float2*>(gelu_input);
const float2* bias_cast = reinterpret_cast<const float2*>(bias);
#pragma unroll
for (int i = 0; i < iterations; i++) {
if (i * loop_stride + id < row_stride) {
float2 output_data = d_output_cast[row * row_stride + i * loop_stride + id];
float2 gelu_input_data = gelu_input_cast[row * row_stride + i * loop_stride + id];
float2 bias_vec = bias_cast[i * loop_stride + id];
__half2* output_data_half = reinterpret_cast<__half2*>(&output_data);
__half2* gelu_input_data_half = reinterpret_cast<__half2*>(&gelu_input_data);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);
float2 output_half_0 = __half22float2(output_data_half[0]);
float2 output_half_1 = __half22float2(output_data_half[1]);
float2 gelu_input_half_0 = __half22float2(gelu_input_data_half[0]);
float2 gelu_input_half_1 = __half22float2(gelu_input_data_half[1]);
float2 bias_half_0 = __half22float2(bias_half[0]);
float2 bias_half_1 = __half22float2(bias_half[1]);
gelu_input_half_0.x += bias_half_0.x;
gelu_input_half_0.y += bias_half_0.y;
gelu_input_half_1.x += bias_half_1.x;
gelu_input_half_1.y += bias_half_1.y;
output_half_0.x *= d_gelu(gelu_input_half_0.x);
output_half_0.y *= d_gelu(gelu_input_half_0.y);
output_half_1.x *= d_gelu(gelu_input_half_1.x);
output_half_1.y *= d_gelu(gelu_input_half_1.y);
float2 result;
__half2* result_half2 = reinterpret_cast<__half2*>(&result);
result_half2[0] = __float22half2_rn(output_half_0);
result_half2[1] = __float22half2_rn(output_half_1);
d_output_cast[row * row_stride + i * loop_stride + id] = result;
}
}
#endif
}
template <typename T>
void launch_bias_gelu(const T* input,
const T* bias,
T* output,
int intermediate_size,
int batch_size,
hipStream_t stream)
{
int iterations = (intermediate_size + 1023) / 1024;
int threads = (intermediate_size - 1) / (iterations * 4) + 1;
dim3 block_dims(threads);
dim3 grid_dims(batch_size);
hipLaunchKernelGGL(( fused_bias_gelu), dim3(grid_dims), dim3(block_dims), 0, stream,
input, bias, output, intermediate_size / 4, iterations);
}
template <typename T>
void launch_gelu(const T* input,
T* output,
int intermediate_size,
int batch_size,
hipStream_t stream)
{
int iterations = (intermediate_size + 1023) / 1024;
int threads = (intermediate_size - 1) / (iterations * 4) + 1;
dim3 block_dims(threads);
dim3 grid_dims(batch_size);
hipLaunchKernelGGL(( gelu_kernel), dim3(grid_dims), dim3(block_dims), 0, stream,
input, output, intermediate_size / 4, iterations);
}
template void launch_bias_gelu<float>(const float*, const float*, float*, int, int, hipStream_t);
template void launch_bias_gelu<__half>(const __half*,
const __half*,
__half*,
int,
int,
hipStream_t);
template void launch_gelu<float>(const float*, float*, int, int, hipStream_t);
template void launch_gelu<__half>(const __half*, __half*, int, int, hipStream_t);
template <typename T>
void launch_d_gelu(T* d_output,
const T* input,
const T* bias,
int intermediate_size,
int batch_size,
hipStream_t stream)
{
int iterations = (intermediate_size + 1023) / 1024;
int threads = (intermediate_size - 1) / (iterations * 4) + 1;
dim3 block_dims(threads);
dim3 grid_dims(batch_size);
hipLaunchKernelGGL(( d_gelu_func), dim3(grid_dims), dim3(block_dims), 0, stream,
d_output, input, bias, intermediate_size / 4, iterations);
}
template void launch_d_gelu<float>(float*, const float*, const float*, int, int, hipStream_t);
template void launch_d_gelu<__half>(__half*, const __half*, const __half*, int, int, hipStream_t);
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
#include "general_kernels_hip.h"
namespace cg = cooperative_groups;
template <typename T>
__global__ void column_sum_reduce(const T* __restrict__ inp,
T* __restrict__ out,
int rows,
int width)
{
__shared__ float tile[TILE_DIM][TILE_DIM + 1];
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int y_stride = width * TILE_DIM;
float localSum = 0;
// Loop across matrix height
if (idx < width) {
int offset = threadIdx.y * width + idx;
for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
localSum += (float)inp[offset];
offset += y_stride;
}
}
tile[threadIdx.x][threadIdx.y] = localSum;
__syncthreads();
// Sum the shared buffer.
float sum = tile[threadIdx.y][threadIdx.x];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < TILE_DIM; i <<= 1) sum += g.shfl_down(sum, i);
if (threadIdx.x == 0) {
int pos = blockIdx.x * TILE_DIM + threadIdx.y;
if (pos < width) out[pos] = sum;
}
}
template <typename T>
void launch_fuse_transpose_bias_kernel(const T* inp,
T* out,
int rows,
int cols,
hipStream_t stream);
template <>
void launch_fuse_transpose_bias_kernel<float>(const float* inp,
float* out,
int rows,
int cols,
hipStream_t stream)
{
// assert(rows % TILE_DIM == 0);
// assert(cols % TILE_DIM == 0);
dim3 grid_dim((cols - 1) / TILE_DIM + 1);
dim3 block_dim(TILE_DIM, TILE_DIM);
hipLaunchKernelGGL(( column_sum_reduce<float>), dim3(grid_dim), dim3(block_dim), 0, stream, inp, out, rows, cols);
}
template <>
void launch_fuse_transpose_bias_kernel<__half>(const __half* inp,
__half* out,
int rows,
int cols,
hipStream_t stream)
{
// assert(rows % TILE_DIM == 0);
// assert(cols % TILE_DIM == 0);
dim3 grid_dim((cols - 1) / TILE_DIM + 1);
dim3 block_dim(TILE_DIM, TILE_DIM);
hipLaunchKernelGGL(( column_sum_reduce<__half>), dim3(grid_dim), dim3(block_dim), 0, stream, inp, out, rows, cols);
}
__global__ void fused_add2_kernel(const int N, float* out, const float* inp1, const float* inp2)
{
const float4* inp1_4 = reinterpret_cast<const float4*>(inp1);
const float4* inp2_4 = reinterpret_cast<const float4*>(inp2);
float4* out_4 = reinterpret_cast<float4*>(out);
CUDA_1D_KERNEL_LOOP(j, N)
{
float4 val;
float4 inp1_reg = inp1_4[j];
float4 inp2_reg = inp2_4[j];
val.x = inp1_reg.x + inp2_reg.x;
val.y = inp1_reg.y + inp2_reg.y;
val.z = inp1_reg.z + inp2_reg.z;
val.w = inp1_reg.w + inp2_reg.w;
out_4[j] = val;
}
}
__global__ void fused_add2_kernel(const int N, __half* out, const __half* inp1, const __half* inp2)
{
float2 inp1_4;
float2 inp2_4;
__half2* inp1_h = reinterpret_cast<__half2*>(&inp1_4);
__half2* inp2_h = reinterpret_cast<__half2*>(&inp2_4);
const float2* inp1_arr = reinterpret_cast<const float2*>(inp1);
const float2* inp2_arr = reinterpret_cast<const float2*>(inp2);
CUDA_1D_KERNEL_LOOP(j, N)
{
inp1_4 = inp1_arr[j];
inp2_4 = inp2_arr[j];
float2 inp1_h_f_0 = __half22float2(inp1_h[0]);
float2 inp1_h_f_1 = __half22float2(inp1_h[1]);
float2 inp2_h_f_0 = __half22float2(inp2_h[0]);
float2 inp2_h_f_1 = __half22float2(inp2_h[1]);
inp1_h_f_0.x += inp2_h_f_0.x;
inp1_h_f_0.y += inp2_h_f_0.y;
inp1_h_f_1.x += inp2_h_f_1.x;
inp1_h_f_1.y += inp2_h_f_1.y;
float2 val_f;
__half2* val_h = reinterpret_cast<__half2*>(&val_f);
val_h[0] = __float22half2_rn(inp1_h_f_0);
val_h[1] = __float22half2_rn(inp1_h_f_1);
float2* out_4 = reinterpret_cast<float2*>(out);
out_4[j] = val_f;
}
}
template <>
void launch_fused_add2<float>(float* out,
const float* inp1,
const float* inp2,
int batch_size,
int seq_length,
int hidden_dim,
hipStream_t& stream)
{
int total_count = batch_size * seq_length * hidden_dim / 4;
dim3 grid_dim = DS_GET_BLOCKS(total_count); //(batch_size * seq_length);
dim3 block_dim = DS_CUDA_NUM_THREADS; //(hidden_dim / 4);
hipLaunchKernelGGL(( fused_add2_kernel), dim3(grid_dim), dim3(block_dim), 0, stream, total_count, out, inp1, inp2);
}
template <>
void launch_fused_add2<__half>(__half* out,
const __half* inp1,
const __half* inp2,
int batch_size,
int seq_length,
int hidden_dim,
hipStream_t& stream)
{
int total_count = batch_size * seq_length * hidden_dim / 4;
dim3 grid_dim = DS_GET_BLOCKS(total_count); //(batch_size * seq_length);
dim3 block_dim = DS_CUDA_NUM_THREADS; //(hidden_dim / 4);
hipLaunchKernelGGL(( fused_add2_kernel), dim3(grid_dim), dim3(block_dim), 0, stream, total_count, out, inp1, inp2);
}
__global__ void fused_add3_kernel(float* out,
const float* inp1,
const float* inp2,
const float* inp3,
int size,
int row_stride)
{
int row = blockIdx.x;
int id = threadIdx.x;
const float4* inp1_4 = reinterpret_cast<const float4*>(inp1);
const float4* inp2_4 = reinterpret_cast<const float4*>(inp2);
const float4* inp3_4 = reinterpret_cast<const float4*>(inp3);
float4* out_4 = reinterpret_cast<float4*>(out);
float4 val;
float4 inp1_reg = inp1_4[row * row_stride + id];
float4 inp2_reg = inp2_4[row * row_stride + id];
float4 inp3_reg = inp3_4[row * row_stride + id];
val.x = inp1_reg.x + inp2_reg.x + inp3_reg.x;
val.y = inp1_reg.y + inp2_reg.y + inp3_reg.y;
val.z = inp1_reg.z + inp2_reg.z + inp3_reg.z;
val.w = inp1_reg.w + inp2_reg.w + inp3_reg.w;
out_4[row * row_stride + id] = val;
}
__global__ void fused_add3_kernel(__half* out,
const __half* inp1,
const __half* inp2,
const __half* inp3,
int size,
int row_stride)
{
int row = blockIdx.x;
int id = threadIdx.x;
const float2* inp1_arr = reinterpret_cast<const float2*>(inp1);
const float2* inp2_arr = reinterpret_cast<const float2*>(inp2);
const float2* inp3_arr = reinterpret_cast<const float2*>(inp3);
float2 inp1_4 = inp1_arr[row * row_stride + id];
float2 inp2_4 = inp2_arr[row * row_stride + id];
float2 inp3_4 = inp3_arr[row * row_stride + id];
__half2* inp1_h = reinterpret_cast<__half2*>(&inp1_4);
__half2* inp2_h = reinterpret_cast<__half2*>(&inp2_4);
__half2* inp3_h = reinterpret_cast<__half2*>(&inp3_4);
float2 inp1_h_f_0 = __half22float2(inp1_h[0]);
float2 inp1_h_f_1 = __half22float2(inp1_h[1]);
float2 inp2_h_f_0 = __half22float2(inp2_h[0]);
float2 inp2_h_f_1 = __half22float2(inp2_h[1]);
float2 inp3_h_f_0 = __half22float2(inp3_h[0]);
float2 inp3_h_f_1 = __half22float2(inp3_h[1]);
inp1_h_f_0.x += (inp2_h_f_0.x + inp3_h_f_0.x);
inp1_h_f_0.y += (inp2_h_f_0.y + inp3_h_f_0.y);
inp1_h_f_1.x += (inp2_h_f_1.x + inp3_h_f_1.x);
inp1_h_f_1.y += (inp2_h_f_1.y + inp3_h_f_1.y);
float2 val_f;
__half2* val_h = reinterpret_cast<__half2*>(&val_f);
val_h[0] = __float22half2_rn(inp1_h_f_0);
val_h[1] = __float22half2_rn(inp1_h_f_1);
float2* out_4 = reinterpret_cast<float2*>(out);
out_4[row * row_stride + id] = val_f;
}
template <>
void launch_fused_add3<float>(float* out,
const float* inp1,
const float* inp2,
const float* inp3,
int batch_size,
int seq_length,
int hidden_size,
hipStream_t& stream)
{
dim3 grid_dim(batch_size * seq_length);
dim3 block_dim(hidden_size / 4);
hipLaunchKernelGGL(( fused_add3_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
out, inp1, inp2, inp3, (batch_size * seq_length * hidden_size), hidden_size / 4);
}
template <>
void launch_fused_add3<__half>(__half* out,
const __half* inp1,
const __half* inp2,
const __half* inp3,
int batch_size,
int seq_length,
int hidden_size,
hipStream_t& stream)
{
dim3 grid_dim(batch_size * seq_length);
dim3 block_dim(hidden_size / 4);
hipLaunchKernelGGL(( fused_add3_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
out, inp1, inp2, inp3, (batch_size * seq_length * hidden_size), hidden_size / 4);
}
__global__ void fused_add4_kernel(float* out,
const float* inp1,
const float* inp2,
const float* inp3,
const float* inp4,
int size,
int row_stride)
{
int row = blockIdx.x;
int id = threadIdx.x;
const float4* inp1_4 = reinterpret_cast<const float4*>(inp1);
const float4* inp2_4 = reinterpret_cast<const float4*>(inp2);
const float4* inp3_4 = reinterpret_cast<const float4*>(inp3);
const float4* inp4_4 = reinterpret_cast<const float4*>(inp4);
float4* out_4 = reinterpret_cast<float4*>(out);
float4 val;
float4 inp1_reg = inp1_4[row * row_stride + id];
float4 inp2_reg = inp2_4[row * row_stride + id];
float4 inp3_reg = inp3_4[row * row_stride + id];
float4 inp4_reg = inp4_4[row * row_stride + id];
val.x = inp1_reg.x + inp2_reg.x + inp3_reg.x + inp4_reg.x;
val.y = inp1_reg.y + inp2_reg.y + inp3_reg.y + inp4_reg.y;
val.z = inp1_reg.z + inp2_reg.z + inp3_reg.z + inp4_reg.z;
val.w = inp1_reg.w + inp2_reg.w + inp3_reg.w + inp4_reg.w;
out_4[row * row_stride + id] = val;
}
__global__ void fused_add4_kernel(__half* out,
const __half* inp1,
const __half* inp2,
const __half* inp3,
const __half* inp4,
int size,
int row_stride)
{
int row = blockIdx.x;
int id = threadIdx.x;
const float2* inp1_arr = reinterpret_cast<const float2*>(inp1);
const float2* inp2_arr = reinterpret_cast<const float2*>(inp2);
const float2* inp3_arr = reinterpret_cast<const float2*>(inp3);
const float2* inp4_arr = reinterpret_cast<const float2*>(inp4);
float2 inp1_4 = inp1_arr[row * row_stride + id];
float2 inp2_4 = inp2_arr[row * row_stride + id];
float2 inp3_4 = inp3_arr[row * row_stride + id];
float2 inp4_4 = inp4_arr[row * row_stride + id];
__half2* inp1_h = reinterpret_cast<__half2*>(&inp1_4);
__half2* inp2_h = reinterpret_cast<__half2*>(&inp2_4);
__half2* inp3_h = reinterpret_cast<__half2*>(&inp3_4);
__half2* inp4_h = reinterpret_cast<__half2*>(&inp4_4);
float2 inp1_h_f_0 = __half22float2(inp1_h[0]);
float2 inp1_h_f_1 = __half22float2(inp1_h[1]);
float2 inp2_h_f_0 = __half22float2(inp2_h[0]);
float2 inp2_h_f_1 = __half22float2(inp2_h[1]);
float2 inp3_h_f_0 = __half22float2(inp3_h[0]);
float2 inp3_h_f_1 = __half22float2(inp3_h[1]);
float2 inp4_h_f_0 = __half22float2(inp4_h[0]);
float2 inp4_h_f_1 = __half22float2(inp4_h[1]);
inp1_h_f_0.x += (inp2_h_f_0.x + inp3_h_f_0.x + inp4_h_f_0.x);
inp1_h_f_0.y += (inp2_h_f_0.y + inp3_h_f_0.y + inp4_h_f_0.y);
inp1_h_f_1.x += (inp2_h_f_1.x + inp3_h_f_1.x + inp4_h_f_1.x);
inp1_h_f_1.y += (inp2_h_f_1.y + inp3_h_f_1.y + inp4_h_f_1.y);
float2 val_f;
__half2* val_h = reinterpret_cast<__half2*>(&val_f);
val_h[0] = __float22half2_rn(inp1_h_f_0);
val_h[1] = __float22half2_rn(inp1_h_f_1);
float2* out_4 = reinterpret_cast<float2*>(out);
out_4[row * row_stride + id] = val_f;
}
template <>
void launch_fused_add4<float>(float* out,
const float* inp1,
const float* inp2,
const float* inp3,
const float* inp4,
int batch_size,
int seq_length,
int hidden_size,
hipStream_t& stream)
{
dim3 grid_dim(batch_size * seq_length);
dim3 block_dim(hidden_size / 4);
hipLaunchKernelGGL(( fused_add4_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
out, inp1, inp2, inp3, inp4, (batch_size * seq_length * hidden_size), hidden_size / 4);
}
template <>
void launch_fused_add4<__half>(__half* out,
const __half* inp1,
const __half* inp2,
const __half* inp3,
const __half* inp4,
int batch_size,
int seq_length,
int hidden_size,
hipStream_t& stream)
{
dim3 grid_dim(batch_size * seq_length);
dim3 block_dim(hidden_size / 4);
hipLaunchKernelGGL(( fused_add4_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
out, inp1, inp2, inp3, inp4, (batch_size * seq_length * hidden_size), hidden_size / 4);
}
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/
#include "inference_cuda_layers.h"
#ifndef __HIP_PLATFORM_HCC__
#include <cuda_profiler_api.h>
#endif
namespace cg = cooperative_groups;
namespace cg = cooperative_groups;
__global__ void apply_rotary_pos_emb(float* mixed_query,
float* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count,
int max_out_tokens)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int id = threadIdx.x;
int gid = id >> 5;
int lane = id & 0x1f;
unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid;
unsigned offset = head_id * head_size;
unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
unsigned seq_index = head_id % seq_len;
unsigned k_offset = (seq_index + (head_id / seq_len) * max_out_tokens) * head_size;
if (head_id < total_count) {
while (lane < rotary_dim) {
float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = mixed_query[offset + lane];
float k = key_layer[k_offset + lane];
float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
q_rot = g.shfl_xor(q_rot, 1);
k_rot = g.shfl_xor(k_rot, 1);
q = q * cosf(inv_freq) + q_rot * sinf(inv_freq);
k = k * cosf(inv_freq) + k_rot * sinf(inv_freq);
mixed_query[offset + lane] = q;
key_layer[k_offset + lane] = k;
lane += WARP_SIZE;
}
}
}
__global__ void apply_rotary_pos_emb(__half* mixed_query,
__half* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count,
int max_out_tokens)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int id = threadIdx.x;
int gid = id >> 5;
int lane = id & 0x1f;
unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid;
unsigned offset = head_id * head_size;
unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
unsigned seq_index = head_id % seq_len;
unsigned k_offset = (seq_index + (head_id / seq_len) * max_out_tokens) * head_size;
if (head_id < total_count) {
while (lane < rotary_dim) {
float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = (float)mixed_query[offset + lane];
float k = (float)key_layer[k_offset + lane];
float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
q_rot = g.shfl_xor(q_rot, 1);
k_rot = g.shfl_xor(k_rot, 1);
q = q * cosf(inv_freq) + q_rot * sinf(inv_freq);
k = k * cosf(inv_freq) + k_rot * sinf(inv_freq);
mixed_query[offset + lane] = (__half)q;
key_layer[k_offset + lane] = (__half)k;
lane += WARP_SIZE;
}
}
}
__global__ void apply_rotary_pos_emb1(float* mixed_query,
float* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count,
int max_out_tokens)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int id = threadIdx.x;
int gid = id >> 5;
int lane = id & 0x1f;
unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid;
unsigned offset = head_id * head_size;
unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
unsigned seq_index = head_id % seq_len;
unsigned k_offset = (seq_index + (head_id / seq_len) * max_out_tokens) * head_size;
if (head_id < total_count) {
while (lane < rotary_dim) {
float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = mixed_query[offset + lane];
float k = key_layer[k_offset + lane];
float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
q_rot = g.shfl_xor(q_rot, 1);
k_rot = g.shfl_xor(k_rot, 1);
q = q * cosf(inv_freq) + q_rot * sinf(inv_freq);
k = k * cosf(inv_freq) + k_rot * sinf(inv_freq);
mixed_query[offset + lane] = q;
key_layer[k_offset + lane] = k;
lane += WARP_SIZE;
}
}
}
__global__ void apply_rotary_pos_emb1(__half* mixed_query,
__half* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count,
int max_out_tokens)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int id = threadIdx.x;
int gid = id >> 5;
int lane = id & 0x1f;
unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid;
unsigned seq_index = head_id % seq_len;
unsigned offset = head_id * head_size;
unsigned k_offset = (seq_index + (head_id / seq_len) * max_out_tokens) * head_size;
constexpr unsigned mask[32] = {
0x1 | 0x1000, 0x2 | 0x2000, 0x4 | 0x4000, 0x8 | 0x8000, 0x10 | 0x10000,
0x20 | 0x20000, 0x40 | 0x40000, 0x80 | 0x80000, 0x100 | 0x100000, 0x200 | 0x200000,
0x400 | 0x400000, 0x800 | 0x800000, 0x1000 | 0x1, 0x2000 | 0x2, 0x4000 | 0x4,
0x8000 | 0x8, 0x10000 | 0x10, 0x20000 | 0x20, 0x40000 | 0x40, 0x80000 | 0x80,
0x100000 | 0x100, 0x200000 | 0x200, 0x400000 | 0x400, 0x800000 | 0x800, 0x1000000,
0x2000000, 0x4000000, 0x8000000, 0x10000000, 0x20000000,
0x40000000, 0x80000000};
unsigned seq_id = (head_id % seq_len) + seq_offset;
unsigned half_dim = rotary_dim >> 1;
if (head_id < total_count) {
while (lane < rotary_dim) {
float inv_freq = (float)((lane % half_dim) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = (float)mixed_query[offset + lane];
float k = (float)key_layer[k_offset + lane];
float rotary_sign = (lane > (half_dim - 1) ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
auto q_rot_tmp = lane < half_dim ? __shfl_sync(mask[lane], q_rot, lane + half_dim)
: __shfl_sync(mask[lane], q_rot, lane - half_dim);
auto k_rot_tmp = lane < half_dim ? __shfl_sync(mask[lane], k_rot, lane + half_dim)
: __shfl_sync(mask[lane], k_rot, lane - half_dim);
q = q * cosf(inv_freq) + q_rot_tmp * sinf(inv_freq);
k = k * cosf(inv_freq) + k_rot_tmp * sinf(inv_freq);
mixed_query[offset + lane] = (__half)q;
key_layer[k_offset + lane] = (__half)k;
lane += WARP_SIZE;
}
}
}
template <typename T>
void launch_apply_rotary_pos_emb(T* mixed_query,
T* key_layer,
unsigned head_size,
unsigned seq_len,
unsigned rotary_dim,
unsigned offset,
unsigned num_heads,
unsigned batch,
bool rotate_half,
bool rotate_every_two,
hipStream_t stream,
int max_out_tokens)
{
int total_count = batch * num_heads * seq_len;
dim3 block_dims(1024);
dim3 grid_dims((total_count - 1) / MAX_WARP_NUM + 1); // (batch_size);
if (rotate_every_two)
hipLaunchKernelGGL(( apply_rotary_pos_emb), dim3(grid_dims), dim3(block_dims), 0, stream, mixed_query,
key_layer,
rotary_dim,
seq_len,
offset,
num_heads,
head_size,
total_count,
max_out_tokens);
else if (rotate_half)
hipLaunchKernelGGL(( apply_rotary_pos_emb1), dim3(grid_dims), dim3(block_dims), 0, stream, mixed_query,
key_layer,
rotary_dim,
seq_len,
offset,
num_heads,
head_size,
total_count,
max_out_tokens);
}
template void launch_apply_rotary_pos_emb<float>(float*,
float*,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
bool,
bool,
hipStream_t,
int);
template void launch_apply_rotary_pos_emb<__half>(__half*,
__half*,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
bool,
bool,
hipStream_t,
int);
/*
__global__ void apply_rotary_pos_emb(float* mixed_query,
float* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int id = threadIdx.x;
int gid = id >> 5;
int lane = id & 0x1f;
unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid;
unsigned offset = head_id * head_size;
unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
if (head_id < total_count) {
while (lane < rotary_dim) {
float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = mixed_query[offset + lane];
float k = key_layer[offset + lane];
float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
q_rot = g.shfl_xor(q_rot, 1);
k_rot = g.shfl_xor(k_rot, 1);
q = q * cosf(inv_freq) + q_rot * sinf(inv_freq);
k = k * cosf(inv_freq) + k_rot * sinf(inv_freq);
mixed_query[offset + lane] = q;
key_layer[offset + lane] = k;
lane += WARP_SIZE;
}
}
}
__global__ void apply_rotary_pos_emb(__half* mixed_query,
__half* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count)
{
#if __CUDA_ARCH__ >= 700
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int id = threadIdx.x;
int gid = id >> 5;
int lane = id & 0x1f;
unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid;
unsigned offset = head_id * head_size;
constexpr unsigned mask[32] = {0x1 | 0x1000, 0x2 | 0x2000, 0x4 | 0x4000, 0x8 | 0x8000,
0x10 | 0x10000, 0x20 | 0x20000, 0x40 | 0x40000, 0x80 | 0x80000,
0x100 | 0x100000, 0x200 | 0x200000, 0x400 | 0x400000, 0x800 | 0x800000,
0x1000 | 0x1, 0x2000 | 0x2, 0x4000 | 0x4, 0x8000 | 0x8,
0x10000 | 0x10, 0x20000 | 0x20, 0x40000 | 0x40, 0x80000 | 0x80,
0x100000 | 0x100, 0x200000 | 0x200, 0x400000 | 0x400, 0x800000 | 0x800,
0x1000000, 0x2000000, 0x4000000, 0x8000000,
0x10000000, 0x20000000, 0x40000000, 0x80000000};
unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;
if (head_id < total_count) {
while (lane < rotary_dim) {
//float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
float inv_freq = (float)((lane % (rotary_dim >> 1)) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = (float)mixed_query[offset + lane];
float k = (float)key_layer[offset + lane];
float rotary_sign = (lane > 11 ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
auto q_rot_tmp = lane < 12 ? __shfl_sync(mask[lane], q_rot, lane + 12) : __shfl_sync(mask[lane],
q_rot, lane - 12);//g.shfl_xor(q_rot, 12); auto k_rot_tmp = lane < 12 ? __shfl_sync(mask[lane],
k_rot, lane + 12) : __shfl_sync(mask[lane], k_rot, lane - 12);//g.shfl_xor(k_rot, 12); q = q *
cosf(inv_freq) + q_rot_tmp * sinf(inv_freq); k = k * cosf(inv_freq) + k_rot_tmp * sinf(inv_freq);
mixed_query[offset + lane] = (__half)q;
key_layer[offset + lane] = (__half)k;
lane += WARP_SIZE;
}
}
#endif
}
template <typename T>
void launch_apply_rotary_pos_emb(T* mixed_query,
T* key_layer,
unsigned head_size,
unsigned seq_len,
unsigned rotary_dim,
unsigned offset,
unsigned num_heads,
unsigned batch,
hipStream_t stream)
{
int total_count = batch * num_heads * seq_len;
dim3 block_dims(1024);
dim3 grid_dims((total_count - 1) / MAX_WARP_NUM + 1); // (batch_size);
hipLaunchKernelGGL((
apply_rotary_pos_emb), dim3(grid_dims), dim3(block_dims), 0, stream,
mixed_query, key_layer, rotary_dim, seq_len, offset, num_heads, head_size, total_count);
}
template void launch_apply_rotary_pos_emb<float>(float*,
float*,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
hipStream_t);
template void launch_apply_rotary_pos_emb<__half>(__half*,
__half*,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
hipStream_t);
*/
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/
#include "inference_cuda_layers.h"
#define MAX_QUANTIZE_GROUPING 1024
#define loop_unroll 1
#define loop_unroll_bits 1
__global__ void dequantize_kernel(float* output,
const int8_t* input,
const float* qscale,
int output_size,
int hidden_dim,
int groups,
int merge_count)
{
unsigned merge_hidden = hidden_dim >> merge_count;
unsigned quantization_stride = (merge_hidden * output_size) / groups;
unsigned bid = blockIdx.x;
unsigned tid = threadIdx.x;
while (tid < output_size) {
unsigned w_index = bid / merge_hidden;
unsigned q_index = tid + bid * output_size;
auto q = input[q_index];
unsigned merge_hidden_total = w_index * merge_hidden;
unsigned scale_index =
((((bid - merge_hidden_total) + tid * merge_hidden) / quantization_stride)
<< merge_count) +
w_index;
float scale_data = qscale[scale_index];
output[q_index] = (scale_data * (float)q);
tid += blockDim.x;
}
}
__global__ void dequantize_kernel(__half* output,
const int8_t* input,
const float* qscale,
unsigned output_size,
unsigned hidden_dim,
unsigned groups,
unsigned merge_count)
{
unsigned merge_hidden = hidden_dim >> merge_count;
unsigned quantization_stride = (merge_hidden * output_size) / groups;
unsigned bid = blockIdx.x;
unsigned tid = threadIdx.x;
while (tid < output_size) {
unsigned w_index = bid / merge_hidden;
unsigned q_index = tid + bid * output_size;
auto q = input[q_index];
unsigned merge_hidden_total = w_index * merge_hidden;
unsigned scale_index =
((((bid - merge_hidden_total) + tid * merge_hidden) / quantization_stride)
<< merge_count) +
w_index;
float scale_data = qscale[scale_index];
output[q_index] = __float2half(scale_data * (float)q);
tid += blockDim.x;
}
}
template <typename T>
void launch_dequantize(T* output,
const int8_t* input,
const float* qscale,
unsigned output_size,
unsigned hidden_dim,
unsigned groups,
unsigned merge_count,
hipStream_t stream)
{
unsigned threads = 1024;
dim3 block_dims(threads);
dim3 grid_dims(hidden_dim);
hipLaunchKernelGGL(( dequantize_kernel), dim3(grid_dims), dim3(block_dims), 0, stream,
output, input, qscale, output_size, hidden_dim, groups, merge_count);
}
template void launch_dequantize<float>(float*,
const int8_t*,
const float*,
unsigned,
unsigned,
unsigned,
unsigned,
hipStream_t);
template void launch_dequantize<__half>(__half*,
const int8_t*,
const float*,
unsigned,
unsigned,
unsigned,
unsigned,
hipStream_t);
__global__ void dequantize_kernel(float* output,
const int8_t* input,
const float* qscale,
int hidden_dim,
unsigned merge_hidden,
int cnt)
{
}
__global__ void dequantize_kernel(__half* output,
const int8_t* input,
const float* qscale,
unsigned hidden_dim,
unsigned merge_hidden,
int cnt)
{
unsigned bid = blockIdx.x * gridDim.y + blockIdx.y;
unsigned tid = threadIdx.x;
float local_scale = qscale[blockIdx.x];
const float* input_cast = reinterpret_cast<const float*>(input);
float2* output_cast = reinterpret_cast<float2*>(output);
input_cast += bid * merge_hidden;
output_cast += bid * merge_hidden;
for (int c = 0; c < cnt; c++) {
if (tid < merge_hidden) {
float q = input_cast[tid];
int8_t* q_int8 = (int8_t*)&q;
float2 q_f;
__half* q_h = (__half*)&q_f;
q_h[0] = __float2half(local_scale * (float)q_int8[0]);
q_h[1] = __float2half(local_scale * (float)q_int8[1]);
q_h[2] = __float2half(local_scale * (float)q_int8[2]);
q_h[3] = __float2half(local_scale * (float)q_int8[3]);
output_cast[tid] = q_f;
tid += blockDim.x;
}
}
}
template <typename T>
void launch_dequantize(T* output,
const int8_t* input,
const float* qscale,
unsigned output_size,
unsigned hidden_dim,
unsigned groups,
hipStream_t stream)
{
unsigned threads = 1024;
hidden_dim /= 4;
unsigned hid_cnt = threads / hidden_dim;
unsigned thd_cnt = (hidden_dim - 1) / threads + 1;
hid_cnt = hid_cnt > 0 ? hid_cnt : 1;
unsigned blocks = (output_size + hid_cnt * groups - 1) / (hid_cnt * groups);
dim3 block_dims(threads);
dim3 grid_dims(groups, blocks);
hipLaunchKernelGGL(( dequantize_kernel), dim3(grid_dims), dim3(block_dims), 0, stream,
output, input, qscale, hidden_dim, hid_cnt * hidden_dim, thd_cnt);
}
template void launch_dequantize<float>(float*,
const int8_t*,
const float*,
unsigned,
unsigned,
unsigned,
hipStream_t);
template void launch_dequantize<__half>(__half*,
const int8_t*,
const float*,
unsigned,
unsigned,
unsigned,
hipStream_t);
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/
#include "conversion_utils.h"
#include "inference_cuda_layers.h"
#include "memory_access_utils.h"
namespace cg = cooperative_groups;
#define MAX_CAP 4
#define MAX_SEQ 2048
inline __device__ float gelu(const float x)
{
const float sqrt_param = 0.79788456080286535587989211986876f;
const float mul_param = 0.044715;
return x * 0.5f * (1.0f + tanhf(sqrt_param * (x + mul_param * x * x * x)));
}
/*
In-place gelu(biasAdd(x)) for channels last
*/
template <typename T>
__global__ void fused_bias_gelu(T* input, const T* bias, int total_count, int intermediate_size)
{
// Input restriction: intermediate_size % vals_per_access == 0
constexpr int granularity = 16;
constexpr int values_per_access = granularity / sizeof(T);
const int offset = (blockIdx.x * blockDim.x + threadIdx.x) * values_per_access;
if (offset < total_count) {
T data[values_per_access];
T data_bias[values_per_access];
mem_access::load_global<granularity>(data, input + offset);
mem_access::load_global<granularity>(data_bias, bias + (offset % intermediate_size));
#pragma unroll
for (int i = 0; i < values_per_access; i++) {
float data_f = conversion::to<float>(data[i]);
float bias_f = conversion::to<float>(data_bias[i]);
data[i] = conversion::to<T>(gelu(data_f + bias_f));
}
mem_access::store_global<granularity>(input + offset, data);
}
}
template <typename T>
void launch_bias_gelu(T* input,
const T* bias,
int intermediate_size,
int batch_size,
hipStream_t stream)
{
constexpr int threads = 1024;
constexpr int granularity = 16;
const int total_count = batch_size * intermediate_size;
const int elems_per_block = threads * (granularity / sizeof(T));
dim3 block_dims(threads);
dim3 grid_dims((total_count + elems_per_block - 1) / elems_per_block);
hipLaunchKernelGGL(( fused_bias_gelu), dim3(grid_dims), dim3(block_dims), 0, stream,
input, bias, total_count, intermediate_size);
}
template void launch_bias_gelu<float>(float*, const float*, int, int, hipStream_t);
template void launch_bias_gelu<__half>(__half*, const __half*, int, int, hipStream_t);
/*
In-place channels-last bias add
*/
template <typename T>
__global__ void fused_bias_add(T* input, const T* bias, int total_count, int intermediate_size)
{
// Input restriction: intermediate_size % vals_per_access == 0
constexpr int granularity = 16;
constexpr int values_per_access = granularity / sizeof(T);
const int offset = (blockIdx.x * blockDim.x + threadIdx.x) * values_per_access;
if (offset < total_count) {
T data[values_per_access];
T data_bias[values_per_access];
mem_access::load_global<granularity>(data, input + offset);
mem_access::load_global<granularity>(data_bias, bias + (offset % intermediate_size));
#pragma unroll
for (int i = 0; i < values_per_access; i++) {
float data_f = conversion::to<float>(data[i]);
float bias_f = conversion::to<float>(data_bias[i]);
data[i] = conversion::to<T>(data_f + bias_f);
}
mem_access::store_global<granularity>(input + offset, data);
}
}
template <typename T>
void launch_bias_add(T* input,
const T* bias,
int intermediate_size,
int batch_size,
hipStream_t stream)
{
constexpr int threads = 1024;
constexpr int granularity = 16;
const int total_count = batch_size * intermediate_size;
const int elems_per_block = threads * (granularity / sizeof(T));
dim3 block_dims(threads);
dim3 grid_dims((total_count + elems_per_block - 1) / elems_per_block);
hipLaunchKernelGGL(( fused_bias_add), dim3(grid_dims), dim3(block_dims), 0, stream,
input, bias, total_count, intermediate_size);
}
template void launch_bias_add<float>(float*, const float*, int, int, hipStream_t);
template void launch_bias_add<__half>(__half*, const __half*, int, int, hipStream_t);
__global__ void fused_bias_residual(float* residual,
const float* hidden_state,
const float* attn,
const float* bias,
const float* attn_bias,
const int total_count,
const int intermediate_size,
const float mp_scale,
const bool preln)
{
float4* res_fl4_ptr = reinterpret_cast<float4*>(residual);
const float4* hs_fl4_ptr = reinterpret_cast<const float4*>(hidden_state);
const float4* attn_fl4_ptr = reinterpret_cast<const float4*>(attn);
const float4* bias_fl4_ptr = reinterpret_cast<const float4*>(bias);
const float4* attn_bias_fl4_ptr = reinterpret_cast<const float4*>(attn_bias);
const int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset < total_count) {
float4 res_fl4 = res_fl4_ptr[offset];
const float4 hs_fl4 = hs_fl4_ptr[offset];
const float4 attn_fl4 = attn_fl4_ptr[offset];
const float4 bias_fl4 = bias_fl4_ptr[offset % intermediate_size];
const float4 attn_bias_fl4 = attn_bias_fl4_ptr[offset % intermediate_size];
if (preln) {
// residual = (residual + attention + bias + attention_bias) *
// mp_scale + hidden_state
res_fl4.x =
(res_fl4.x + attn_fl4.x + bias_fl4.x + attn_bias_fl4.x) * mp_scale + (hs_fl4.x);
res_fl4.y =
(res_fl4.y + attn_fl4.y + bias_fl4.y + attn_bias_fl4.y) * mp_scale + (hs_fl4.y);
res_fl4.z =
(res_fl4.z + attn_fl4.z + bias_fl4.z + attn_bias_fl4.z) * mp_scale + (hs_fl4.z);
res_fl4.w =
(res_fl4.w + attn_fl4.w + bias_fl4.w + attn_bias_fl4.w) * mp_scale + (hs_fl4.w);
} else {
// residual += hidden_state + bias
res_fl4.x = res_fl4.x + hs_fl4.x + bias_fl4.x;
res_fl4.y = res_fl4.y + hs_fl4.y + bias_fl4.y;
res_fl4.z = res_fl4.z + hs_fl4.z + bias_fl4.z;
res_fl4.w = res_fl4.w + hs_fl4.w + bias_fl4.w;
}
res_fl4_ptr[offset] = res_fl4;
}
}
__global__ void fused_bias_residual(__half* residual,
const __half* hidden_state,
const __half* attn,
const __half* bias,
const __half* attn_bias,
const int total_count,
const int intermediate_size,
const float mp_scale,
const bool preln)
{
float2* res_fl2_ptr = reinterpret_cast<float2*>(residual);
const float2* hs_fl2_ptr = reinterpret_cast<const float2*>(hidden_state);
const float2* attn_fl2_ptr = reinterpret_cast<const float2*>(attn);
const float2* bias_fl2_ptr = reinterpret_cast<const float2*>(bias);
const float2* attn_bias_fl2_ptr = reinterpret_cast<const float2*>(attn_bias);
const int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset < total_count) {
float2 res_fl2 = res_fl2_ptr[offset];
const float2 hs_fl2 = hs_fl2_ptr[offset];
const float2 attn_fl2 = attn_fl2_ptr[offset];
const float2 bias_fl2 = bias_fl2_ptr[offset % intermediate_size];
const float2 attn_bias_fl2 = attn_bias_fl2_ptr[offset % intermediate_size];
__half2* res_half2 = reinterpret_cast<__half2*>(&res_fl2);
const __half2* hs_half2 = reinterpret_cast<const __half2*>(&hs_fl2);
const __half2* attn_half2 = reinterpret_cast<const __half2*>(&attn_fl2);
const __half2* bias_half2 = reinterpret_cast<const __half2*>(&bias_fl2);
const __half2* attn_bias_half2 = reinterpret_cast<const __half2*>(&attn_bias_fl2);
float2 res_low = __half22float2(res_half2[0]);
float2 res_high = __half22float2(res_half2[1]);
const float2 hs_low = __half22float2(hs_half2[0]);
const float2 hs_high = __half22float2(hs_half2[1]);
const float2 attn_low = __half22float2(attn_half2[0]);
const float2 attn_high = __half22float2(attn_half2[1]);
const float2 bias_low = __half22float2(bias_half2[0]);
const float2 bias_high = __half22float2(bias_half2[1]);
const float2 attn_bias_low = __half22float2(attn_bias_half2[0]);
const float2 attn_bias_high = __half22float2(attn_bias_half2[1]);
if (preln) {
// residual = (residual + attention + bias + attention_bias) *
// mp_scale + hidden_state
res_low.x =
(res_low.x + attn_low.x + bias_low.x + attn_bias_low.x) * mp_scale + hs_low.x;
res_low.y =
(res_low.y + attn_low.y + bias_low.y + attn_bias_low.y) * mp_scale + hs_low.y;
res_high.x =
(res_high.x + attn_high.x + bias_high.x + attn_bias_high.x) * mp_scale + hs_high.x;
res_high.y =
(res_high.y + attn_high.y + bias_high.y + attn_bias_high.y) * mp_scale + hs_high.y;
} else {
// residual += hidden_state + bias
res_low.x = (res_low.x + hs_low.x + bias_low.x);
res_low.y = (res_low.y + hs_low.y + bias_low.y);
res_high.x = (res_high.x + hs_high.x + bias_high.x);
res_high.y = (res_high.y + hs_high.y + bias_high.y);
}
res_half2[0] = __float22half2_rn(res_low);
res_half2[1] = __float22half2_rn(res_high);
res_fl2_ptr[offset] = res_fl2;
}
}
template <typename T>
void launch_bias_residual(T* residual,
T* hidden_state,
T* attn,
T* bias,
T* attn_bias,
int batch,
int hidden_dim,
int mp_size,
bool preln,
hipStream_t stream)
{
int total_count = batch * hidden_dim / 4;
dim3 block_dims(1024);
dim3 grid_dims((total_count - 1) / 1024 + 1); // (batch_size);
hipLaunchKernelGGL(( fused_bias_residual), dim3(grid_dims), dim3(block_dims), 0, stream, residual,
hidden_state,
attn,
bias,
attn_bias,
total_count,
hidden_dim / 4,
1.0 / mp_size,
preln);
}
template void launch_bias_residual<
float>(float*, float*, float*, float*, float*, int, int, int, bool, hipStream_t);
template void launch_bias_residual<
__half>(__half*, __half*, __half*, __half*, __half*, int, int, int, bool, hipStream_t);
__global__ void gptj_residual_add(float* residual,
const float* hidden_state,
const float* attn,
const float* bias,
const float* attn_bias,
const int total_count,
const int intermediate_size,
const float mp_scale)
{
float4* res_fl4_ptr = reinterpret_cast<float4*>(residual);
const float4* hs_fl4_ptr = reinterpret_cast<const float4*>(hidden_state);
const float4* attn_fl4_ptr = reinterpret_cast<const float4*>(attn);
const float4* bias_fl4_ptr = reinterpret_cast<const float4*>(bias);
const float4* attn_bias_fl4_ptr = reinterpret_cast<const float4*>(attn_bias);
const int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset < total_count) {
float4 res_fl4 = res_fl4_ptr[offset];
const float4 hs_fl4 = hs_fl4_ptr[offset];
const float4 attn_fl4 = attn_fl4_ptr[offset];
const float4 bias_fl4 = bias_fl4_ptr[offset % intermediate_size];
if (attn_bias) {
float4 attn_bias_fl4 = attn_bias_fl4_ptr[offset % intermediate_size];
// residual += attention_bias
res_fl4.x += attn_bias_fl4.x;
res_fl4.y += attn_bias_fl4.y;
res_fl4.z += attn_bias_fl4.z;
res_fl4.w += attn_bias_fl4.w;
}
// residual = hidden_state + attention + (residual + bias) * mp_scale
res_fl4.x = hs_fl4.x + attn_fl4.x + (res_fl4.x + bias_fl4.x) * mp_scale;
res_fl4.y = hs_fl4.y + attn_fl4.y + (res_fl4.y + bias_fl4.y) * mp_scale;
res_fl4.z = hs_fl4.z + attn_fl4.z + (res_fl4.z + bias_fl4.z) * mp_scale;
res_fl4.w = hs_fl4.w + attn_fl4.w + (res_fl4.w + bias_fl4.w) * mp_scale;
res_fl4_ptr[offset] = res_fl4;
}
}
__global__ void gptj_residual_add(__half* residual,
const __half* hidden_state,
const __half* attn,
const __half* bias,
const __half* attn_bias,
const int total_count,
const int intermediate_size,
const float mp_scale)
{
float2* res_fl2_ptr = reinterpret_cast<float2*>(residual);
const float2* hs_fl2_ptr = reinterpret_cast<const float2*>(hidden_state);
const float2* attn_fl2_ptr = reinterpret_cast<const float2*>(attn);
const float2* bias_fl2_ptr = reinterpret_cast<const float2*>(bias);
const float2* attn_bias_fl2_ptr = reinterpret_cast<const float2*>(attn_bias);
const int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset < total_count) {
float2 res_fl2 = res_fl2_ptr[offset];
const float2 hs_fl2 = hs_fl2_ptr[offset];
const float2 attn_fl2 = attn_fl2_ptr[offset];
const float2 bias_fl2 = bias_fl2_ptr[offset % intermediate_size];
__half2* res_half2 = reinterpret_cast<__half2*>(&res_fl2);
const __half2* hs_half2 = reinterpret_cast<const __half2*>(&hs_fl2);
const __half2* attn_half2 = reinterpret_cast<const __half2*>(&attn_fl2);
const __half2* bias_half2 = reinterpret_cast<const __half2*>(&bias_fl2);
float2 res_low = __half22float2(res_half2[0]);
float2 res_high = __half22float2(res_half2[1]);
const float2 hs_low = __half22float2(hs_half2[0]);
const float2 hs_high = __half22float2(hs_half2[1]);
const float2 attn_low = __half22float2(attn_half2[0]);
const float2 attn_high = __half22float2(attn_half2[1]);
const float2 bias_low = __half22float2(bias_half2[0]);
const float2 bias_high = __half22float2(bias_half2[1]);
if (attn_bias) {
const float2 attn_bias_fl2 = attn_bias_fl2_ptr[offset % intermediate_size];
const __half2* attn_bias_half2 = reinterpret_cast<const __half2*>(&attn_bias_fl2);
const float2 attn_bias_low = __half22float2(attn_bias_half2[0]);
const float2 attn_bias_high = __half22float2(attn_bias_half2[1]);
// residual += attention_bias
res_low.x += attn_bias_low.x;
res_low.y += attn_bias_low.y;
res_high.x += attn_bias_high.x;
res_high.y += attn_bias_high.y;
}
// residual = hidden_state + attention + (residual + bias) * mp_scale
res_low.x = attn_low.x + hs_low.x + (res_low.x + bias_low.x) * mp_scale;
res_low.y = attn_low.y + hs_low.y + (res_low.y + bias_low.y) * mp_scale;
res_high.x = attn_high.x + hs_high.x + (res_high.x + bias_high.x) * mp_scale;
res_high.y = attn_high.y + hs_high.y + (res_high.y + bias_high.y) * mp_scale;
res_half2[0] = __float22half2_rn(res_low);
res_half2[1] = __float22half2_rn(res_high);
res_fl2_ptr[offset] = res_fl2;
}
}
template <typename T>
void launch_gptj_residual_add(T* residual,
T* hidden_state,
T* attn,
T* bias,
T* attn_bias,
int hidden_dim,
int batch,
int mp_size,
hipStream_t stream)
{
int total_count = batch * hidden_dim / 4;
dim3 block_dims(1024);
dim3 grid_dims((total_count - 1) / 1024 + 1); // (batch_size);
hipLaunchKernelGGL(( gptj_residual_add), dim3(grid_dims), dim3(block_dims), 0, stream,
residual, hidden_state, attn, bias, attn_bias, total_count, hidden_dim / 4, 1.0 / mp_size);
}
template void launch_gptj_residual_add<float>(float*,
float*,
float*,
float*,
float*,
int,
int,
int,
hipStream_t);
template void launch_gptj_residual_add<__half>(__half*,
__half*,
__half*,
__half*,
__half*,
int,
int,
int,
hipStream_t);
template <typename T>
__global__ void moe_res_matmul(T* residual, T* coef, T* mlp_out, int seq_len, int hidden_dim)
{
constexpr int granularity = 16;
constexpr int vals_per_access = granularity / sizeof(T);
T* residual_seq = residual + blockIdx.x * hidden_dim;
T* mlp_out_seq = mlp_out + blockIdx.x * hidden_dim;
for (unsigned tid = threadIdx.x * vals_per_access; tid < hidden_dim;
tid += blockDim.x * vals_per_access) {
T mlp[vals_per_access];
T res[vals_per_access];
T coef1[vals_per_access];
T coef2[vals_per_access];
mem_access::load_global<granularity>(mlp, mlp_out_seq + tid);
mem_access::load_global<granularity>(res, residual_seq + tid);
mem_access::load_global<granularity>(coef1, coef + tid);
mem_access::load_global<granularity>(coef2, coef + tid + hidden_dim);
#pragma unroll
for (int idx = 0; idx < vals_per_access; idx++) {
mlp[idx] = mlp[idx] * coef2[idx] + res[idx] * coef1[idx];
}
mem_access::store_global<granularity>(mlp_out_seq + tid, mlp);
}
}
template <typename T>
void launch_moe_res_matmul(T* residual,
T* coef,
T* mlp_out,
int seq_len,
int hidden_dim,
hipStream_t stream)
{
dim3 grid_dim(seq_len);
dim3 block_dim(1024);
hipLaunchKernelGGL(( moe_res_matmul), dim3(grid_dim), dim3(block_dim), 0, stream,
residual, coef, mlp_out, seq_len, hidden_dim);
}
template void launch_moe_res_matmul(float* residual,
float* coef,
float* mlp_out,
int seq_len,
int hidden_dim,
hipStream_t stream);
template void launch_moe_res_matmul(__half* residual,
__half* coef,
__half* mlp_out,
int seq_len,
int hidden_dim,
hipStream_t stream);
__global__ void pad_data_kernel(__half* padded_output,
__half* output,
int head_size,
int padded_head_size)
{
float4* padded_output_cast = reinterpret_cast<float4*>(padded_output);
float4* output_cast = reinterpret_cast<float4*>(output);
int bid = blockIdx.x * (blockDim.y) + threadIdx.y;
int idx = threadIdx.x;
padded_output_cast += (bid * padded_head_size);
output_cast += (bid * head_size);
float4 ZERO;
const __half2 zero_h = __float2half2_rn(0.f);
__half2* ZERO_h = reinterpret_cast<__half2*>(&ZERO);
#pragma unroll
for (int i = 0; i < 4; i++) ZERO_h[i] = zero_h;
if (idx < head_size)
padded_output_cast[idx] = output_cast[idx];
else
padded_output_cast[idx] = ZERO;
}
__global__ void pad_data_kernel(float* padded_output,
float* output,
int head_size,
int padded_head_size)
{
}
template <typename T>
void pad_data(T* padded_output,
T* output,
int bsz,
int head_size,
int padded_head_size,
hipStream_t stream)
{
dim3 grid_dim((bsz - 1) / 16 + 1);
dim3 block_dim(padded_head_size / 8, 16);
hipLaunchKernelGGL(( pad_data_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
padded_output, output, head_size / 8, padded_head_size / 8);
}
template void pad_data(__half* padded_output,
__half* output,
int bsz,
int head_size,
int padded_head_size,
hipStream_t stream);
template void pad_data(float* padded_output,
float* output,
int bsz,
int head_size,
int padded_head_size,
hipStream_t stream);
__global__ void pad_head_seq_kernel(__half* padded_output,
__half* output,
int seq_len,
int padded_seq_len,
int head_size,
int padded_head_size)
{
float4* padded_output_cast = reinterpret_cast<float4*>(padded_output);
float4* output_cast = reinterpret_cast<float4*>(output);
int bsz = blockIdx.x;
int bid = blockIdx.y * (blockDim.y) + threadIdx.y;
int idx = threadIdx.x;
padded_output_cast += (bsz * padded_seq_len + bid) * padded_head_size;
output_cast += (bsz * seq_len + bid) * head_size;
float4 ZERO;
const __half2 zero_h = __float2half2_rn(0.f);
__half2* ZERO_h = reinterpret_cast<__half2*>(&ZERO);
#pragma unroll
for (int i = 0; i < 4; i++) ZERO_h[i] = zero_h;
if (idx < head_size && bid < seq_len)
padded_output_cast[idx] = output_cast[idx];
else
padded_output_cast[idx] = ZERO;
}
__global__ void pad_head_seq_kernel(float* padded_output,
float* output,
int seq_len,
int padded_seq_len,
int head_size,
int padded_head_size)
{
}
template <typename T>
void pad_head_seq(T* padded_output,
T* output,
int bsz,
int seq_len,
int padded_seq_len,
int head_size,
int padded_head_size,
hipStream_t stream)
{
dim3 grid_dim(bsz, padded_seq_len / 16);
dim3 block_dim(padded_head_size / 8, 16);
hipLaunchKernelGGL(( pad_head_seq_kernel), dim3(grid_dim), dim3(block_dim), 0, stream,
padded_output, output, seq_len, padded_seq_len, head_size / 8, padded_head_size / 8);
}
template void pad_head_seq(__half* padded_output,
__half* output,
int bsz,
int seq_len,
int padded_seq_len,
int head_size,
int padded_head_size,
hipStream_t stream);
template void pad_head_seq(float* padded_output,
float* output,
int bsz,
int seq_len,
int padded_seq_len,
int head_size,
int padded_head_size,
hipStream_t stream);
// TODO(cmikeh2): evaluate different GeLU performance
__device__ __forceinline__ float old_gelu(float val)
{
// 1 / sqrt(2)
constexpr float rsqrt_2 = 0.707106769084930419922;
return val * 0.5f * (1.0f + erff(val * rsqrt_2));
}
namespace fused_geglu {
constexpr int threads = 256;
constexpr int steps = 2;
constexpr int granularity = 16;
} // namespace fused_geglu
template <typename T>
__global__ void fused_bias_geglu(T* output,
const T* activation,
const T* bias,
int base_channels,
int total_elems)
{
constexpr int T_per_access = fused_geglu::granularity / sizeof(T);
constexpr int T_per_step = T_per_access * fused_geglu::threads;
constexpr int T_per_block = T_per_step * fused_geglu::steps;
const int id = blockIdx.x * T_per_block + threadIdx.x * T_per_access;
#pragma unroll
for (int i = 0; i < fused_geglu::steps; i++) {
T activation_buffer_1[T_per_access];
T activation_buffer_2[T_per_access];
T bias_buffer_1[T_per_access];
T bias_buffer_2[T_per_access];
const int iter_id = id + T_per_step * i;
if (iter_id < total_elems) {
const int channel_id = iter_id % base_channels;
const int seq_id = iter_id / base_channels;
const int seq_offset = seq_id * base_channels * 2;
mem_access::load_global<fused_geglu::granularity>(activation_buffer_1,
activation + seq_offset + channel_id);
mem_access::load_global<fused_geglu::granularity>(
activation_buffer_2, activation + seq_offset + channel_id + base_channels);
mem_access::load_global<fused_geglu::granularity>(bias_buffer_1, bias + channel_id);
mem_access::load_global<fused_geglu::granularity>(bias_buffer_2,
bias + channel_id + base_channels);
// Since the GeLU is going to happen at float, might as well
// convert
#pragma unroll
for (int v = 0; v < T_per_access; v++) {
T hidden_state = activation_buffer_1[v] + bias_buffer_1[v];
T pre_gate = activation_buffer_2[v] + bias_buffer_2[v];
float gate_f = old_gelu(conversion::to<float>(pre_gate));
T gate = conversion::to<T>(gate_f);
activation_buffer_1[v] = hidden_state * gate;
}
mem_access::store_global<fused_geglu::granularity>(output + iter_id,
activation_buffer_1);
}
}
}
template <typename T>
void launch_fused_bias_geglu(T* output,
const T* activation,
const T* bias,
int rows,
int elems_per_row,
hipStream_t stream)
{
/*
Fused bias GEGLU is a variant of the gated activation functions.
The input here is a matrix of [batch, seq_len, 2 * intermediate_dim]
where the second half of the channels act as GeLU gates for the first
half.
*/
// Re-derive the above figures
constexpr int T_per_access = fused_geglu::granularity / sizeof(T);
constexpr int T_per_step = T_per_access * fused_geglu::threads;
constexpr int T_per_block = T_per_step * fused_geglu::steps;
const int base_channels = elems_per_row / 2;
const int total_elems = base_channels * rows;
dim3 block(fused_geglu::threads);
dim3 grid((total_elems + T_per_block - 1) / T_per_block);
hipLaunchKernelGGL(( fused_bias_geglu), dim3(grid), dim3(block), 0, stream,
output, activation, bias, base_channels, total_elems);
}
template void launch_fused_bias_geglu(__half*,
const __half*,
const __half*,
int,
int,
hipStream_t);
template void launch_fused_bias_geglu(float*, const float*, const float*, int, int, hipStream_t);
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/
#include "conversion_utils.h"
#include "ds_kernel_utils.h"
#include "inference_cuda_layers.h"
#include "memory_access_utils.h"
#include "reduction_utils.h"
namespace cg = cooperative_groups;
using rop = reduce::ROpType;
namespace ln {
constexpr int granularity = 16;
} // namespace ln
/*
Primary layer norm implementation. Assumes elems_per_row % 8
is equal to 0.
Args:
output: buffer for output data
vals: buffer for input data
gamma: gain for normalization
beta: bias for normalization
epsilon: numeric stability
elems_per_row: number of elements each block will normalize
*/
template <typename T, int unRoll, int threadsPerGroup, int maxThreads>
__global__ void fused_ln(T* output,
const T* vals,
const T* gamma,
const T* beta,
float epsilon,
int elems_per_row)
{
constexpr int T_per_load = ln::granularity / sizeof(T);
cg::thread_block tb = cg::this_thread_block();
cg::thread_block_tile<hw_warp_size> warp = cg::tiled_partition<hw_warp_size>(tb);
// X-dimension of the block
const int block_offset = (tb.group_index().x * (maxThreads / threadsPerGroup) * elems_per_row) +
(tb.thread_index().y * elems_per_row);
const int thread_offset = tb.thread_index().x * T_per_load;
const int base_offset = block_offset + thread_offset;
const int stride = tb.size() * T_per_load;
float sum = reduce::init<rop::Add, float>();
const T* input_base = vals + base_offset;
T local_buffer[unRoll * T_per_load];
#pragma unRoll
for (int i = 0; i < unRoll; i++) {
T* iteration_buffer = local_buffer + i * T_per_load;
T residual_buffer[T_per_load];
T bias_buffer[T_per_load];
mem_access::load_global<ln::granularity>(
iteration_buffer, input_base + i * stride, thread_offset + i * stride < elems_per_row);
#pragma unRoll
for (int j = 0; j < T_per_load; j++) {
float vals_up_cast = conversion::to<float>(iteration_buffer[j]);
sum = reduce::element<rop::Add>(sum, vals_up_cast);
}
}
reduce::partitioned_block<rop::Add, threadsPerGroup>(tb, warp, sum);
const float mean = sum / elems_per_row;
float mean_diff = reduce::init<rop::Add, float>();
#pragma unRoll
for (int i = 0; i < unRoll; i++) {
#pragma unRoll
for (int j = 0; j < T_per_load; j++) {
// Using a 0 value here skews the variance, have to if-guard
if (thread_offset + i * stride < elems_per_row) {
float diff = (conversion::to<float>(local_buffer[i * T_per_load + j]) - mean);
mean_diff = reduce::element<rop::Add>(mean_diff, diff * diff);
}
}
}
reduce::partitioned_block<rop::Add, threadsPerGroup>(tb, warp, mean_diff);
const float variance = mean_diff / elems_per_row;
const float denom = __frsqrt_rn(variance + epsilon);
const T mean_compute = conversion::to<T>(mean);
const T denom_compute = conversion::to<T>(denom);
T* block_output = output + block_offset;
#pragma unRoll
for (int i = 0; i < unRoll; i++) {
T* iteration_buffer = local_buffer + i * T_per_load;
const int iter_idx = i * stride + thread_offset;
const bool do_loads = iter_idx < elems_per_row;
T gamma_local[T_per_load], beta_local[T_per_load];
mem_access::load_global<ln::granularity>(gamma_local, gamma + iter_idx, do_loads);
mem_access::load_global<ln::granularity>(beta_local, beta + iter_idx, do_loads);
#pragma unRoll
for (int j = 0; j < T_per_load; j++) {
iteration_buffer[j] = (iteration_buffer[j] - mean_compute) * denom_compute;
iteration_buffer[j] = iteration_buffer[j] * gamma_local[j] + beta_local[j];
}
if (do_loads) {
mem_access::store_global<ln::granularity>(block_output + iter_idx, iteration_buffer);
}
}
}
#define LAUNCH_FUSED_LN(unRollFactor, threadsPerGroup, maxThreads) \
hipLaunchKernelGGL(( fused_ln<T, unRollFactor, threadsPerGroup, maxThreads>) \
, dim3(grid), dim3(block), 0, stream, output, vals, gamma, beta, epsilon, elems_per_row);
template <typename T>
void launch_fused_ln(T* output,
const T* vals,
const T* gamma,
const T* beta,
float epsilon,
int rows,
int elems_per_row,
hipStream_t stream)
{
// 8 for __half, 4 for float
constexpr int T_per_load = ln::granularity / sizeof(T);
constexpr int maxThreads = 256;
// For Flaoat, unRoll 4, for __half, unRoll 2
constexpr int internal_unRoll = sizeof(T) == 4 ? 4 : 2;
const bool is_subblock_schedule = (elems_per_row <= 128) ? true : false;
const int h_per_step = is_subblock_schedule ? T_per_load : T_per_load * internal_unRoll;
// Scheduling concern: may be slightly faster for some inputs to assign multiple stages of
// warp-sized blocks rather than stepping up to 64/96 threads
const int one_step_threads = next_pow2((elems_per_row + h_per_step - 1) / h_per_step);
const int threadsPerGroup = (one_step_threads < maxThreads) ? one_step_threads : maxThreads;
const int groups_per_block_max =
is_subblock_schedule ? (maxThreads + threadsPerGroup - 1) / threadsPerGroup : 1;
const int groups_per_block = (rows < groups_per_block_max) ? rows : groups_per_block_max;
const int groups_launch = (groups_per_block + rows - 1) / groups_per_block;
dim3 block(threadsPerGroup, groups_per_block);
dim3 grid(groups_launch);
const int elems_per_step = threadsPerGroup * h_per_step;
const int external_unRoll = (elems_per_row + elems_per_step - 1) / elems_per_step;
if (is_subblock_schedule) {
// <=128
if (threadsPerGroup == 1) {
LAUNCH_FUSED_LN(1, 1, maxThreads);
} else if (threadsPerGroup == 2) {
LAUNCH_FUSED_LN(1, 2, maxThreads);
} else if (threadsPerGroup == 4) {
LAUNCH_FUSED_LN(1, 4, maxThreads);
} else if (threadsPerGroup == 8) {
LAUNCH_FUSED_LN(1, 8, maxThreads);
} else if (threadsPerGroup == 16) {
LAUNCH_FUSED_LN(1, 16, maxThreads);
}
} else if (external_unRoll == 1) {
// 129 - 4096 elems
// (this can launch with 1-7 warps as well)
LAUNCH_FUSED_LN(1 * internal_unRoll, maxThreads, maxThreads);
} else if (external_unRoll == 2) {
// 4097 - 8192 elems
LAUNCH_FUSED_LN(2 * internal_unRoll, maxThreads, maxThreads);
} else if (external_unRoll == 3) {
// 8193 - 12288 elems
LAUNCH_FUSED_LN(3 * internal_unRoll, maxThreads, maxThreads);
} else if (external_unRoll == 4) {
// 12289 - 16384 elems
LAUNCH_FUSED_LN(4 * internal_unRoll, maxThreads, maxThreads);
}
}
template void launch_fused_ln(__half*,
const __half*,
const __half*,
const __half*,
float,
int,
int,
hipStream_t);
template void
launch_fused_ln(float*, const float*, const float*, const float*, float, int, int, hipStream_t);
/*
Fused resiual + bias + layer norm implementation. Assumes elems_per_row % 8
is equal to 0.
TODO(cmikeh2): Goal is to deprecate this implementation. The bias + residual
need to be fused into compute-bound producer operations.
Args:
output: buffer for output data
res_output: output of residual addition
vals: buffer for input data
residual: residual data
bias: bias of of input data
gamma: gain for normalization
beta: bias for normalization
epsilon: numeric stability
elems_per_row: number of elements each block will normalize
Template arg:
StoreResidual: controls whether the residual calculation is stored
or not. When set to false, the input `res_output` is unused.
*/
template <typename T, int unRoll, int threadsPerGroup, int maxThreads, bool preLnResidual>
__global__ void fused_residual_ln(T* output,
T* res_output,
const T* vals,
const T* residual,
const T* bias,
const T* gamma,
const T* beta,
float epsilon,
int elems_per_row)
{
constexpr int T_per_load = ln::granularity / sizeof(T);
cg::thread_block tb = cg::this_thread_block();
cg::thread_block_tile<hw_warp_size> warp = cg::tiled_partition<hw_warp_size>(tb);
// X-dimension of the block
const int block_offset = (tb.group_index().x * (maxThreads / threadsPerGroup) * elems_per_row) +
(tb.thread_index().y * elems_per_row);
const int thread_offset = tb.thread_index().x * T_per_load;
const int base_offset = block_offset + thread_offset;
const int stride = tb.size() * T_per_load;
float sum = reduce::init<rop::Add, float>();
const T* input_base = vals + base_offset;
const T* residual_base = residual + base_offset;
const T* bias_base = bias + thread_offset;
T local_buffer[unRoll * T_per_load];
// Unlike a vanilla layernorm, since we're fusing the two adds as well
// an inner unRoll seems to be less valuable. If anything, a double unRoll
// makes the most sense if we find we are having performance issues.
#pragma unRoll
for (int i = 0; i < unRoll; i++) {
T* iteration_buffer = local_buffer + i * T_per_load;
T residual_buffer[T_per_load];
T bias_buffer[T_per_load];
mem_access::load_global<ln::granularity>(
iteration_buffer, input_base + i * stride, thread_offset + i * stride < elems_per_row);
mem_access::load_global<ln::granularity>(residual_buffer,
residual_base + i * stride,
thread_offset + i * stride < elems_per_row);
mem_access::load_global<ln::granularity>(
bias_buffer, bias_base + i * stride, thread_offset + i * stride < elems_per_row);
#pragma unRoll
for (int j = 0; j < T_per_load; j++) {
float vals_up_cast = conversion::to<float>(iteration_buffer[j]);
float res_up_cast = conversion::to<float>(residual_buffer[j]);
float bias_up_cast = conversion::to<float>(bias_buffer[j]);
vals_up_cast += res_up_cast + bias_up_cast;
sum = reduce::element<rop::Add>(sum, vals_up_cast);
iteration_buffer[j] = conversion::to<T>(vals_up_cast);
}
if (preLnResidual && (thread_offset + i * stride < elems_per_row)) {
mem_access::store_global<ln::granularity>(res_output + base_offset + i * stride,
iteration_buffer);
}
}
reduce::partitioned_block<rop::Add, threadsPerGroup>(tb, warp, sum);
const float mean = sum / elems_per_row;
float mean_diff = reduce::init<rop::Add, float>();
#pragma unRoll
for (int i = 0; i < unRoll; i++) {
#pragma unRoll
for (int j = 0; j < T_per_load; j++) {
// Using a 0 value here skews the variance, have to if-guard
if (thread_offset + i * stride < elems_per_row) {
float diff = (conversion::to<float>(local_buffer[i * T_per_load + j]) - mean);
mean_diff = reduce::element<rop::Add>(mean_diff, diff * diff);
}
}
}
reduce::partitioned_block<rop::Add, threadsPerGroup>(tb, warp, mean_diff);
const float variance = mean_diff / elems_per_row;
const float denom = __frsqrt_rn(variance + epsilon);
const T mean_compute = conversion::to<T>(mean);
const T denom_compute = conversion::to<T>(denom);
T* block_output = output + block_offset;
#pragma unRoll
for (int i = 0; i < unRoll; i++) {
T* iteration_buffer = local_buffer + i * T_per_load;
const int iter_idx = i * stride + thread_offset;
const bool do_loads = iter_idx < elems_per_row;
T gamma_local[T_per_load], beta_local[T_per_load];
mem_access::load_global<ln::granularity>(gamma_local, gamma + iter_idx, do_loads);
mem_access::load_global<ln::granularity>(beta_local, beta + iter_idx, do_loads);
#pragma unRoll
for (int j = 0; j < T_per_load; j++) {
iteration_buffer[j] = (iteration_buffer[j] - mean_compute) * denom_compute;
iteration_buffer[j] = iteration_buffer[j] * gamma_local[j] + beta_local[j];
}
if (do_loads) {
mem_access::store_global<ln::granularity>(block_output + iter_idx, iteration_buffer);
}
}
}
// TODO(cmikeh2): There's a bunch of redundancy here that needs to be removed/simplified.
#define LAUNCH_FUSED_RES_LN(unRollFactor, threadsPerGroup, maxThreads) \
hipLaunchKernelGGL(( fused_residual_ln<T, unRollFactor, threadsPerGroup, maxThreads, false>) \
, dim3(grid), dim3(block), 0, stream, \
output, nullptr, vals, residual, bias, gamma, beta, epsilon, elems_per_row);
template <typename T>
void launch_fused_residual_ln(T* output,
const T* vals,
const T* residual,
const T* bias,
const T* gamma,
const T* beta,
float epsilon,
int rows,
int elems_per_row,
hipStream_t stream)
{
// 8 for __half, 4 for float
constexpr int T_per_load = ln::granularity / sizeof(T);
constexpr int maxThreads = 256;
// For Flaoat, unRoll 4, for __half, unRoll 2
constexpr int internal_unRoll = sizeof(T) == 4 ? 4 : 2;
const bool is_subblock_schedule = (elems_per_row <= 128) ? true : false;
const int h_per_step = is_subblock_schedule ? T_per_load : T_per_load * internal_unRoll;
// Scheduling concern: may be slightly faster for some inputs to assign multiple stages of
// warp-sized blocks rather than stepping up to 64/96 threads
const int one_step_threads = next_pow2((elems_per_row + h_per_step - 1) / h_per_step);
const int threadsPerGroup = (one_step_threads < maxThreads) ? one_step_threads : maxThreads;
const int groups_per_block_max =
is_subblock_schedule ? (maxThreads + threadsPerGroup - 1) / threadsPerGroup : 1;
const int groups_per_block = (rows < groups_per_block_max) ? rows : groups_per_block_max;
const int groups_launch = (groups_per_block + rows - 1) / groups_per_block;
dim3 block(threadsPerGroup, groups_per_block);
dim3 grid(groups_launch);
const int elems_per_step = threadsPerGroup * h_per_step;
const int external_unRoll = (elems_per_row + elems_per_step - 1) / elems_per_step;
if (is_subblock_schedule) {
// <=128
if (threadsPerGroup == 1) {
LAUNCH_FUSED_RES_LN(1, 1, maxThreads);
} else if (threadsPerGroup == 2) {
LAUNCH_FUSED_RES_LN(1, 2, maxThreads);
} else if (threadsPerGroup == 4) {
LAUNCH_FUSED_RES_LN(1, 4, maxThreads);
} else if (threadsPerGroup == 8) {
LAUNCH_FUSED_RES_LN(1, 8, maxThreads);
} else if (threadsPerGroup == 16) {
LAUNCH_FUSED_RES_LN(1, 16, maxThreads);
}
} else if (external_unRoll == 1) {
// 129 - 4096 elems
// (this can launch with 1-7 warps as well)
LAUNCH_FUSED_RES_LN(1 * internal_unRoll, maxThreads, maxThreads);
} else if (external_unRoll == 2) {
// 4097 - 8192 elems
LAUNCH_FUSED_RES_LN(2 * internal_unRoll, maxThreads, maxThreads);
} else if (external_unRoll == 3) {
// 8193 - 12288 elems
LAUNCH_FUSED_RES_LN(3 * internal_unRoll, maxThreads, maxThreads);
} else if (external_unRoll == 4) {
// 12289 - 16384 elems
LAUNCH_FUSED_RES_LN(4 * internal_unRoll, maxThreads, maxThreads);
}
}
#define LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(unRollFactor, threadsPerGroup, maxThreads) \
hipLaunchKernelGGL(( fused_residual_ln<T, unRollFactor, threadsPerGroup, maxThreads, true>) \
, dim3(grid), dim3(block), 0, stream, \
norm_output, res_output, vals, residual, bias, gamma, beta, epsilon, elems_per_row);
template <typename T>
void launch_fused_residual_ln_store_pre_ln_res(T* norm_output,
T* res_output,
const T* vals,
const T* residual,
const T* bias,
const T* gamma,
const T* beta,
float epsilon,
int rows,
int elems_per_row,
hipStream_t stream)
{
// 8 for __half, 4 for float
constexpr int T_per_load = ln::granularity / sizeof(T);
constexpr int maxThreads = 256;
// For Flaoat, unRoll 4, for __half, unRoll 2
constexpr int internal_unRoll = sizeof(T) == 4 ? 4 : 2;
const bool is_subblock_schedule = (elems_per_row <= 128) ? true : false;
const int h_per_step = is_subblock_schedule ? T_per_load : T_per_load * internal_unRoll;
// Scheduling concern: may be slightly faster for some inputs to assign multiple stages of
// warp-sized blocks rather than stepping up to 64/96 threads
const int one_step_threads = next_pow2((elems_per_row + h_per_step - 1) / h_per_step);
const int threadsPerGroup = (one_step_threads < maxThreads) ? one_step_threads : maxThreads;
const int groups_per_block_max =
is_subblock_schedule ? (maxThreads + threadsPerGroup - 1) / threadsPerGroup : 1;
const int groups_per_block = (rows < groups_per_block_max) ? rows : groups_per_block_max;
const int groups_launch = (groups_per_block + rows - 1) / groups_per_block;
dim3 block(threadsPerGroup, groups_per_block);
dim3 grid(groups_launch);
const int elems_per_step = threadsPerGroup * h_per_step;
const int external_unRoll = (elems_per_row + elems_per_step - 1) / elems_per_step;
if (is_subblock_schedule) {
// <=128
if (threadsPerGroup == 1) {
LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(1, 1, maxThreads);
} else if (threadsPerGroup == 2) {
LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(1, 2, maxThreads);
} else if (threadsPerGroup == 4) {
LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(1, 4, maxThreads);
} else if (threadsPerGroup == 8) {
LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(1, 8, maxThreads);
} else if (threadsPerGroup == 16) {
LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(1, 16, maxThreads);
}
} else if (external_unRoll == 1) {
// 129 - 4096 elems
// (this can launch with 1-7 warps as well)
LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(1 * internal_unRoll, maxThreads, maxThreads);
} else if (external_unRoll == 2) {
// 4097 - 8192 elems
LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(2 * internal_unRoll, maxThreads, maxThreads);
} else if (external_unRoll == 3) {
// 8193 - 12288 elems
LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(3 * internal_unRoll, maxThreads, maxThreads);
} else if (external_unRoll == 4) {
// 12289 - 16384 elems
LAUNCH_FUSED_RES_LN_STORE_PRE_LN_RES(4 * internal_unRoll, maxThreads, maxThreads);
}
}
// No-store specializations
template void launch_fused_residual_ln(__half*,
const __half*,
const __half*,
const __half*,
const __half*,
const __half*,
float,
int,
int,
hipStream_t);
template void launch_fused_residual_ln(float*,
const float*,
const float*,
const float*,
const float*,
const float*,
float,
int,
int,
hipStream_t);
// Store specializations
template void launch_fused_residual_ln_store_pre_ln_res(__half*,
__half*,
const __half*,
const __half*,
const __half*,
const __half*,
const __half*,
float,
int,
int,
hipStream_t);
template void launch_fused_residual_ln_store_pre_ln_res(float*,
float*,
const float*,
const float*,
const float*,
const float*,
const float*,
float,
int,
int,
hipStream_t);
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
#include <limits>
#include "custom_hip_layers.h"
#ifndef __HIP_PLATFORM_HCC__
#include <cuda_profiler_api.h>
#endif
#include <cstdio>
#include <cstdlib>
#include <ctime>
#define NORM_REG (MAX_REGISTERS)
namespace cg = cooperative_groups;
__global__ void fused_bias_residual_layer_norm(float* output,
const float* vals,
const float* gamma,
const float* beta,
float epsilon,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id >> 5;
int warp_num = iteration_stride >> 5;
float inp_reg[NORM_REG];
int k = 0;
float sum = 0;
int input_id = id;
while (input_id < row_stride) {
inp_reg[k] = vals[input_id + row * row_stride];
sum += inp_reg[k++];
input_id += iteration_stride;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
__shared__ float shr[MAX_WARP_NUM];
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
float mean = sum / (row_stride);
sum = 0.f;
for (int f = 0; f < k; f++) {
inp_reg[f] -= mean;
sum += inp_reg[f] * inp_reg[f];
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (row_stride);
sum += epsilon;
sum = __frsqrt_rn(sum);
for (int f = 0; f < k; f++) {
int out_id = f * iteration_stride + id;
inp_reg[f] = inp_reg[f] * sum;
inp_reg[f] = inp_reg[f] * gamma[out_id] + beta[out_id];
output[out_id + row * row_stride] = inp_reg[f];
}
}
__global__ void fused_bias_residual_layer_norm(__half* output,
const __half* vals,
const __half* gamma,
const __half* beta,
float epsilon,
int row_stride)
{
#ifdef HALF_PRECISION_AVAILABLE
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id >> 5;
int warp_num = iteration_stride >> 5;
__half2 inp_reg[NORM_REG];
const __half2* vals_cast = reinterpret_cast<const __half2*>(vals);
__half2* out_cast = reinterpret_cast<__half2*>(output);
int k = 0;
int input_id = id;
while (input_id < row_stride) {
inp_reg[k++] = vals_cast[input_id + row * row_stride];
input_id += iteration_stride;
}
float sum = 0;
for (int f = k - 1; f >= 0; f--) {
float2 inp_f = __half22float2(inp_reg[f]);
sum += inp_f.x + inp_f.y;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
__shared__ float shr[MAX_WARP_NUM];
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
float mean = sum / (row_stride << 1);
sum = 0.f;
for (int f = 0; f < k; f++) {
float2 inp_f = __half22float2(inp_reg[f]);
inp_f.x -= mean;
inp_f.y -= mean;
inp_reg[f] = __float22half2_rn(inp_f);
sum += inp_f.x * inp_f.x;
sum += inp_f.y * inp_f.y;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (row_stride << 1);
sum += epsilon;
sum = __frsqrt_rn(sum);
__half2 variance_h = __float2half2_rn(sum);
const __half2* gamma_cast = reinterpret_cast<const __half2*>(gamma);
const __half2* beta_cast = reinterpret_cast<const __half2*>(beta);
for (int f = 0; f < k; f++) {
int out_id = f * iteration_stride + id;
inp_reg[f] = inp_reg[f] * variance_h;
inp_reg[f] = inp_reg[f] * gamma_cast[out_id] + beta_cast[out_id];
out_cast[out_id + row * row_stride] = inp_reg[f];
}
#endif
}
template <typename T>
void launch_layer_norm(T* out,
T* vals,
const T* gamma,
const T* beta,
float epsilon,
int batch_size,
int hidden_dim,
hipStream_t stream);
template <>
void launch_layer_norm<float>(float* out,
float* vals,
const float* gamma,
const float* beta,
float epsilon,
int batch_size,
int hidden_dim,
hipStream_t stream)
{
constexpr int threads = 1024;
dim3 grid_dim(batch_size);
dim3 block_dim(threads);
hipLaunchKernelGGL(( fused_bias_residual_layer_norm), dim3(grid_dim), dim3(block_dim), 0, stream,
out, vals, gamma, beta, epsilon, hidden_dim);
}
template <>
void launch_layer_norm<__half>(__half* out,
__half* vals,
const __half* gamma,
const __half* beta,
float epsilon,
int batch_size,
int hidden_dim,
hipStream_t stream)
{
constexpr int threads = 1024;
dim3 grid_dim(batch_size);
dim3 block_dim(threads);
hipLaunchKernelGGL(( fused_bias_residual_layer_norm), dim3(grid_dim), dim3(block_dim), 0, stream,
out, vals, gamma, beta, epsilon, hidden_dim / 2);
}
__global__ void fused_residual_layer_norm(float* norm,
float* res_add,
float* vals,
float* residual,
const float* bias,
const float* gamma,
const float* beta,
float epsilon,
int row_stride,
bool preLN,
bool mlp_after_attn)
{
int iteration_stride = blockDim.x;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id >> 5;
int warp_num = iteration_stride >> 5;
float inp_reg[NORM_REG];
int k = 0;
int input_id = id;
float sum = 0;
while (input_id < row_stride) {
inp_reg[k] = vals[input_id + row * row_stride];
float res_f = (residual[input_id + row * row_stride]);
float bias_f = (bias[input_id]);
if (mlp_after_attn) inp_reg[k] += res_f + bias_f;
// if (preLN) res_add[input_id + row * row_stride] = inp_reg[k];
sum += inp_reg[k++];
input_id += iteration_stride;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
__shared__ float shr[MAX_WARP_NUM];
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
float mean = sum / (row_stride);
sum = 0.f;
for (int f = 0; f < k; f++) {
inp_reg[f] -= mean;
sum += inp_reg[f] * inp_reg[f];
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (row_stride);
sum += epsilon;
sum = __frsqrt_rn(sum);
for (int f = 0; f < k; f++) {
int out_id = f * iteration_stride + id;
inp_reg[f] = inp_reg[f] * sum;
inp_reg[f] = inp_reg[f] * gamma[out_id] + beta[out_id];
norm[out_id + row * row_stride] = inp_reg[f];
}
}
__global__ void fused_residual_layer_norm(__half* norm,
__half* res_add,
__half* vals,
__half* residual,
const __half* bias,
const __half* gamma,
const __half* beta,
float epsilon,
int row_stride,
bool preLN,
bool mlp_after_attn)
{
#ifdef HALF_PRECISION_AVAILABLE
int iteration_stride = blockDim.x;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id >> 5;
int warp_num = iteration_stride >> 5;
__half2 inp_reg[NORM_REG];
__half2* vals_cast = reinterpret_cast<__half2*>(vals);
__half2* norm_cast = reinterpret_cast<__half2*>(norm);
__half2* res_add_cast = reinterpret_cast<__half2*>(res_add);
__half2* residual_cast = reinterpret_cast<__half2*>(residual);
const __half2* bias_cast = reinterpret_cast<const __half2*>(bias);
int k = 0;
int input_id = id;
float sum = 0;
while (input_id < row_stride) {
inp_reg[k] = vals_cast[input_id + row * row_stride];
float2 inp_f = __half22float2(inp_reg[k]);
float2 res_f = __half22float2(residual_cast[input_id + row * row_stride]);
float2 bias_f = __half22float2(bias_cast[input_id]);
if (mlp_after_attn) {
inp_f.x += res_f.x + bias_f.x;
inp_f.y += res_f.y + bias_f.y;
}
inp_reg[k] = __float22half2_rn(inp_f);
// if (preLN) res_add_cast[input_id + row * row_stride] = __float22half2_rn(res_f);
// //inp_reg[k];
sum += inp_f.x + inp_f.y;
input_id += iteration_stride;
k++;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
__shared__ float shr[MAX_WARP_NUM];
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
float mean = sum / (row_stride << 1);
sum = 0.f;
for (int f = 0; f < k; f++) {
float2 inp_f = __half22float2(inp_reg[f]);
inp_f.x -= mean;
inp_f.y -= mean;
inp_reg[f] = __float22half2_rn(inp_f);
sum += inp_f.x * inp_f.x;
sum += inp_f.y * inp_f.y;
}
for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i);
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()];
b.sync();
for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (row_stride << 1);
sum += epsilon;
sum = __frsqrt_rn(sum);
__half2 variance_h = __float2half2_rn(sum);
const __half2* gamma_cast = reinterpret_cast<const __half2*>(gamma);
const __half2* beta_cast = reinterpret_cast<const __half2*>(beta);
for (int f = 0; f < k; f++) {
int out_id = f * iteration_stride + id;
inp_reg[f] = inp_reg[f] * variance_h;
inp_reg[f] = inp_reg[f] * gamma_cast[out_id] + beta_cast[out_id];
norm_cast[out_id + row * row_stride] = inp_reg[f];
}
#endif
}
template <typename T>
void launch_residual_layer_norm(T* norm,
T* res_add,
T* vals,
T* residual,
const T* bias,
const T* gamma,
const T* beta,
float epsilon,
int batch_size,
int hidden_dim,
bool preLN,
bool mlp_after_attn,
hipStream_t stream);
template <>
void launch_residual_layer_norm<float>(float* norm,
float* res_add,
float* vals,
float* residual,
const float* bias,
const float* gamma,
const float* beta,
float epsilon,
int batch_size,
int hidden_dim,
bool preLN,
bool mlp_after_attn,
hipStream_t stream)
{
constexpr int threads = 1024;
dim3 grid_dim(batch_size);
dim3 block_dim(threads);
hipLaunchKernelGGL(( fused_residual_layer_norm), dim3(grid_dim), dim3(block_dim), 0, stream, norm,
res_add,
vals,
residual,
bias,
gamma,
beta,
epsilon,
hidden_dim,
preLN,
mlp_after_attn);
}
template <>
void launch_residual_layer_norm<__half>(__half* norm,
__half* res_add,
__half* vals,
__half* residual,
const __half* bias,
const __half* gamma,
const __half* beta,
float epsilon,
int batch_size,
int hidden_dim,
bool preLN,
bool mlp_after_attn,
hipStream_t stream)
{
constexpr int threads = 1024;
dim3 grid_dim(batch_size);
dim3 block_dim(threads);
hipLaunchKernelGGL(( fused_residual_layer_norm), dim3(grid_dim), dim3(block_dim), 0, stream, norm,
res_add,
vals,
residual,
bias,
gamma,
beta,
epsilon,
hidden_dim / 2,
preLN,
mlp_after_attn);
}
// !!! This is a file automatically generated by hipify!!!
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/
#include <ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h>
#include <torch/extension.h>
#include <stdexcept>
#include <vector>
#include "inference_context.h"
#include "inference_cublas_wrappers.h"
#include "inference_cuda_layers.h"
std::array<int, 3> gemm_algos = std::array<int, 3>({99, 99, 99});
// NOTE: This activation function type enum should be always in sync
// with the python counterpart, otherwise the casting from python binding
// will be incorrect.
enum class ActivationFuncType { UNKNOWN = 0, GELU = 1, ReLU = 2 };
enum class TransformerType : uint8_t { UNKNOWN = 0, GPTType = 1, BERTType = 2 };
// NOTE: this is a temporary and dodgy solution to distinguish GPT and BERT style models
// based on the dimensions of the corresponding attention mask.
inline auto infer_transformer_type(at::Tensor& attn_mask) -> TransformerType
{
auto attn_mask_num_dims = attn_mask.sizes().size();
if (attn_mask_num_dims > 2) {
return TransformerType::GPTType;
} else if (attn_mask_num_dims == 2) {
return TransformerType::BERTType;
} else {
return TransformerType::UNKNOWN;
}
}
// infer stride of attention mask memory layout based on the model type.
inline auto get_attn_mask_stride(at::Tensor& attn_mask) -> int
{
auto trnsfrmr_type = infer_transformer_type(attn_mask);
if (trnsfrmr_type == TransformerType::GPTType) {
return attn_mask.size(2);
} else if (trnsfrmr_type == TransformerType::BERTType) {
// Bert style models have always a mask stride of 1.
return 1;
} else if (trnsfrmr_type == TransformerType::UNKNOWN) {
return 0;
}
// this is just to make the compiler happy.
return 0;
}
template <typename T>
at::Tensor ds_softmax(at::Tensor& attn_scores,
at::Tensor& attn_mask,
at::Tensor& alibi,
bool triangular,
bool recompute,
bool local_attention,
int window_size,
bool async_op,
float layer_scale,
int head_offset,
int mp_size)
{
auto attn_scores_c = attn_scores.contiguous();
int bsz = attn_scores_c.size(0);
int seq_len = attn_scores_c.size(1);
int len = attn_scores_c.sizes().size();
if (len > 2) seq_len = attn_scores_c.size(2);
int soft_len = attn_scores_c.size(2);
if (len > 3) soft_len = attn_scores_c.size(3);
int heads = 1;
if (len > 1) heads = attn_scores_c.size(1);
auto mask_stride = get_attn_mask_stride(attn_mask);
launch_attn_softmax_v2((T*)attn_scores_c.data_ptr(),
(attn_mask.sizes().size() > 1 ? (T*)attn_mask.data_ptr() : nullptr),
(alibi.sizes().size() > 1 ? (T*)alibi.data_ptr() : nullptr),
layer_scale,
triangular,
recompute,
local_attention,
window_size,
bsz,
heads,
seq_len,
soft_len,
head_offset,
mask_stride,
mp_size,
Context::Instance().GetCurrentStream(async_op));
return attn_scores_c;
}
template <typename T>
void allocate_workspace(unsigned hidden_dim,
unsigned num_heads,
unsigned prompt_length,
unsigned batch_size,
unsigned num_layers,
unsigned mp_size = 1,
bool external_cache = false,
unsigned rank = 0,
unsigned max_out_tokens = 1024)
{
Context::Instance().GenWorkSpace(num_layers,
num_heads,
batch_size,
prompt_length,
hidden_dim,
mp_size,
external_cache,
sizeof(T),
rank,
max_out_tokens);
}
template <typename T>
at::Tensor einsum_sec_sm_ecm(at::Tensor& Q, at::Tensor& W)
{
auto options = at::TensorOptions()
.dtype(Q.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
T* workspace = (T*)Context::Instance().GetWorkSpace();
float alpha = 1;
float gemm_beta = 0.0;
/*
// Reallocate memory if we received a new prompt
if (!workspace || input.size(1) != 1) {
allocate_workspace<T>(W.size(1), Context::Instance().GetMaxTokenLenght(), Q.size(0), 1,
head_size); workspace = (T*)Context::Instance().GetWorkSpace();
}
*/
auto O = at::from_blob(workspace, {Q.size(1), Q.size(2), W.size(1)}, options);
unsigned m = W.size(1);
unsigned n = Q.size(1) * Q.size(2);
unsigned k = Q.size(0);
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
rocblas_operation_none,
rocblas_operation_transpose,
m,
n,
k,
&alpha,
&gemm_beta,
(T*)W.data_ptr(),
(T*)Q.data_ptr(),
(T*)O.data_ptr(),
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo_standard);
#else
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
#endif
return O;
}
template <typename T>
void attention_unfused(at::Tensor& prev_key_cont,
at::Tensor& query_cont,
at::Tensor& attn_mask,
at::Tensor& prev_value_cont,
at::Tensor& output,
int& bsz,
int& seq_len,
int& soft_len,
int& heads,
float& norm_factor,
bool triangular,
bool recompute,
bool local_attention,
int window_size)
{
auto options = at::TensorOptions()
.dtype(query_cont.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
float alpha = norm_factor;
float gemm_beta = 0.0;
auto attn_score = at::empty({bsz, heads, seq_len, soft_len}, options);
int k = prev_value_cont.size(2) / heads;
auto mask_stride = get_attn_mask_stride(attn_mask);
rocblas_set_stream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream());
cublas_strided_batched_gemm(Context::Instance().GetCublasHandle(),
soft_len,
seq_len,
k,
&alpha,
&gemm_beta,
(T*)prev_key_cont.data_ptr(),
(T*)query_cont.data_ptr(),
(T*)attn_score.data_ptr(),
rocblas_operation_none,
rocblas_operation_none,
soft_len * k,
seq_len * k,
seq_len * soft_len,
bsz * heads,
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo_standard);
#else
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
#endif
launch_attn_softmax_v2((T*)attn_score.data_ptr(),
(T*)(attn_mask.sizes().size() > 1 ? attn_mask.data_ptr() : nullptr),
(T*)nullptr,
1.0,
triangular,
recompute,
local_attention,
window_size,
bsz,
heads,
seq_len,
soft_len,
0,
mask_stride,
1,
Context::Instance().GetCurrentStream(false));
alpha = 1.0;
cublas_strided_batched_gemm(Context::Instance().GetCublasHandle(),
k,
seq_len,
soft_len,
&alpha,
&gemm_beta,
(T*)prev_value_cont.data_ptr(),
(T*)attn_score.data_ptr(),
(T*)output.data_ptr(),
rocblas_operation_none,
rocblas_operation_none,
soft_len * k,
seq_len * soft_len,
seq_len * k,
bsz * heads,
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo_standard);
#else
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
#endif
}
template <typename T>
std::vector<at::Tensor> ds_softmax_context1(at::Tensor& query,
at::Tensor& prev_key,
at::Tensor& new_key,
at::Tensor& attn_mask,
at::Tensor& prev_value,
at::Tensor& new_value,
int heads,
float norm_factor,
bool merging,
bool triangular,
bool local_attention,
int window_size,
bool no_masking)
{
auto query_cont = query.contiguous();
auto prev_key_cont = prev_key.contiguous();
auto prev_value_cont = prev_value.contiguous();
int new_size = (new_value.sizes().size() > 1 ? new_value.size(1) : 0);
// Attn_Score [ batch Head Sequence-length Softmax-length]
int bsz = query_cont.size(0);
int seq_len = query_cont.size(1);
int soft_len = prev_value.size(1);
auto options = at::TensorOptions()
.dtype(query_cont.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto output =
at::empty({prev_value.size(0), heads, seq_len, prev_value.size(2) / heads}, options);
attention_unfused<T>(prev_key_cont,
query_cont,
attn_mask, //(no_masking ? nullptr : (T*)attn_mask.data_ptr()),
prev_value_cont,
output,
bsz,
seq_len,
soft_len,
heads,
norm_factor,
(triangular && (new_size == 0)),
(new_size == 0),
local_attention,
window_size);
return {output, prev_key, prev_value};
}
template <typename T>
void ds_softmax_internal(T* attn_scores,
at::Tensor& attn_mask,
at::Tensor& alibi,
float& layer_scale,
bool triangular,
bool recompute,
bool local_attention,
int window_size,
int bsz,
int seq_len,
int soft_len,
int heads)
{
auto mask_stride = get_attn_mask_stride(attn_mask);
launch_attn_softmax_v2((T*)attn_scores,
(attn_mask.sizes().size() > 1 ? (T*)attn_mask.data_ptr() : nullptr),
(alibi.sizes().size() > 1 ? (T*)alibi.data_ptr() : nullptr),
layer_scale,
triangular,
recompute,
local_attention,
window_size,
bsz,
heads,
seq_len,
soft_len,
0,
mask_stride,
1,
at::hip::getCurrentHIPStreamMasqueradingAsCUDA());
}
template <typename T>
void attention_unfused(T* prev_key_cont,
T* query_cont,
at::Tensor& attn_mask,
T* prev_value_cont,
T* output,
unsigned& bsz,
int& k,
unsigned& seq_len,
unsigned& soft_len,
int& heads,
float& norm_factor,
bool triangular,
bool recompute,
bool local_attention,
int window_size,
at::Tensor& alibi,
int layer_id)
{
float layer_scale = alibi.sizes().size() > 1 ? std::max(1, layer_id) : 1.0;
float alpha = norm_factor * norm_factor / layer_scale;
float gemm_beta = 0.0;
T* workspace = (T*)Context::Instance().GetAttentionUnfusedWorkspace();
rocblas_set_stream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream());
cublas_strided_batched_gemm(Context::Instance().GetCublasHandle(),
soft_len,
seq_len,
k,
&alpha,
&gemm_beta,
(T*)prev_key_cont,
(T*)query_cont,
workspace,
rocblas_operation_transpose,
rocblas_operation_none,
Context::Instance().GetMaxTokenLenght() * k,
seq_len * k,
seq_len * soft_len,
bsz * heads,
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo_standard);
#else
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
#endif
ds_softmax_internal<T>(workspace,
attn_mask,
alibi,
layer_scale,
triangular,
recompute,
local_attention,
window_size,
bsz,
seq_len,
soft_len,
heads);
alpha = 1.0;
cublas_strided_batched_gemm(Context::Instance().GetCublasHandle(),
k,
seq_len,
soft_len,
&alpha,
&gemm_beta,
(T*)prev_value_cont,
workspace,
(T*)output,
rocblas_operation_none,
rocblas_operation_none,
Context::Instance().GetMaxTokenLenght() * k,
seq_len * soft_len,
seq_len * k,
bsz * heads,
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo_standard);
#else
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
#endif
}
void reset_cache() { Context::Instance().reset_tokens(); }
template <typename T>
std::vector<at::Tensor> ds_softmax_context(at::Tensor& query_key_value,
at::Tensor& attn_mask,
int rotary_dim,
bool rotate_half,
bool rotate_every_two,
int heads,
float norm_factor,
bool triangular,
bool local_attention,
int window_size,
bool no_masking,
unsigned layer_id,
unsigned num_layers,
at::Tensor& alibi)
{
unsigned bsz = query_key_value.size(0);
unsigned seq_len = query_key_value.size(1);
unsigned hidden_dim = query_key_value.size(2) / 3;
bool is_prompt = (seq_len > 1);
if (is_prompt) Context::Instance().reset_tokens(seq_len);
unsigned soft_len = Context::Instance().current_tokens();
int k = hidden_dim / heads;
auto options = at::TensorOptions()
.dtype(query_key_value.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
T* workspace = (T*)Context::Instance().GetWorkSpace();
size_t buf_size = bsz * seq_len * hidden_dim;
auto output = torch::from_blob(workspace + 4 * buf_size, {bsz, seq_len, hidden_dim}, options);
auto query_cont = workspace + 8 * buf_size;
size_t offset = 16 * (hidden_dim * bsz * Context::Instance().GetMaxTokenLenght()) +
layer_id * 2 * bsz * Context::Instance().GetMaxTokenLenght() * hidden_dim;
unsigned all_tokens = soft_len;
auto kv_cache = workspace + offset + (hidden_dim / heads) * (is_prompt ? 0 : soft_len - 1);
size_t value_offset = bsz * Context::Instance().GetMaxTokenLenght() * hidden_dim;
T* temp_buf = (T*)output.data_ptr() + at::numel(output);
launch_bias_add_transform_0213<T>((T*)query_cont,
kv_cache,
kv_cache + value_offset,
(T*)query_key_value.data_ptr(),
nullptr,
bsz,
seq_len,
(is_prompt ? 0 : soft_len - 1),
soft_len,
hidden_dim,
heads,
rotary_dim,
rotate_half,
rotate_every_two,
Context::Instance().GetCurrentStream(),
3,
Context::Instance().GetMaxTokenLenght());
if (rotary_dim > 0 && rotate_half)
launch_apply_rotary_pos_emb(query_cont,
kv_cache,
k,
seq_len,
rotary_dim,
(is_prompt ? 0 : soft_len - 1),
heads,
bsz,
rotate_half,
rotate_every_two,
Context::Instance().GetCurrentStream(),
Context::Instance().GetMaxTokenLenght());
attention_unfused<T>(workspace + offset,
(T*)query_cont,
attn_mask,
workspace + offset + value_offset,
temp_buf,
bsz,
k,
seq_len,
all_tokens,
heads,
norm_factor,
(triangular && is_prompt),
is_prompt,
local_attention,
window_size,
alibi,
layer_id);
launch_transform4d_0213<T>((T*)output.data_ptr(),
temp_buf,
bsz,
heads,
seq_len,
output.size(2),
Context::Instance().GetCurrentStream(false),
1);
if (layer_id == num_layers - 1) Context::Instance().advance_tokens();
auto prev_key = torch::from_blob(workspace + offset, {bsz, heads, all_tokens, k}, options);
auto prev_value =
torch::from_blob(workspace + offset + value_offset, {bsz, heads, all_tokens, k}, options);
return {output, prev_key, prev_value};
}
template <typename T>
at::Tensor ds_bias_gelu(at::Tensor& input, at::Tensor& bias)
{
auto input_cont = input.contiguous();
int bsz = input_cont.size(0) * input_cont.size(1);
int intermediate_size = input_cont.size(2);
launch_bias_gelu((T*)input_cont.data_ptr(),
(T*)bias.data_ptr(),
intermediate_size,
bsz,
Context::Instance().GetCurrentStream());
return input_cont;
}
at::Tensor ds_bias_geglu(at::Tensor& activation, at::Tensor& bias)
{
/*
Used in FF of Stable diffusion
*/
const int batch_size = activation.size(0);
const int seq_len = activation.size(1);
const int channels = activation.size(2);
const int rows = batch_size * seq_len;
// Dimensionality is cut in half
const int out_channels = channels / 2;
auto output = at::empty({batch_size, seq_len, out_channels}, activation.options());
if (activation.options().dtype() == torch::kFloat32) {
launch_fused_bias_geglu((float*)output.data_ptr(),
(const float*)activation.data_ptr(),
(const float*)bias.data_ptr(),
rows,
channels,
Context::Instance().GetCurrentStream());
} else {
launch_fused_bias_geglu((__half*)output.data_ptr(),
(const __half*)activation.data_ptr(),
(const __half*)bias.data_ptr(),
rows,
channels,
Context::Instance().GetCurrentStream());
}
return output;
}
template <typename T>
at::Tensor ds_bias_relu(at::Tensor& input, at::Tensor& bias)
{
auto input_cont = input.contiguous();
int bsz = input_cont.size(0) * input_cont.size(1);
int intermediate_size = input_cont.size(2);
launch_bias_relu((T*)input_cont.data_ptr(),
(T*)bias.data_ptr(),
intermediate_size,
bsz,
Context::Instance().GetCurrentStream());
return input_cont;
}
template <typename T>
at::Tensor ds_bias_add(at::Tensor& input, at::Tensor& bias)
{
auto input_cont = input.contiguous();
int bsz = input_cont.size(0) * input_cont.size(1);
int hidden_size = input_cont.size(2);
launch_bias_add((T*)input_cont.data_ptr(),
(T*)bias.data_ptr(),
hidden_size,
bsz,
Context::Instance().GetCurrentStream());
return input_cont;
}
template <typename T>
at::Tensor ds_bias_residual(at::Tensor& input, at::Tensor& residual, at::Tensor& bias)
{
auto input_cont = input.contiguous();
auto residual_cont = residual.contiguous();
int bsz = input_cont.size(0) * input_cont.size(1);
// launch_bias_residual((T*)input_cont.data_ptr(),
// (T*)residual_cont.data_ptr(),
// (T*)bias.data_ptr(),
// bsz,
// input_cont.size(2),
// (bias.size(0) > 1),
// Context::Instance().GetCurrentStream());
return input_cont;
}
at::Tensor ds_layer_norm(at::Tensor& input, at::Tensor& gamma, at::Tensor& beta, float epsilon)
{
const int rows = input.size(0) * input.size(1);
const int elems_per_row = input.size(2);
auto output = at::empty_like(input);
if (input.options().dtype() == torch::kFloat16) {
launch_fused_ln((__half*)output.data_ptr(),
(const __half*)input.data_ptr(),
(const __half*)gamma.data_ptr(),
(const __half*)beta.data_ptr(),
epsilon,
rows,
elems_per_row,
Context::Instance().GetCurrentStream());
} else {
launch_fused_ln((float*)output.data_ptr(),
(const float*)input.data_ptr(),
(const float*)gamma.data_ptr(),
(const float*)beta.data_ptr(),
epsilon,
rows,
elems_per_row,
Context::Instance().GetCurrentStream());
}
return output;
}
template <typename T>
void ds_layer_norm_internal(T* workspace,
at::Tensor& input,
at::Tensor& gamma,
at::Tensor& beta,
float epsilon)
{
int bsz = input.size(0) * input.size(1);
launch_fused_ln(workspace,
(const T*)input.data_ptr(),
(const T*)gamma.data_ptr(),
(const T*)beta.data_ptr(),
epsilon,
bsz,
input.size(2),
Context::Instance().GetCurrentStream());
}
/* Currently only used in unit testing */
at::Tensor ds_layer_norm_residual(at::Tensor& input,
at::Tensor& bias,
at::Tensor& residual,
at::Tensor& gamma,
at::Tensor& beta,
float epsilon)
{
const int rows = input.size(0) * input.size(1);
const int elems_per_row = input.size(2);
auto output = at::empty_like(input);
if (input.options().dtype() == torch::kFloat16) {
launch_fused_residual_ln((__half*)output.data_ptr(),
(const __half*)input.data_ptr(),
(const __half*)residual.data_ptr(),
(const __half*)bias.data_ptr(),
(const __half*)gamma.data_ptr(),
(const __half*)beta.data_ptr(),
epsilon,
rows,
elems_per_row,
Context::Instance().GetCurrentStream());
} else {
launch_fused_residual_ln((float*)output.data_ptr(),
(const float*)input.data_ptr(),
(const float*)residual.data_ptr(),
(const float*)bias.data_ptr(),
(const float*)gamma.data_ptr(),
(const float*)beta.data_ptr(),
epsilon,
rows,
elems_per_row,
Context::Instance().GetCurrentStream());
}
return output;
}
/* Currently only used in unit testing */
std::vector<at::Tensor> ds_layer_norm_residual_store_pre_ln_res(at::Tensor& input,
at::Tensor& bias,
at::Tensor& residual,
at::Tensor& gamma,
at::Tensor& beta,
float epsilon)
{
const int rows = input.size(0) * input.size(1);
const int elems_per_row = input.size(2);
auto norm_output = at::empty_like(input);
auto res_output = at::empty_like(input);
if (input.options().dtype() == torch::kFloat16) {
launch_fused_residual_ln_store_pre_ln_res((__half*)norm_output.data_ptr(),
(__half*)res_output.data_ptr(),
(const __half*)input.data_ptr(),
(const __half*)residual.data_ptr(),
(const __half*)bias.data_ptr(),
(const __half*)gamma.data_ptr(),
(const __half*)beta.data_ptr(),
epsilon,
rows,
elems_per_row,
Context::Instance().GetCurrentStream());
} else {
launch_fused_residual_ln_store_pre_ln_res((float*)norm_output.data_ptr(),
(float*)res_output.data_ptr(),
(const float*)input.data_ptr(),
(const float*)residual.data_ptr(),
(const float*)bias.data_ptr(),
(const float*)gamma.data_ptr(),
(const float*)beta.data_ptr(),
epsilon,
rows,
elems_per_row,
Context::Instance().GetCurrentStream());
}
return {norm_output, res_output};
}
template <typename T>
void quantized_gemm(void* output,
T* input,
at::Tensor& weight,
at::Tensor& qscale,
int groups,
int bsz,
int hidden_size)
{
// T* weight16 = (T*)Context::Instance().GetWorkSpace() + 12 * hidden_size * bsz;
auto options = at::TensorOptions()
.dtype(at::kHalf)
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto tmp = torch::empty(weight.sizes(), options);
T* weight16 = (T*)tmp.data_ptr();
launch_dequantize(weight16,
(int8_t*)weight.data_ptr(),
(float*)qscale.data_ptr(),
weight.size(0),
weight.size(1),
groups,
Context::Instance().GetCurrentStream());
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
rocblas_operation_transpose,
rocblas_operation_none,
weight.size(0),
bsz,
weight.size(1),
&alpha,
&gemm_beta,
weight16,
(T*)input,
(T*)output,
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo_standard);
#else
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
#endif
}
template <typename T>
at::Tensor qkv_unfused_cublas(at::Tensor& output,
at::Tensor& input,
at::Tensor& weight,
at::Tensor& q_scale,
at::Tensor& bias,
at::Tensor& gamma,
at::Tensor& beta,
const float epsilon,
bool add_bias,
bool q_int8)
{
int bsz = input.size(0) * input.size(1);
T* workspace = (T*)Context::Instance().GetWorkSpace();
workspace += (3 * bsz * input.size(2));
ds_layer_norm_internal<T>(workspace, input, gamma, beta, epsilon);
if (q_int8) {
quantized_gemm<T>(
output.data_ptr(), workspace, weight, q_scale, q_scale.size(0), bsz, input.size(2));
} else {
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
rocblas_set_stream(Context::Instance().GetCublasHandle(),
Context::Instance().GetCurrentStream());
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
rocblas_operation_none,
rocblas_operation_none,
weight.size(1),
bsz,
input.size(2),
&alpha,
&gemm_beta,
(T*)weight.data_ptr(),
workspace,
(T*)output.data_ptr(),
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo_standard);
#else
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
#endif
}
if (add_bias)
launch_bias_add((T*)output.data_ptr(),
(T*)bias.data_ptr(),
q_int8 ? weight.size(0) : weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
return torch::from_blob(workspace, input.sizes(), input.options());
}
template <typename T>
std::vector<at::Tensor> ds_qkv_gemm(at::Tensor& input,
at::Tensor& weight,
at::Tensor& q_scale,
at::Tensor& bias,
at::Tensor& gamma,
at::Tensor& beta,
const float epsilon,
bool add_bias,
unsigned num_layers,
bool external_cache,
unsigned mp_size,
unsigned rank,
bool q_int8)
{
int bsz = input.size(0) * input.size(1);
T* workspace = (T*)Context::Instance().GetWorkSpace();
int out_size = q_int8 ? weight.size(0) : weight.size(1);
auto options = at::TensorOptions()
.dtype(input.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto output = at::from_blob(workspace, {input.size(0), input.size(1), out_size}, options);
auto inp_norm = qkv_unfused_cublas<T>(
output, input, weight, q_scale, bias, gamma, beta, epsilon, add_bias, q_int8);
return {output, inp_norm};
}
template <typename T>
void quantized_gemm(at::Tensor& output,
at::Tensor& input,
at::Tensor& weight,
at::Tensor& qscale,
int groups,
int merge_count)
{
int bsz = input.size(0) * input.size(1);
auto options = at::TensorOptions()
.dtype(input.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto weight16 = at::empty({weight.size(0), weight.size(1)}, options);
launch_dequantize((T*)weight16.data_ptr(),
(int8_t*)weight.data_ptr(),
(float*)qscale.data_ptr(),
weight.size(0),
weight.size(1),
groups,
merge_count,
Context::Instance().GetCurrentStream());
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
rocblas_operation_transpose,
rocblas_operation_none,
weight.size(0),
bsz,
input.size(2),
&alpha,
&gemm_beta,
(T*)weight16.data_ptr(),
(T*)input.data_ptr(),
(T*)output.data_ptr(),
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo_standard);
#else
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
#endif
}
template <typename T>
at::Tensor ds_qkv_gemm_int8(at::Tensor& input,
at::Tensor& weight,
at::Tensor& bias,
at::Tensor& gamma,
at::Tensor& beta,
const float epsilon,
at::Tensor& q_scale,
int groups,
bool add_bias)
{
int bsz = input.size(0) * input.size(1);
auto input_cont = input.contiguous();
auto options = at::TensorOptions()
.dtype(input_cont.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options);
auto inp_norm = ds_layer_norm(input_cont, gamma, beta, epsilon);
quantized_gemm<T>(output, inp_norm, weight, q_scale, groups, 0);
if (add_bias)
launch_bias_add((T*)output.data_ptr(),
(T*)bias.data_ptr(),
weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
return output;
}
template <typename T>
at::Tensor ds_linear_layer(at::Tensor& input,
at::Tensor& weight,
at::Tensor& bias,
bool add_bias,
bool do_flash_attn,
int num_heads)
{
auto input_cont = input.contiguous();
auto options = at::TensorOptions()
.dtype(input_cont.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
int head_size = input_cont.size(2) / num_heads;
int bsz = input.size(0) * input.size(1);
T* workspace = (T*)Context::Instance().GetWorkSpace();
auto output = at::from_blob(workspace, {input.size(0), input.size(1), weight.size(1)}, options);
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
rocblas_set_stream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream());
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
rocblas_operation_none,
rocblas_operation_none,
weight.size(1),
bsz,
input_cont.size(2),
&alpha,
&gemm_beta,
(T*)weight.data_ptr(),
(T*)input_cont.data_ptr(),
(T*)output.data_ptr(),
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo_standard);
#else
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
#endif
if (add_bias)
launch_bias_add((T*)output.data_ptr(),
(T*)bias.data_ptr(),
weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
bool add_padding = (head_size % 32 != 0 && head_size < 64) || (head_size % 64 != 0);
if (do_flash_attn) {
if (add_padding) {
int padded_head_size = head_size < 32 ? 32 : (head_size < 64 ? 64 : 128);
auto padded_output = workspace + output.numel();
auto final_output =
padded_output + (input.size(0) * input.size(1) * 3 * num_heads * padded_head_size);
pad_data(padded_output,
workspace,
3 * bsz * num_heads,
head_size,
padded_head_size,
Context::Instance().GetCurrentStream());
launch_bias_add_transform_0213<T>(
final_output,
final_output + (input.size(0) * input.size(1) * num_heads * padded_head_size),
final_output + (input.size(0) * input.size(1) * 2 * num_heads * padded_head_size),
padded_output,
nullptr,
input.size(0),
input.size(1),
0,
input.size(1),
(num_heads * padded_head_size),
num_heads,
-1,
false,
false,
Context::Instance().GetCurrentStream(),
3,
input.size(1));
return at::from_blob(final_output,
{3, input.size(0), num_heads, input.size(1), padded_head_size},
options);
// return at::from_blob(padded_output, {input.size(0) * input.size(1), 3, num_heads,
// padded_head_size}, options);
} else {
auto final_output = workspace + output.numel();
launch_bias_add_transform_0213<T>(
final_output,
final_output + (input.size(0) * input.size(1) * input_cont.size(2)),
final_output + (input.size(0) * input.size(1) * 2 * input_cont.size(2)),
workspace,
nullptr,
input.size(0),
input.size(1),
0,
input.size(1),
input_cont.size(2),
num_heads,
-1,
false,
false,
Context::Instance().GetCurrentStream(),
3,
input.size(1));
return at::from_blob(
final_output, {3, input.size(0), num_heads, input.size(1), head_size}, options);
// return at::from_blob(workspace, {input.size(0) * input.size(1), 3, num_heads,
// head_size}, options);
}
} else
return output;
}
template <typename T>
std::vector<at::Tensor> add_padding(at::Tensor& query, at::Tensor& key, at::Tensor& value)
{
int head_size = query.size(3);
int padded_head_size = head_size < 32 ? 32 : (head_size < 64 ? 64 : 128);
T* workspace = (T*)Context::Instance().GetWorkSpace();
T* key_pad_ptr = workspace + padded_head_size * query.size(0) * query.size(1) * query.size(2);
T* value_pad_ptr = key_pad_ptr + padded_head_size * query.size(0) * query.size(1) * 128;
pad_head_seq(workspace,
(T*)query.data_ptr(),
query.size(0) * query.size(1),
query.size(2),
query.size(2),
head_size,
padded_head_size,
Context::Instance().GetCurrentStream());
pad_head_seq(key_pad_ptr,
(T*)key.data_ptr(),
query.size(0) * query.size(1),
key.size(2),
128,
head_size,
padded_head_size,
Context::Instance().GetCurrentStream());
pad_head_seq(value_pad_ptr,
(T*)value.data_ptr(),
query.size(0) * query.size(1),
key.size(2),
128,
head_size,
padded_head_size,
Context::Instance().GetCurrentStream());
return {
at::from_blob(workspace,
{query.size(0), query.size(1), query.size(2), padded_head_size},
query.options()),
at::from_blob(
key_pad_ptr, {query.size(0), query.size(1), 128, padded_head_size}, query.options()),
at::from_blob(
value_pad_ptr, {query.size(0), query.size(1), 128, padded_head_size}, query.options())};
}
template <typename T>
std::vector<at::Tensor> padd_add_transform(at::Tensor& query,
at::Tensor& key,
at::Tensor& value,
int heads,
bool add_padding)
{
int head_size = query.size(2) / heads;
int key_value_length = add_padding ? 128 : key.size(1);
int padded_head_size = add_padding ? (head_size < 32 ? 32 : (head_size < 64 ? 64 : 128))
: head_size;
T* workspace = (T*)Context::Instance().GetWorkSpace();
T* key_pad_ptr = workspace + padded_head_size * query.size(0) * heads * query.size(1);
T* value_pad_ptr = key_pad_ptr + padded_head_size * query.size(0) * heads * key_value_length;
launch_pad_add_transform_0213(workspace,
(T*)query.data_ptr(),
query.size(0),
query.size(2),
query.size(1),
query.size(1),
heads,
padded_head_size,
Context::Instance().GetCurrentStream());
launch_pad_add_transform_0213(key_pad_ptr,
(T*)key.data_ptr(),
key.size(0),
key.size(2),
key.size(1),
key_value_length,
heads,
padded_head_size,
Context::Instance().GetCurrentStream());
launch_pad_add_transform_0213(value_pad_ptr,
(T*)value.data_ptr(),
value.size(0),
value.size(2),
value.size(1),
key_value_length,
heads,
padded_head_size,
Context::Instance().GetCurrentStream());
return {
at::from_blob(
workspace, {query.size(0), heads, query.size(1), padded_head_size}, query.options()),
at::from_blob(key_pad_ptr,
{query.size(0), heads, key_value_length, padded_head_size},
query.options()),
at::from_blob(value_pad_ptr,
{query.size(0), heads, key_value_length, padded_head_size},
query.options())};
}
template <typename T>
at::Tensor ds_linear_layer_int8(at::Tensor& input,
at::Tensor& weight,
at::Tensor& bias,
at::Tensor& q_scale,
int groups)
{
auto input_cont = input.contiguous();
auto options = at::TensorOptions()
.dtype(input_cont.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
int bsz = input_cont.size(0) * input_cont.size(1);
auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options);
quantized_gemm<T>(output, input_cont, weight, q_scale, groups, 0);
launch_bias_add((T*)output.data_ptr(),
(T*)bias.data_ptr(),
weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
return output;
}
template <typename T>
at::Tensor ds_vector_matmul(at::Tensor& input,
at::Tensor& weight,
bool async_op,
at::Tensor& q_scale,
bool q_int8)
{
auto options = at::TensorOptions()
.dtype(input.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
int out_size = q_int8 ? weight.size(0) : weight.size(1);
int bsz = input.size(0) * input.size(1);
T* workspace = (T*)Context::Instance().GetWorkSpace();
auto output = at::from_blob(workspace, {input.size(0), input.size(1), out_size}, options);
if (q_int8) {
quantized_gemm<T>(output.data_ptr(),
(T*)input.data_ptr(),
weight,
q_scale,
q_scale.size(0),
bsz,
input.size(2));
} else {
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
rocblas_set_stream(Context::Instance().GetCublasHandle(),
Context::Instance().GetCurrentStream(async_op));
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
rocblas_operation_none,
rocblas_operation_none,
weight.size(1),
bsz,
input.size(2),
&alpha,
&gemm_beta,
(T*)weight.data_ptr(),
(T*)input.data_ptr(),
(T*)output.data_ptr(),
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo_standard);
#else
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
#endif
}
return output;
}
template <typename T>
at::Tensor ds_vector_matmul_int8(at::Tensor& input,
at::Tensor& weight,
at::Tensor& q_scale,
int groups,
int merge_count)
{
auto input_cont = input.contiguous();
auto options = at::TensorOptions()
.dtype(input_cont.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options);
quantized_gemm<T>(output, input_cont, weight, q_scale, groups, merge_count);
return output;
}
template <typename T>
at::Tensor mlp_unfused_cublas(at::Tensor& output,
at::Tensor& input,
at::Tensor& residual,
at::Tensor& input_bias,
at::Tensor& weight,
at::Tensor& weight1,
at::Tensor& bias,
at::Tensor& gamma,
at::Tensor& beta,
const float epsilon,
bool preLayerNorm,
bool mlp_after_attn,
at::Tensor& q_scale,
at::Tensor& q_scale1,
bool q_int8,
ActivationFuncType act_func_type)
{
int bsz = input.size(0) * input.size(1);
T* inp_norm =
(T*)Context::Instance().GetWorkSpace() + torch::numel(input) + torch::numel(output);
T* intermediate = inp_norm + torch::numel(input);
if (mlp_after_attn) {
launch_fused_residual_ln((T*)inp_norm,
(const T*)input.data_ptr(),
(const T*)residual.data_ptr(),
(const T*)input_bias.data_ptr(),
(const T*)gamma.data_ptr(),
(const T*)beta.data_ptr(),
epsilon,
bsz,
input.size(2),
Context::Instance().GetCurrentStream());
} else {
ds_layer_norm_internal(inp_norm, input, gamma, beta, epsilon);
}
if (q_int8) {
quantized_gemm<T>(
intermediate, inp_norm, weight, q_scale, q_scale.size(0), bsz, input.size(2));
} else {
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
rocblas_set_stream(Context::Instance().GetCublasHandle(),
Context::Instance().GetCurrentStream());
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
rocblas_operation_none,
rocblas_operation_none,
weight.size(1),
bsz,
input.size(2),
&alpha,
&gemm_beta,
(T*)weight.data_ptr(),
inp_norm,
intermediate,
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo_standard);
#else
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
#endif
}
if (act_func_type == ActivationFuncType::GELU) {
launch_bias_gelu(intermediate,
(T*)bias.data_ptr(),
q_int8 ? weight.size(0) : weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
} else if (act_func_type == ActivationFuncType::ReLU) {
launch_bias_relu(intermediate,
(T*)bias.data_ptr(),
q_int8 ? weight.size(0) : weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
}
if (q_int8) {
quantized_gemm<T>(output.data_ptr(),
intermediate,
weight1,
q_scale1,
q_scale1.size(0),
bsz,
input.size(2));
} else {
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
rocblas_set_stream(Context::Instance().GetCublasHandle(),
Context::Instance().GetCurrentStream());
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
rocblas_operation_none,
rocblas_operation_none,
weight1.size(1),
bsz,
weight1.size(0),
&alpha,
&gemm_beta,
(T*)weight1.data_ptr(),
intermediate,
(T*)output.data_ptr(),
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo_standard);
#else
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
#endif
}
return torch::from_blob(inp_norm, input.sizes(), input.options());
}
template <typename T>
std::vector<at::Tensor> ds_mlp_gemm(at::Tensor& input,
at::Tensor& residual,
at::Tensor& input_bias,
at::Tensor& weight_interm,
at::Tensor& weight_out,
at::Tensor& bias,
at::Tensor& gamma,
at::Tensor& beta,
const float epsilon,
bool preLayerNorm,
bool mlp_after_attn,
at::Tensor& q_scale,
at::Tensor& q_scale1,
bool q_int8,
int activation_type)
{
auto options = at::TensorOptions()
.dtype(input.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
int out_size = q_int8 ? weight_out.size(0) : weight_out.size(1);
auto output = at::from_blob((T*)Context::Instance().GetWorkSpace() + torch::numel(input),
{input.size(0), input.size(1), out_size},
options);
int bsz = input.size(0) * input.size(1);
auto act_func_type = static_cast<ActivationFuncType>(activation_type);
auto res_add = mlp_unfused_cublas<T>(output,
mlp_after_attn ? input : residual,
residual,
input_bias,
weight_interm,
weight_out,
bias,
gamma,
beta,
epsilon,
preLayerNorm,
mlp_after_attn,
q_scale,
q_scale1,
q_int8,
act_func_type);
return {output, res_add};
}
template <typename T>
std::vector<at::Tensor> ds_mlp_gemm_int8(at::Tensor& input,
at::Tensor& residual,
at::Tensor& input_bias,
at::Tensor& weight,
at::Tensor& bias,
at::Tensor& gamma,
at::Tensor& beta,
const float epsilon,
at::Tensor& q_scale,
int groups,
bool preLayerNorm)
{
auto input_cont = input.contiguous();
auto options = at::TensorOptions()
.dtype(input_cont.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options);
int bsz = input_cont.size(0) * input_cont.size(1);
auto inp_norm = at::empty_like(input_cont);
auto residual_add = (preLayerNorm ? at::empty_like(input_cont) : inp_norm);
quantized_gemm<T>(output, inp_norm, weight, q_scale, groups, 0);
launch_bias_gelu((T*)output.data_ptr(),
(T*)bias.data_ptr(),
weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
return {output, residual_add};
}
template <typename T>
at::Tensor fused_gemm_gelu(at::Tensor& input,
at::Tensor& weight,
at::Tensor& weight_scale,
at::Tensor& bias,
at::Tensor& weight_out,
at::Tensor& weight_out_scale,
const float epsilon,
bool preLayerNorm,
bool q_int8,
bool async_op)
{
auto options = at::TensorOptions()
.dtype(input.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
int intm_dim = q_int8 ? weight.size(0) : weight.size(1);
// auto output = at::from_blob((T*)Context::Instance().GetWorkSpace() + torch::numel(input),
// {input.size(0), input.size(1), out_size},
// options);
// T* intermediate = (T*)input.data_ptr() + torch::numel(input);
auto intermediate = at::empty({input.size(0), input.size(1), intm_dim}, options);
int bsz = input.size(0) * input.size(1);
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
if (q_int8) {
quantized_gemm<T>(intermediate.data_ptr(),
(T*)input.data_ptr(),
weight,
weight_scale,
weight_scale.size(0),
bsz,
input.size(2));
} else {
rocblas_set_stream(Context::Instance().GetCublasHandle(),
Context::Instance().GetCurrentStream());
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
rocblas_operation_none,
rocblas_operation_none,
intm_dim,
bsz,
input.size(2),
&alpha,
&gemm_beta,
(T*)weight.data_ptr(),
(T*)input.data_ptr(),
(T*)intermediate.data_ptr(),
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo_standard);
#else
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
#endif
}
launch_bias_gelu((T*)intermediate.data_ptr(),
(T*)bias.data_ptr(),
intm_dim,
bsz,
Context::Instance().GetCurrentStream());
int out_size = q_int8 ? weight_out.size(0) : weight_out.size(1);
auto output = at::empty({input.size(0), input.size(1), out_size}, options);
if (q_int8) {
quantized_gemm<T>(output.data_ptr(),
(T*)intermediate.data_ptr(),
weight_out,
weight_out_scale,
weight_out_scale.size(0),
bsz,
input.size(2));
} else {
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
rocblas_operation_none,
rocblas_operation_none,
out_size,
bsz,
intm_dim,
&alpha,
&gemm_beta,
(T*)weight_out.data_ptr(),
(T*)intermediate.data_ptr(),
(T*)output.data_ptr(),
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo_standard);
#else
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
#endif
}
// hipEventRecord(Context::Instance().GetCompEvent(2),
// Context::Instance().GetCurrentStream(true));
return output;
}
template <typename T>
at::Tensor& residual_add_bias(at::Tensor& hidden_state,
at::Tensor& residual,
const at::Tensor& attention_output,
const at::Tensor& attention_bias,
const at::Tensor& final_bias,
const int mp_size,
const bool mlp_after_attn,
const bool add_bias,
const bool preln)
{
int bsz = residual.size(0) * residual.size(1);
int hidden_size = residual.size(2);
if (mlp_after_attn)
launch_bias_residual(static_cast<T*>(residual.data_ptr()),
static_cast<T*>(hidden_state.data_ptr()),
static_cast<T*>(attention_output.data_ptr()),
static_cast<T*>(final_bias.data_ptr()),
static_cast<T*>(attention_bias.data_ptr()),
bsz,
hidden_size,
mp_size,
preln,
Context::Instance().GetCurrentStream());
else
launch_gptj_residual_add<T>(
static_cast<T*>(residual.data_ptr()),
static_cast<T*>(hidden_state.data_ptr()),
static_cast<T*>(attention_output.data_ptr()),
static_cast<T*>(final_bias.data_ptr()),
static_cast<T*>((add_bias ? attention_bias.data_ptr() : nullptr)),
hidden_size,
bsz,
mp_size,
Context::Instance().GetCurrentStream());
return residual;
}
std::vector<at::Tensor> apply_rotary_pos_emb(at::Tensor& mixed_query,
at::Tensor& key_layer,
unsigned rotary_dim,
unsigned offset,
unsigned num_heads,
bool rotate_half,
bool rotate_every_two)
{
auto query_cont = mixed_query.contiguous();
auto key_cont = key_layer.contiguous();
unsigned bsz = mixed_query.size(0);
unsigned head_size = mixed_query.size(2) / num_heads;
unsigned seq_len = mixed_query.size(1);
if (mixed_query.scalar_type() == at::kFloat)
launch_apply_rotary_pos_emb<float>((float*)query_cont.data_ptr(),
(float*)key_cont.data_ptr(),
head_size,
seq_len,
rotary_dim,
offset,
num_heads,
bsz,
rotate_half,
rotate_every_two,
Context::Instance().GetCurrentStream(),
Context::Instance().GetMaxTokenLenght());
else
launch_apply_rotary_pos_emb<__half>((__half*)query_cont.data_ptr(),
(__half*)key_cont.data_ptr(),
head_size,
seq_len,
rotary_dim,
offset,
num_heads,
bsz,
rotate_half,
rotate_every_two,
Context::Instance().GetCurrentStream(),
Context::Instance().GetMaxTokenLenght());
return {query_cont, key_cont};
}
template <typename T>
at::Tensor fused_gemm_gelu_int8(at::Tensor& input,
at::Tensor& weight,
at::Tensor& bias,
const float epsilon,
at::Tensor& q_scale,
int groups,
bool preLayerNorm)
{
auto input_cont = input.contiguous();
auto options = at::TensorOptions()
.dtype(input_cont.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options);
int bsz = input_cont.size(0) * input_cont.size(1);
quantized_gemm<T>(output, input_cont, weight, q_scale, groups, 0);
launch_bias_gelu((T*)output.data_ptr(),
(T*)bias.data_ptr(),
weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
return output;
}
at::Tensor moe_res_matmul(at::Tensor& moe_res, at::Tensor& coef, at::Tensor& output)
{
int M = moe_res.size(0) * moe_res.size(1);
int N = moe_res.size(2);
Context::Instance().SynchComm();
if (moe_res.scalar_type() == at::kFloat) {
launch_moe_res_matmul<float>((float*)moe_res.data_ptr(),
(float*)coef.data_ptr(),
(float*)output.data_ptr(),
M,
N,
at::hip::getCurrentHIPStreamMasqueradingAsCUDA());
} else {
launch_moe_res_matmul<__half>((__half*)moe_res.data_ptr(),
(__half*)coef.data_ptr(),
(__half*)output.data_ptr(),
M,
N,
at::hip::getCurrentHIPStreamMasqueradingAsCUDA());
}
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("softmax_fp32", &ds_softmax<float>, "DeepSpeed SoftMax with fp32 (CUDA)");
m.def("softmax_fp16", &ds_softmax<__half>, "DeepSpeed SoftMax with fp16 (CUDA)");
m.def(
"softmax_context_fp32", &ds_softmax_context<float>, "DeepSpeed attention with fp32 (CUDA)");
m.def("softmax_context_fp16",
&ds_softmax_context<__half>,
"DeepSpeed attention with fp16 (CUDA)");
m.def("softmax_context_int8",
&ds_softmax_context1<__half>,
"DeepSpeed attention with int8 (CUDA)");
m.def("bias_gelu_fp32", &ds_bias_gelu<float>, "DeepSpeed Gelu with fp32 (CUDA)");
m.def("bias_gelu_fp16", &ds_bias_gelu<__half>, "DeepSpeed Gelu with fp16 (CUDA)");
m.def("bias_geglu", &ds_bias_geglu, "DeepSpeed Bias GEGLU (CUDA)");
m.def("bias_add_fp32", &ds_bias_add<float>, "DeepSpeed Bias Add with fp32 (CUDA)");
m.def("bias_add_fp16", &ds_bias_add<__half>, "DeepSpeed Gelu with fp16 (CUDA)");
m.def("bias_relu_fp32", &ds_bias_relu<float>, "DeepSpeed ReLU with fp32 (CUDA)");
m.def("bias_relu_fp16", &ds_bias_relu<__half>, "DeepSpeed ReLU with fp16 (CUDA)");
m.def("bias_residual_fp32",
&ds_bias_residual<float>,
"DeepSpeed residual-bias add with fp32 (CUDA)");
m.def("bias_residual_fp16",
&ds_bias_residual<__half>,
"DeepSpeed residual-bias add with fp16 (CUDA)");
m.def("layer_norm", &ds_layer_norm, "DeepSpeed layer norm (CUDA)");
m.def(
"_layer_norm_residual", &ds_layer_norm_residual, "DeepSpeed layer norm + residual (CUDA)");
m.def("layer_norm_residual_store_pre_ln_res",
&ds_layer_norm_residual_store_pre_ln_res,
"DeepSpeed layer norm + store pre Layernorm residual (CUDA)");
m.def("qkv_gemm_fp32", &ds_qkv_gemm<float>, "DeepSpeed qkv gemm with fp32 (CUDA)");
m.def("qkv_gemm_fp16", &ds_qkv_gemm<__half>, "DeepSpeed qkv gemm with fp16 (CUDA)");
m.def("qkv_gemm_int8", &ds_qkv_gemm_int8<__half>, "DeepSpeed qkv gemm with int8 (CUDA)");
m.def("mlp_gemm_fp32", &ds_mlp_gemm<float>, "DeepSpeed mlp with fp32 (CUDA)");
m.def("mlp_gemm_fp16", &ds_mlp_gemm<__half>, "DeepSpeed mlp with fp16 (CUDA)");
m.def("mlp_gemm_int8", &ds_mlp_gemm_int8<__half>, "DeepSpeed mlp with int8 (CUDA)");
m.def("vector_matmul_fp32", &ds_vector_matmul<float>, "DeepSpeed vector-MM with fp32 (CUDA)");
m.def("vector_matmul_fp16", &ds_vector_matmul<__half>, "DeepSpeed vector-MM with fp16 (CUDA)");
m.def("vector_matmul_int8",
&ds_vector_matmul_int8<__half>,
"DeepSpeed vector-MM with int8 (CUDA)");
m.def("linear_layer_fp32", &ds_linear_layer<float>, "DeepSpeed linear_layer with fp32 (CUDA)");
m.def("linear_layer_fp16", &ds_linear_layer<__half>, "DeepSpeed linear_layer with fp16 (CUDA)");
m.def("linear_layer_int8",
&ds_linear_layer_int8<__half>,
"DeepSpeed linear_layer with int8 (CUDA)");
m.def("fused_gemm_gelu_fp32", &fused_gemm_gelu<float>, "DeepSpeed mlp with fp32 (CUDA)");
m.def("fused_gemm_gelu_fp16", &fused_gemm_gelu<__half>, "DeepSpeed mlp with fp16 (CUDA)");
m.def("residual_add_bias_fp32",
&residual_add_bias<float>,
"DeepSpeed residual add with fp32 (CUDA)");
m.def("residual_add_bias_fp16",
&residual_add_bias<__half>,
"DeepSpeed residual add with fp16 (CUDA)");
m.def("apply_rotary_pos_emb", &apply_rotary_pos_emb, "DeepSpeed mlp with fp16 (CUDA)");
m.def("einsum_sec_sm_ecm_fp32",
&einsum_sec_sm_ecm<float>,
"DeepSpeed vector-MM with fp32 (CUDA)");
m.def("einsum_sec_sm_ecm_fp16",
&einsum_sec_sm_ecm<__half>,
"DeepSpeed vector-MM with fp16 (CUDA)");
m.def("moe_res_matmul", &moe_res_matmul, "DeepSpeed moe residual matmul (CUDA)");
m.def("add_padding_fp32", &add_padding<float>, "DeepSpeed residual add with fp32 (CUDA)");
m.def("add_padding_fp16", &add_padding<__half>, "DeepSpeed residual add with fp16 (CUDA)");
m.def("pad_transform_fp32",
&padd_add_transform<float>,
"DeepSpeed residual add with fp32 (CUDA)");
m.def("pad_transform_fp16",
&padd_add_transform<__half>,
"DeepSpeed residual add with fp16 (CUDA)");
m.def("allocate_workspace_fp32",
&allocate_workspace<float>,
"DeepSpeed memory allocation for GPT inference with fp32 (CUDA)");
m.def("allocate_workspace_fp16",
&allocate_workspace<__half>,
"DeepSpeed memory allocation for GPT inference with fp16 (CUDA)");
m.def("reset_cache", &reset_cache, "Reset Cache for generation tasks");
}
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/
#include "conversion_utils.h"
#include "inference_cuda_layers.h"
#include "memory_access_utils.h"
namespace cg = cooperative_groups;
#define MAX_CAP 4
#define MAX_SEQ 2048
inline __device__ float relu(const float x) { return x < 0 ? 0 : x; }
/*
In-place relu(biasAdd(x)) for channels last
*/
template <typename T>
__global__ void fused_bias_relu(T* input, const T* bias, int total_count, int intermediate_size)
{
// Input restriction: intermediate_size % vals_per_access == 0
constexpr int granularity = 16;
constexpr int values_per_access = granularity / sizeof(T);
const int offset = (blockIdx.x * blockDim.x + threadIdx.x) * values_per_access;
if (offset < total_count) {
T data[values_per_access];
T data_bias[values_per_access];
mem_access::load_global<granularity>(data, input + offset);
mem_access::load_global<granularity>(data_bias, bias + (offset % intermediate_size));
#pragma unroll
for (int i = 0; i < values_per_access; i++) {
float data_f = conversion::to<float>(data[i]);
float bias_f = conversion::to<float>(data_bias[i]);
data[i] = conversion::to<T>(relu(data_f + bias_f));
}
mem_access::store_global<granularity>(input + offset, data);
}
}
template <typename T>
void launch_bias_relu(T* input,
const T* bias,
int intermediate_size,
int batch_size,
hipStream_t stream)
{
constexpr int threads = 1024;
constexpr int granularity = 16;
const int total_count = batch_size * intermediate_size;
const int elems_per_block = threads * (granularity / sizeof(T));
dim3 block_dims(threads);
dim3 grid_dims((total_count + elems_per_block - 1) / elems_per_block);
hipLaunchKernelGGL(( fused_bias_relu), dim3(grid_dims), dim3(block_dims), 0, stream,
input, bias, total_count, intermediate_size);
}
template void launch_bias_relu<float>(float*, const float*, int, int, hipStream_t);
template void launch_bias_relu<__half>(__half*, const __half*, int, int, hipStream_t);
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/
#include <limits>
#include "inference_cuda_layers.h"
#ifndef __HIP_PLATFORM_HCC__
#include <cuda_profiler_api.h>
#endif
#include <cstdio>
#include <cstdlib>
#include <ctime>
#define ATTN_THREADS 256
#define MAX_REG_SIZE 8
#define minus_infinity -10000.0
void CheckCudaErrorAux(const char* file, unsigned line)
{
hipError_t err = hipGetLastError();
if (err == hipSuccess) return;
std::cerr << hipGetErrorString(err) << "(" << err << ") at " << file << ":" << line
<< std::endl;
throw std::runtime_error("CUDA ERROR!!!\n");
}
#define CUDA_CHECK_ERROR() CheckCudaErrorAux(__FILE__, __LINE__)
namespace cg = cooperative_groups;
__global__ void attn_softmax_v2(__half* vals,
__half* mask,
__half* alibi,
float layer_scale,
bool triangular,
bool recompute,
bool local_attention,
int window_size,
int total_count,
int heads,
int sequence_length,
int num_seq,
int head_offset,
int mask_stride,
int mp_size,
int iterations,
int reduceWidth)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
float2 low_data[MAX_REG_SIZE];
float2 high_data[MAX_REG_SIZE];
const __half zero_h = __float2half(0.f);
int wid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int reduce_blocks = reduceWidth >> 5;
int seq_lane = threadIdx.x % reduceWidth;
__shared__ float partialSum[MAX_WARP_NUM];
int iter_offset = blockIdx.x * (warp_num / reduce_blocks) + (wid / reduce_blocks);
int batch_idx = iter_offset / (num_seq * heads);
int alibi_offset = batch_idx * heads * mp_size + head_offset;
int mask_offset = batch_idx * mask_stride + (iter_offset % mask_stride);
if (iter_offset < total_count) {
vals += (iter_offset * sequence_length);
alibi_offset = (alibi_offset + ((iter_offset / num_seq) % heads)) * sequence_length;
mask_offset = mask_offset * sequence_length;
int seq_id = iter_offset % num_seq;
int seq_id4 = seq_id >> 2;
int real_seq_id = seq_id + (num_seq == sequence_length ? 0 : sequence_length);
int window_stride4 = (local_attention && (real_seq_id >> 2) > (window_size >> 2))
? (real_seq_id >> 2) - (window_size >> 2)
: 0;
int window_stride =
(local_attention && real_seq_id >= window_size) ? real_seq_id - window_size : -1;
float max_val = minus_infinity;
// if (lane == 0) printf("%d, %d: %d \n", wid, blockIdx.x, mask_offset);
for (int i = 0; i < iterations; i++) {
int data_id = i * (reduceWidth << 2) + (seq_lane << 2);
if ((!triangular || ((data_id >> 2) <= seq_id4)) && (data_id >> 2) >= window_stride4 &&
data_id < sequence_length) {
if ((sequence_length - data_id) >= 4) {
low_data[i].x = data_id > window_stride
? __half2float(vals[data_id]) * layer_scale
: minus_infinity;
low_data[i].y = ((!triangular || ((data_id + 1) <= seq_id)) &&
(data_id + 1) > window_stride)
? __half2float(vals[data_id + 1]) * layer_scale
: minus_infinity;
high_data[i].x = ((!triangular || ((data_id + 2) <= seq_id)) &&
(data_id + 2) > window_stride)
? __half2float(vals[data_id + 2]) * layer_scale
: minus_infinity;
high_data[i].y = ((!triangular || ((data_id + 3) <= seq_id)) &&
(data_id + 3) > window_stride)
? __half2float(vals[data_id + 3]) * layer_scale
: minus_infinity;
if (alibi) {
low_data[i].x = low_data[i].x + __half2float(alibi[data_id + alibi_offset]);
low_data[i].y =
low_data[i].y + __half2float(alibi[data_id + alibi_offset + 1]);
high_data[i].x =
high_data[i].x + __half2float(alibi[data_id + alibi_offset + 2]);
high_data[i].y =
high_data[i].y + __half2float(alibi[data_id + alibi_offset + 3]);
}
if (mask) {
low_data[i].x += __half2float(mask[data_id + mask_offset]);
low_data[i].y += __half2float(mask[data_id + mask_offset + 1]);
high_data[i].x += __half2float(mask[data_id + mask_offset + 2]);
high_data[i].y += __half2float(mask[data_id + mask_offset + 3]);
}
} else {
low_data[i].x = data_id > window_stride
? __half2float(vals[data_id]) * layer_scale
: minus_infinity;
low_data[i].y = (((!triangular || (data_id + 1) <= seq_id) &&
(data_id + 1) > window_stride) &&
(data_id + 1) < sequence_length)
? __half2float(vals[data_id + 1]) * layer_scale
: minus_infinity;
high_data[i].x = (((!triangular || (data_id + 2) <= seq_id) &&
(data_id + 2) > window_stride) &&
(data_id + 2) < sequence_length)
? __half2float(vals[data_id + 2]) * layer_scale
: minus_infinity;
if (alibi) {
low_data[i].x = low_data[i].x + __half2float(alibi[data_id + alibi_offset]);
if ((data_id + 1) < sequence_length)
low_data[i].y =
low_data[i].y + __half2float(alibi[data_id + alibi_offset + 1]);
if ((data_id + 2) < sequence_length)
high_data[i].x =
high_data[i].x + __half2float(alibi[data_id + alibi_offset + 2]);
}
high_data[i].y = minus_infinity;
if (mask) {
low_data[i].x += __half2float(mask[data_id + mask_offset]);
if ((data_id + 1) < sequence_length)
low_data[i].y += __half2float(mask[data_id + mask_offset + 1]);
if ((data_id + 2) < sequence_length)
high_data[i].x += __half2float(mask[data_id + mask_offset + 2]);
}
}
// if(lane == 0) printf("%f , %d, %d \n", low_data[i].x, data_id, seq_id);
max_val = (low_data[i].x > max_val ? low_data[i].x : max_val);
max_val = (low_data[i].y > max_val ? low_data[i].y : max_val);
max_val = (high_data[i].x > max_val ? high_data[i].x : max_val);
max_val = (high_data[i].y > max_val ? high_data[i].y : max_val);
} else {
low_data[i].x = minus_infinity;
low_data[i].y = minus_infinity;
high_data[i].x = minus_infinity;
high_data[i].y = minus_infinity;
}
}
for (int i = 1; i < WARP_SIZE; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
if (reduceWidth > WARP_SIZE) {
if (lane == 0) partialSum[wid] = max_val;
b.sync();
if (lane < warp_num) max_val = partialSum[lane];
b.sync();
for (int i = 1; i < reduce_blocks; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
max_val = g.shfl(max_val, threadIdx.x / WARP_SIZE);
}
float sum = 0;
for (int i = 0; i < iterations; i++) {
low_data[i].x = __expf(low_data[i].x - max_val);
low_data[i].y = __expf(low_data[i].y - max_val);
high_data[i].x = __expf(high_data[i].x - max_val);
high_data[i].y = __expf(high_data[i].y - max_val);
sum += (low_data[i].x + low_data[i].y + high_data[i].x + high_data[i].y);
}
for (int i = 1; i < WARP_SIZE; i *= 2) sum += g.shfl_xor(sum, i);
if (reduceWidth > WARP_SIZE) {
if (lane == 0) partialSum[wid] = sum;
b.sync();
if (lane < warp_num) sum = partialSum[lane];
b.sync();
for (int i = 1; i < reduce_blocks; i *= 2) { sum += g.shfl_xor(sum, i); }
sum = g.shfl(sum, threadIdx.x / WARP_SIZE);
}
sum += 1e-6;
for (int i = 0; i < iterations; i++) {
int data_id = i * (reduceWidth << 2) + (seq_lane << 2);
if (data_id < sequence_length) {
if ((sequence_length - data_id) >= 4) {
vals[data_id] = __float2half(low_data[i].x / sum);
vals[data_id + 1] = __float2half(low_data[i].y / sum);
vals[data_id + 2] = __float2half(high_data[i].x / sum);
vals[data_id + 3] = __float2half(high_data[i].y / sum);
} else {
vals[data_id] = __float2half(low_data[i].x / sum);
if ((data_id + 1) < sequence_length)
vals[data_id + 1] = __float2half(low_data[i].y / sum);
if ((data_id + 2) < sequence_length)
vals[data_id + 2] = __float2half(high_data[i].x / sum);
}
}
}
}
}
__global__ void attn_softmax_v2(float* vals,
float* attn_mask,
float* alibi,
float layer_scale,
bool triangular,
bool recompute,
bool local_attention,
int window_size,
int total_count,
int heads,
int sequence_length,
int num_seq,
int head_offset,
int mask_stride,
int mp_size,
int iterations,
int reduceWidth)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
float4 data[MAX_REG_SIZE];
int wid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
int warp_num = blockDim.x >> 5;
int reduce_blocks = reduceWidth >> 5;
int seq_lane = threadIdx.x % reduceWidth;
__shared__ float partialSum[MAX_WARP_NUM];
int iter_offset = blockIdx.x * (warp_num / reduce_blocks) + (wid / reduce_blocks);
if (iter_offset < total_count) {
vals += (iter_offset * sequence_length);
int batch_idx = iter_offset / (num_seq * heads);
int alibi_offset = batch_idx * heads * mp_size + head_offset;
int mask_offset = batch_idx * mask_stride + (iter_offset % mask_stride);
mask_offset = mask_offset * sequence_length;
int seq_id = iter_offset % num_seq;
int seq_id4 = seq_id >> 2;
int real_seq_id = seq_id + (num_seq == sequence_length ? 0 : sequence_length);
int window_stride4 = (local_attention && (real_seq_id >> 2) > (window_size >> 2))
? (real_seq_id >> 2) - (window_size >> 2)
: 0;
int window_stride =
(local_attention && real_seq_id >= window_size) ? real_seq_id - window_size : -1;
float max_val = minus_infinity;
for (int i = 0; i < iterations; i++) {
int data_id = i * (reduceWidth << 2) + (seq_lane << 2);
if ((!triangular || ((data_id >> 2) <= seq_id4)) && (data_id >> 2) >= window_stride4 &&
data_id < sequence_length) {
if ((sequence_length - data_id) >= 4) {
data[i].x = (data_id > window_stride ? vals[data_id] : minus_infinity);
data[i].y = ((!triangular || ((data_id + 1) <= seq_id)) &&
(data_id + 1) > window_stride)
? vals[data_id + 1]
: minus_infinity;
data[i].z = ((!triangular || ((data_id + 2) <= seq_id)) &&
(data_id + 2) > window_stride)
? vals[data_id + 2]
: minus_infinity;
data[i].w = ((!triangular || ((data_id + 3) <= seq_id)) &&
(data_id + 3) > window_stride)
? vals[data_id + 3]
: minus_infinity;
if (attn_mask) {
data[i].x += attn_mask[data_id + mask_offset];
data[i].y += attn_mask[data_id + mask_offset + 1];
data[i].z += attn_mask[data_id + mask_offset + 2];
data[i].w += attn_mask[data_id + mask_offset + 3];
}
} else {
data[i].x = data_id > window_stride ? vals[data_id] : minus_infinity;
data[i].y = (((!triangular || (data_id + 1) <= seq_id)) &&
(data_id + 1) > window_stride && (data_id + 1) < sequence_length)
? (vals[data_id + 1])
: minus_infinity;
data[i].z = (((!triangular || (data_id + 2) <= seq_id)) &&
(data_id + 2) > window_stride && (data_id + 2) < sequence_length)
? (vals[data_id + 2])
: minus_infinity;
data[i].w = minus_infinity;
if (attn_mask) {
data[i].x += attn_mask[data_id + mask_offset];
if ((data_id + 1) < sequence_length)
data[i].y += attn_mask[data_id + mask_offset + 1];
if ((data_id + 2) < sequence_length)
data[i].z += attn_mask[data_id + mask_offset + 2];
}
}
max_val = (data[i].x > max_val ? data[i].x : max_val);
max_val = (data[i].y > max_val ? data[i].y : max_val);
max_val = (data[i].z > max_val ? data[i].z : max_val);
max_val = (data[i].w > max_val ? data[i].w : max_val);
} else {
data[i].x = minus_infinity;
data[i].y = minus_infinity;
data[i].z = minus_infinity;
data[i].w = minus_infinity;
}
}
for (int i = 1; i < WARP_SIZE; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
if (reduceWidth > WARP_SIZE) {
if (lane == 0) partialSum[wid] = max_val;
b.sync();
if (lane < warp_num) max_val = partialSum[lane];
b.sync();
for (int i = 1; i < reduce_blocks; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
max_val = g.shfl(max_val, threadIdx.x / WARP_SIZE);
}
float sum = 0;
for (int i = 0; i < iterations; i++) {
data[i].x = __expf(data[i].x - max_val);
data[i].y = __expf(data[i].y - max_val);
data[i].z = __expf(data[i].z - max_val);
data[i].w = __expf(data[i].w - max_val);
sum += (data[i].x + data[i].y + data[i].z + data[i].w);
}
for (int i = 1; i < WARP_SIZE; i *= 2) sum += g.shfl_xor(sum, i);
if (reduceWidth > WARP_SIZE) {
if (lane == 0) partialSum[wid] = sum;
b.sync();
if (lane < warp_num) sum = partialSum[lane];
b.sync();
for (int i = 1; i < reduce_blocks; i *= 2) { sum += g.shfl_xor(sum, i); }
sum = g.shfl(sum, threadIdx.x / WARP_SIZE);
}
sum += 1e-6;
for (int i = 0; i < iterations; i++) {
int data_id = i * (reduceWidth << 2) + (seq_lane << 2);
if (data_id < sequence_length) {
if ((sequence_length - data_id) >= 4) {
vals[data_id] = data[i].x / sum;
vals[data_id + 1] = data[i].y / sum;
vals[data_id + 2] = data[i].z / sum;
vals[data_id + 3] = data[i].w / sum;
} else {
vals[data_id] = data[i].x / sum;
if ((data_id + 1) < sequence_length) vals[data_id + 1] = data[i].y / sum;
if ((data_id + 2) < sequence_length) vals[data_id + 2] = data[i].z / sum;
}
}
}
}
}
template <typename T>
void launch_attn_softmax_v2(T* vals,
T* mask,
T* alibi,
float layer_scale,
bool triangular,
bool recompute,
bool local_attention,
int window_size,
int batch_size,
int heads,
int num_seq,
int sequence_length,
int head_offset,
int mask_stride,
int mp_size,
hipStream_t stream)
{
int total_count = batch_size * heads * num_seq;
int warp_num = ATTN_THREADS / WARP_SIZE;
int reduce_width = ((sequence_length - 1) / ATTN_THREADS + 1);
reduce_width = (int)pow(2.0, floor(log2((float)(reduce_width)))) * WARP_SIZE;
dim3 grid_dim((total_count - 1) / (ATTN_THREADS / reduce_width) + 1);
dim3 block_dim(ATTN_THREADS);
const int iterations = (sequence_length - 1) / (reduce_width << 2) + 1;
if (sequence_length <= 32768)
hipLaunchKernelGGL(( attn_softmax_v2), dim3(grid_dim), dim3(block_dim), 0, stream, vals,
mask,
alibi,
layer_scale,
triangular,
recompute,
local_attention,
window_size,
total_count,
heads,
sequence_length,
num_seq,
head_offset,
mask_stride,
mp_size,
iterations,
reduce_width);
else
throw std::runtime_error("Unsupport Seq_Length!");
}
template void launch_attn_softmax_v2(float* vals,
float* mask,
float* alibi,
float layer_scale,
bool triangular,
bool recompute,
bool local_attention,
int window_size,
int batch_size,
int heads,
int num_seq,
int sequence_length,
int head_offset,
int mask_stride,
int mp_size,
hipStream_t stream);
template void launch_attn_softmax_v2(__half* vals,
__half* mask,
__half* alibi,
float layer_scale,
bool triangular,
bool recompute,
bool local_attention,
int window_size,
int batch_size,
int heads,
int num_seq,
int sequence_length,
int head_offset,
int mask_stride,
int mp_size,
hipStream_t stream);
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/
#ifndef __HIP_PLATFORM_HCC__
#include <cuda_profiler_api.h>
#endif
#include "inference_cuda_layers.h"
namespace cg = cooperative_groups;
// Bias add
__global__ void bias_add_transform_0213(float* output,
float* k_cache,
float* v_cache,
const float* vals,
const float* bias,
int hidden_dim,
int seq_length,
unsigned seq_offset,
int heads,
int rotary_dim,
bool rotate_half,
bool rotate_every_two,
int head_ext,
int max_out_tokens)
{
int d0_stride = hidden_dim * seq_length;
int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads;
int d0 = blockIdx.x; // Batch
int d1 = blockIdx.y; // Sequence ID (0-127)
int cnt = blockIdx.z / head_ext; // Hidden count
int d2 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head (0-11)
int d3 = threadIdx.x; // Values (groups of 4)
int d2_out_stride = d2_stride * (cnt == 0 ? seq_length : max_out_tokens);
int d0_out_stride = hidden_dim * (cnt == 0 ? seq_length : max_out_tokens);
const float4* vals_vec = reinterpret_cast<const float4*>(vals);
float4* output_vec =
reinterpret_cast<float4*>(cnt == 0 ? output : (cnt == 1 ? k_cache : v_cache));
vals_vec += (d0 * d0_stride * (gridDim.z / head_ext));
vals_vec += (d1 * d1_stride * (gridDim.z / head_ext));
vals_vec += (cnt * d1_stride);
vals_vec += (d2 * d2_stride);
output_vec += (d1 * d2_stride);
output_vec += (d0 * d0_out_stride);
output_vec += (d2 * d2_out_stride);
unsigned seq_id = d1 + seq_offset;
float4 inputs = vals_vec[d3];
int lane = d3 & 0x1f;
if (cnt < 2 && rotary_dim > 0 && d3 < rotary_dim) {
float4 q = vals_vec[d3];
float2* q_f = reinterpret_cast<float2*>(&q);
if (rotate_every_two) {
#pragma unroll
for (int o = 0; o < 2; o++) {
float inv_freq = (float)(((d3 << 1) + o) * 2) / (float)(rotary_dim << 2);
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
q_f[o].x = (-1.0 * q_f[o].y * sinf(inv_freq) + q_f[o].x * cosf(inv_freq));
q_f[o].y = (q_f[o].x * sinf(inv_freq) + q_f[o].y * cosf(inv_freq));
}
}
output_vec[d3] = q;
} else
output_vec[d3] = inputs;
}
#define ATTN_H 3
#define MAX_SEQ_LINE 10
__global__ void bias_add_transform_0213(__half* output, // q
__half* k_cache,
__half* v_cache,
const __half* vals, // qkv
const __half* bias,
int hidden_dim,
int seq_length,
unsigned seq_offset,
int all_tokens,
int heads,
int rotary_dim,
bool rotate_half,
bool rotate_every_two,
int head_ext,
int max_out_tokens)
{
unsigned half_dim = (rotary_dim << 3) >> 1;
int d0_stride = hidden_dim * seq_length;
int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads;
int d0 = blockIdx.x; // Batch
int d1 = blockIdx.y; // Sequence ID (0-127)
int cnt = blockIdx.z / head_ext; // Hidden count
int d2 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head (0-11)
int d3 = threadIdx.x; // Values (groups of 4)
int d2_out_stride = d2_stride * (cnt == 0 ? seq_length : max_out_tokens);
int d0_out_stride = hidden_dim * (cnt == 0 ? seq_length : max_out_tokens);
float4 vals_arr;
float4 output_arr;
__half2* vals_half = reinterpret_cast<__half2*>(&vals_arr);
__half2* output_half = reinterpret_cast<__half2*>(&output_arr);
const float4* vals_vec = reinterpret_cast<const float4*>(vals);
float4* output_vec =
reinterpret_cast<float4*>(cnt == 0 ? output : (cnt == 1 ? k_cache : v_cache));
vals_vec += (d0 * d0_stride * (gridDim.z / head_ext));
vals_vec += (d1 * d1_stride * (gridDim.z / head_ext));
vals_vec += (cnt * d1_stride);
vals_vec += (d2 * d2_stride);
output_vec += (d1 * d2_stride);
output_vec += (d0 * d0_out_stride);
output_vec += (d2 * d2_out_stride);
unsigned seq_id = d1 + seq_offset;
int lane = d3 & 0x1f;
if (cnt < 2 && rotary_dim > 0 && d3 < rotary_dim) {
float4 q = vals_vec[d3];
__half2* q_h = reinterpret_cast<__half2*>(&q);
if (rotate_every_two) {
#pragma unroll
for (int o = 0; o < 4; o++) {
float inv_freq = (float)(((d3 << 2) + o) * 2) / (float)(rotary_dim << 3);
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q_data[2];
q_data[0] = (float)q_h[o].x;
q_data[1] = (float)q_h[o].y;
q_h[o].x = (__half)(-1.0 * q_data[1] * sinf(inv_freq) + q_data[0] * cosf(inv_freq));
q_h[o].y = (__half)(q_data[0] * sinf(inv_freq) + q_data[1] * cosf(inv_freq));
}
}
output_vec[d3] = q;
} else
output_vec[d3] = vals_vec[d3];
}
// [B S C*H] - > C * [B A S N]
template <>
void launch_bias_add_transform_0213<float>(float* output,
float* k_cache,
float* v_cache,
const float* vals,
const float* bias,
int batch_size,
int seq_length,
unsigned seq_offset,
int all_tokens,
int hidden_dim,
int heads,
int rotary_dim,
bool rotate_half,
bool rotate_every_two,
hipStream_t stream,
int trans_count,
int max_out_tokens)
{
hidden_dim >>= 2;
int head_ext = (hidden_dim - 1) / MAX_THREADS + 1;
dim3 block_dim(hidden_dim / heads, (heads / head_ext));
dim3 grid_dim(batch_size, seq_length, (trans_count * head_ext));
hipLaunchKernelGGL(( bias_add_transform_0213), dim3(grid_dim), dim3(block_dim), 0, stream, output,
k_cache,
v_cache,
vals,
bias,
hidden_dim,
seq_length,
seq_offset,
heads,
rotary_dim >> 2,
rotate_half,
rotate_every_two,
head_ext,
max_out_tokens);
}
template <typename T>
void launch_bias_add_transform_0213(T* outputs,
T* vals,
T* vals1,
const T* vals2,
const T* bias,
int batch_size,
int seq_length,
unsigned seq_offset,
int seq_length1,
int hidden_dim,
int heads,
int rotary_dim,
bool rotate_half,
bool rotate_every_two,
hipStream_t stream,
int trans_count,
int max_out_tokens);
template <>
void launch_bias_add_transform_0213<__half>(__half* output,
__half* k_cache,
__half* v_cache,
const __half* vals,
const __half* bias,
int batch_size,
int seq_length,
unsigned seq_offset,
int all_tokens,
int hidden_dim,
int heads,
int rotary_dim,
bool rotate_half,
bool rotate_every_two,
hipStream_t stream,
int trans_count,
int max_out_tokens)
{
hidden_dim >>= 3;
int head_ext = 1; // (hidden_dim - 1) / MAX_THREADS + 1;
dim3 block_dim(hidden_dim / heads, (heads / head_ext));
dim3 grid_dim(batch_size, seq_length, (trans_count * head_ext));
hipLaunchKernelGGL(( bias_add_transform_0213), dim3(grid_dim), dim3(block_dim), 0, stream, output,
k_cache,
v_cache,
vals,
bias,
hidden_dim,
seq_length,
seq_offset,
all_tokens,
heads,
rotary_dim >> 3,
rotate_half,
rotate_every_two,
head_ext,
max_out_tokens);
}
// Bias add
__global__ void pad_add_transform_0213(float* output,
const float* vals,
int hidden_dim,
int seq_length,
int padded_seq_len,
int heads,
int padded_head_size)
{
}
__global__ void pad_add_transform_0213(__half* output,
const __half* vals,
int hidden_dim,
int seq_length,
int padded_seq_len,
int heads,
int padded_head_size)
{
float4 ZERO;
const __half2 zero_h = __float2half2_rn(0.f);
__half2* ZERO_h = reinterpret_cast<__half2*>(&ZERO);
#pragma unroll
for (int i = 0; i < 4; i++) ZERO_h[i] = zero_h;
int d0_stride = hidden_dim * seq_length;
int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads;
int d0 = blockIdx.x; // Batch
int d1 = blockIdx.y * blockDim.z + threadIdx.z; // Sequence ID (0-127)
int d2 = threadIdx.y; // Head (0-11)
int d3 = threadIdx.x; // Values (groups of 4)
int d2_out_stride = padded_head_size * padded_seq_len;
int d0_out_stride = heads * d2_out_stride;
const float4* vals_vec = reinterpret_cast<const float4*>(vals);
float4* output_vec = reinterpret_cast<float4*>(output);
vals_vec += (d0 * d0_stride);
vals_vec += (d1 * d1_stride);
vals_vec += (d2 * d2_stride);
output_vec += (d1 * padded_head_size);
output_vec += (d0 * d0_out_stride);
output_vec += (d2 * d2_out_stride);
if (d3 < d2_stride && d1 < seq_length)
output_vec[d3] = vals_vec[d3];
else
output_vec[d3] = ZERO;
}
template <typename T>
void launch_pad_add_transform_0213(T* output,
const T* vals,
int batch_size,
int hidden_dim,
int seq_length,
int padded_seq_len,
int heads,
int padded_head_size,
hipStream_t stream);
// [B S C*H] - > C * [B A S N]
template <>
void launch_pad_add_transform_0213<float>(float* output,
const float* vals,
int batch_size,
int hidden_dim,
int seq_length,
int padded_seq_len,
int heads,
int padded_head_size,
hipStream_t stream)
{
}
template <>
void launch_pad_add_transform_0213<__half>(__half* output,
const __half* vals,
int batch_size,
int hidden_dim,
int seq_length,
int padded_seq_len,
int heads,
int padded_head_size,
hipStream_t stream)
{
hidden_dim >>= 3;
dim3 block_dim((padded_head_size >> 3), heads, 2);
dim3 grid_dim(batch_size, padded_seq_len / 2);
hipLaunchKernelGGL(( pad_add_transform_0213), dim3(grid_dim), dim3(block_dim), 0, stream,
output, vals, hidden_dim, seq_length, padded_seq_len, heads, padded_head_size >> 3);
}
// Bias add
template <typename T>
__global__ void bias_add_transform_0213(T* output,
const T* vals,
const T* bias,
int hidden_dim,
int seq_length,
int heads,
int head_ext);
template <>
__global__ void bias_add_transform_0213<float>(float* output,
const float* vals,
const float* bias,
int hidden_dim,
int seq_length,
int heads,
int head_ext)
{
int d0_stride = hidden_dim * seq_length;
int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads;
int d0_out_stride = d0_stride;
int d1_out_stride = d2_stride;
int d2_out_stride = d2_stride * seq_length;
int d0 = blockIdx.x; // Batch
int d1 = blockIdx.y; // Sequence ID (0-127)
int cnt = blockIdx.z / head_ext; // Hidden count
int d2 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head (0-11)
int d3 = threadIdx.x; // Values (groups of 4)
const float4* vals_vec = reinterpret_cast<const float4*>(vals);
const float4* bias_vec = reinterpret_cast<const float4*>(bias);
float4* output_vec = reinterpret_cast<float4*>(output);
float4 inputs = vals_vec[d0 * d0_stride * (gridDim.z / head_ext) + cnt * d1_stride +
d1 * d1_stride * (gridDim.z / head_ext) + d2 * d2_stride + d3];
float4 biases = bias_vec[cnt * d1_stride + d2 * d2_stride + d3];
float4 outputs;
outputs.x = inputs.x + biases.x;
outputs.y = inputs.y + biases.y;
outputs.z = inputs.z + biases.z;
outputs.w = inputs.w + biases.w;
output_vec[cnt * d0_out_stride * gridDim.x + d0 * d0_out_stride + d1 * d1_out_stride +
d2 * d2_out_stride + d3] = outputs;
}
template <>
__global__ void bias_add_transform_0213<__half>(__half* output,
const __half* vals,
const __half* bias,
int hidden_dim,
int seq_length,
int heads,
int head_ext)
{
int d0_stride = hidden_dim * seq_length;
int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads;
int d2_out_stride = d2_stride * seq_length;
int d0 = blockIdx.x; // Batch
int d1 = blockIdx.y; // Sequence ID (0-127)
int cnt = blockIdx.z / head_ext; // Hidden count
int d2 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head (0-11)
int d3 = threadIdx.x; // Values (groups of 4)
float4 vals_arr;
float4 bias_arr;
float4 output_arr;
__half2* vals_half = reinterpret_cast<__half2*>(&vals_arr);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_arr);
__half2* output_half = reinterpret_cast<__half2*>(&output_arr);
const float4* vals_vec = reinterpret_cast<const float4*>(vals);
const float4* bias_vec = reinterpret_cast<const float4*>(bias);
float4* output_vec = reinterpret_cast<float4*>(output);
vals_vec += (d0 * d0_stride * (gridDim.z / head_ext));
vals_vec += (d1 * d1_stride * (gridDim.z / head_ext));
vals_vec += (cnt * d1_stride);
vals_vec += (d2 * d2_stride);
bias_vec += (cnt * d1_stride);
bias_vec += (d2 * d2_stride);
output_vec += (cnt * d0_stride * gridDim.x);
output_vec += (d1 * d2_stride);
output_vec += (d0 * d0_stride);
output_vec += (d2 * d2_out_stride);
bias_arr = bias_vec[d3];
vals_arr = vals_vec[d3];
output_half[0] = vals_half[0] + bias_half[0];
output_half[1] = vals_half[1] + bias_half[1];
output_half[2] = vals_half[2] + bias_half[2];
output_half[3] = vals_half[3] + bias_half[3];
output_vec[d3] = output_arr;
}
__global__ void bias_add_transform_0213_v2(__half* output,
const __half* vals,
const __half* bias,
int hidden_dim,
int seq_length,
int heads)
{
__shared__ float4 in_data[3072];
int d0_stride = hidden_dim * seq_length;
int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads;
int iteration_stride = d1_stride * blockDim.z; // Hidden * 3 / 8
int batch_stride = d0_stride * blockDim.z; // Hidden * S * 3 / 8
int d0_out_stride = d0_stride;
int d1_out_stride = d2_stride;
int d2_out_stride = d2_stride * seq_length;
int d0 = blockIdx.x; // Batch
int d1 = blockIdx.y; // Sequence ID (0-127)
int cnt = threadIdx.z; // blockIdx.z; // Hidden count
int d2 = threadIdx.y; // Head (0-11)
int d3 = threadIdx.x; // Values (groups of 4)
float4 vals_arr[1];
float4 bias_arr[1];
float4 output_arr[1];
__half2* vals_half = reinterpret_cast<__half2*>(vals_arr);
__half2* bias_half = reinterpret_cast<__half2*>(bias_arr);
__half2* output_half = reinterpret_cast<__half2*>(output_arr);
const float4* vals_vec = reinterpret_cast<const float4*>(vals);
const float4* bias_vec = reinterpret_cast<const float4*>(bias);
float4* output_vec = reinterpret_cast<float4*>(output);
int iter_index = cnt * d1_stride + d2 * d2_stride + d3;
int input_offset = d0 * batch_stride + d1 * (iteration_stride << 1);
bias_arr[0] = bias_vec[iter_index];
#pragma unroll
for (int iter = 0; iter < 2; iter++) {
int iter_id = iter * iteration_stride + iter_index;
vals_arr[0] = vals_vec[input_offset + iter_id];
output_half[0] = vals_half[0] + bias_half[0];
output_half[1] = vals_half[1] + bias_half[1];
output_half[2] = vals_half[2] + bias_half[2];
output_half[3] = vals_half[3] + bias_half[3];
in_data[iter_id] = output_arr[0];
}
__syncthreads();
iteration_stride = blockDim.z * (blockDim.y >> 1);
int matrix_stride = (d0_out_stride * gridDim.x);
int head_count = (d2 >> 1) + cnt * (blockDim.y >> 1);
int out_index = d0 * d0_out_stride + d1 * (d1_out_stride << 1) + d3 + (d2 % 2) * d2_stride;
#pragma unroll
for (int iter = 0; iter < 2; iter++) {
int iter_row = (iter * iteration_stride) + head_count;
int iter_offset =
(iter_row % blockDim.y) * d2_out_stride + (iter_row / blockDim.y) * matrix_stride;
output_vec[out_index + iter_offset] =
in_data[iter_row * d2_stride + d3 + (d2 % 2) * (d1_stride * blockDim.z)];
}
}
template <typename T>
__global__ void transform4d_0213(T* out,
const T* in,
int heads,
int seq_length,
int hidden_dim,
int head_ext);
template <>
__global__ void transform4d_0213<float>(float* out,
const float* in,
int heads,
int seq_length,
int hidden_dim,
int head_ext)
{
int d0_stride = hidden_dim * seq_length;
int d1_stride = d0_stride / heads;
int d2_stride = hidden_dim / heads;
int d0_out_stride = d0_stride;
int d1_out_stride = d2_stride;
int d2_out_stride = hidden_dim;
int d0 = blockIdx.x; // Batch
int d1 = blockIdx.y / ((seq_length - 1) / blockDim.y + 1); // Head
int d2 = (threadIdx.y + blockDim.y * blockIdx.y) % seq_length;
int cnt = blockIdx.z;
int d3 = threadIdx.x; // Values (groups of 8)
if (d2 < seq_length) {
const float4* in_vec = reinterpret_cast<const float4*>(in);
float4* out_vec = reinterpret_cast<float4*>(out);
float4 vals_vec = in_vec[cnt * d0_stride * gridDim.x + d0 * d0_stride + d1 * d1_stride +
d2 * d2_stride + d3];
out_vec[d0 * d0_out_stride * gridDim.z + cnt * d2_out_stride + d1 * d1_out_stride +
d2 * d2_out_stride * gridDim.z + d3] = vals_vec;
}
}
template <>
__global__ void transform4d_0213<__half>(__half* out,
const __half* in,
int heads,
int seq_length,
int hidden_dim,
int head_ext)
{
int d0_stride = hidden_dim * (seq_length / head_ext);
int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads;
int d0 = blockIdx.x; // Batch
int d1 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head
int d2 = blockIdx.z / head_ext; // Sequence
int cnt = blockIdx.y; // Hidden count
int d3 = threadIdx.x; // Values (groups of 8)
const float4* in_vec = reinterpret_cast<const float4*>(in);
float4* out_vec = reinterpret_cast<float4*>(out);
in_vec += (cnt * d0_stride * gridDim.x);
in_vec += (d0 * d0_stride);
in_vec += (d2 * d2_stride);
in_vec += (d1 * d2_stride * seq_length);
out_vec += (cnt * d1_stride);
out_vec += (d1 * d2_stride);
out_vec += (d0 * d0_stride * gridDim.y);
out_vec += (d2 * d1_stride * gridDim.y);
out_vec[d3] = in_vec[d3];
}
__global__ void transform4d_0213_v2(__half* out,
const __half* in,
int heads,
int seq_length,
int hidden_dim)
{
__shared__ float4 in_data[3072];
int d0_stride = hidden_dim * seq_length;
int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads;
int d0 = blockIdx.x; // Batch
int d1 = threadIdx.y; // Head
int d2 = blockIdx.y; // Sequence
int cnt = threadIdx.z; // Hidden count
int d3 = threadIdx.x; // Values (groups of 8)
const float4* in_vec = reinterpret_cast<const float4*>(in);
float4* out_vec = reinterpret_cast<float4*>(out);
int input_offset = d0 * d0_stride + d2 * (d2_stride << 1) + d3 + (d1 % 2) * d2_stride;
int head_count = (d1 >> 1) + cnt * (blockDim.y >> 1);
int iteration_stride = blockDim.z * (blockDim.y >> 1);
int matrix_stride = (d0_stride * gridDim.x);
#pragma unroll
for (int iter = 0; iter < 2; iter++) {
int iter_row = iter * iteration_stride + head_count;
int iter_offset = (iter_row % blockDim.y) * d2_stride;
in_data[d3 + iter_offset + (iter_row / blockDim.y + (d1 % 2) * blockDim.z) * d1_stride] =
in_vec[input_offset + iter_offset * seq_length +
(iter_row / blockDim.y) * matrix_stride];
}
__syncthreads();
iteration_stride = d1_stride * blockDim.z;
int iter_index = cnt * d1_stride + d1 * d2_stride + d3;
int output_offset = d0 * d0_stride * blockDim.z + d2 * (iteration_stride << 1);
#pragma unroll
for (int iter = 0; iter < 2; iter++) {
int iter_id = iter * iteration_stride + iter_index;
out_vec[output_offset + iter_id] = in_data[iter_id];
}
}
// 3 * [B A S N] - > [B S C*H]
template <>
void launch_transform4d_0213<float>(float* out,
const float* in,
int batch_size,
int heads,
int seq_length,
int hidden_dim,
hipStream_t stream,
int trans_count)
{
hidden_dim >>= 2;
dim3 grid_dims(batch_size, heads * ((seq_length - 1) / 8 + 1), trans_count);
dim3 block_dims(hidden_dim / heads, 8);
hipLaunchKernelGGL(( transform4d_0213<float>)
, dim3(grid_dims), dim3(block_dims), 0, stream, out, in, heads, seq_length, hidden_dim, 1);
}
template <>
void launch_transform4d_0213<__half>(__half* out,
const __half* in,
int batch_size,
int heads,
int seq_length,
int hidden_dim,
hipStream_t stream,
int trans_count)
{
hidden_dim >>= 3;
int head_ext = (hidden_dim - 1) / MAX_THREADS + 1;
dim3 grid_dims(batch_size, trans_count, (seq_length * head_ext));
dim3 block_dims(hidden_dim / heads, (heads / head_ext));
hipLaunchKernelGGL(( transform4d_0213<__half>)
, dim3(grid_dims), dim3(block_dims), 0, stream, out, in, heads, seq_length, hidden_dim, head_ext);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment