Commit eadbbe09 authored by 401qingkong's avatar 401qingkong
Browse files

push rocm deepspeed v0.3.13

parent ab5534fc
#include "hip/hip_runtime.h"
#include "hip/custom_hip_layers.h"
namespace cg = cooperative_groups;
/*
Fused bias add, residual (elementwise) add, and normalization layer.
For FP16, this kernel does not promote to FP32 in order to utilize the 2x throughput for
__half2 instructions, and avoid the conversion overhead (1/8 of __hal2 arithmetic).
For specific launch constraints, see the launch functions.
*/
#define NORM_REG (MAX_REGISTERS / 4)
__global__ void fused_bias_residual_layer_norm(float* vals,
const float* residual,
const float* gamma,
const float* beta,
float epsilon,
bool preLayerNorm,
bool training,
float* vars,
float* means,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id / WARP_SIZE;
float vals_arr[NORM_REG];
__shared__ float shr[MAX_WARP_NUM];
residual += (row * row_stride);
vals += (row * row_stride);
float sum = 0.f;
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
vals_arr[i] = residual[i * iteration_stride + id];
sum += vals_arr[i];
}
if (high_index < row_stride) {
vals_arr[iterations] = residual[high_index];
sum += vals_arr[iterations];
iterations++;
}
for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (iteration_stride >> 5)) sum = shr[g.thread_rank()];
#if !defined(__STOCHASTIC_MODE__) || __CUDA_ARCH__ < 700
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> 5); i *= 2) { sum += g.shfl_down(sum, i); }
sum = g.shfl(sum, 0);
float mean = sum / row_stride;
if (training)
if (g.thread_rank() == 0) means[row] = mean;
float variance = 0.f;
for (int i = 0; i < iterations; i++) {
vals_arr[i] -= mean;
variance += vals_arr[i] * vals_arr[i];
}
for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
if (g.thread_rank() == 0) shr[gid] = variance;
b.sync();
if (g.thread_rank() < (iteration_stride >> 5)) variance = shr[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> 5); i *= 2) { variance += g.shfl_down(variance, i); }
variance = g.shfl(variance, 0);
variance /= row_stride;
variance += epsilon;
if (training)
if (g.thread_rank() == 0) vars[row] = variance;
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) {
vals_arr[i] = vals_arr[i] * rsqrtf(variance);
vals_arr[i] =
vals_arr[i] * gamma[i * iteration_stride + id] + beta[i * iteration_stride + id];
vals[i * iteration_stride + id] = vals_arr[i];
}
if ((high_index) < row_stride) {
vals_arr[iterations] = vals_arr[iterations] * rsqrtf(variance);
vals_arr[iterations] = vals_arr[iterations] * gamma[high_index] + beta[high_index];
vals[high_index] = vals_arr[iterations];
}
}
__global__ void fused_bias_residual_layer_norm(__half* vals,
const __half* residual,
const __half* gamma,
const __half* beta,
float epsilon,
bool preLayerNorm,
bool training,
__half* vars,
__half* means,
int row_stride)
{
#if __CUDA_ARCH__ >= 700
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id >> 5;
float2 vals_f[NORM_REG];
__shared__ float shr[MAX_WARP_NUM];
__half2* vals_cast = reinterpret_cast<__half2*>(vals);
const __half2* residual_cast = reinterpret_cast<const __half2*>(residual);
residual_cast += (row * row_stride);
vals_cast += (row * row_stride);
float sum = 0.f;
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
vals_f[i] = __half22float2(residual_cast[i * iteration_stride + id]);
sum += vals_f[i].x;
sum += vals_f[i].y;
}
if ((high_index) < row_stride) {
vals_f[iterations] = __half22float2(residual_cast[high_index]);
sum += vals_f[iterations].x;
sum += vals_f[iterations].y;
iterations++;
}
for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (iteration_stride >> 5)) sum = shr[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> 5); i *= 2) { sum += g.shfl_down(sum, i); }
sum = g.shfl(sum, 0);
float mean = sum / (row_stride * 2);
float variance = 0.f;
for (int i = 0; i < iterations; i++) {
vals_f[i].x -= mean;
vals_f[i].y -= mean;
variance += vals_f[i].x * vals_f[i].x;
variance += vals_f[i].y * vals_f[i].y;
}
for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
if (g.thread_rank() == 0) shr[gid] = variance;
b.sync();
if (g.thread_rank() < (iteration_stride >> 5)) variance = shr[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> 5); i *= 2) { variance += g.shfl_down(variance, i); }
variance = g.shfl(variance, 0);
variance /= (row_stride * 2);
variance += epsilon;
__half2 variance_h = __float2half2_rn(variance);
const __half2* gamma_cast = reinterpret_cast<const __half2*>(gamma);
const __half2* beta_cast = reinterpret_cast<const __half2*>(beta);
if (training && g.thread_rank() == 0) {
vars[row] = __float2half(variance);
means[row] = __float2half(mean);
}
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) {
__half2 vals_arr = __float22half2_rn(vals_f[i]);
vals_arr = vals_arr * h2rsqrt(variance_h);
vals_arr =
vals_arr * gamma_cast[i * iteration_stride + id] + beta_cast[i * iteration_stride + id];
vals_cast[i * iteration_stride + id] = vals_arr;
}
if ((high_index) < row_stride) {
__half2 vals_arr = __float22half2_rn(vals_f[iterations]);
vals_arr = vals_arr * h2rsqrt(variance_h);
vals_arr = vals_arr * gamma_cast[high_index] + beta_cast[high_index];
vals_cast[high_index] = vals_arr;
}
#endif
}
template <typename T>
void launch_bias_residual_layer_norm(T* vals,
const T* residual,
const T* gamma,
const T* beta,
float epsilon,
int batch_size,
int hidden_dim,
hipStream_t stream,
bool preLayerNorm,
bool training,
T* vars,
T* means);
template <>
void launch_bias_residual_layer_norm<float>(float* vals,
const float* residual,
const float* gamma,
const float* beta,
float epsilon,
int batch_size,
int hidden_dim,
hipStream_t stream,
bool preLayerNorm,
bool training,
float* vars,
float* means)
{
int threads = THREADS;
dim3 grid_dim(batch_size);
if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 1;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 2;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim(threads);
hipLaunchKernelGGL(( fused_bias_residual_layer_norm), dim3(grid_dim), dim3(block_dim), 0, stream,
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, hidden_dim);
}
template <>
void launch_bias_residual_layer_norm<__half>(__half* vals,
const __half* residual,
const __half* gamma,
const __half* beta,
float epsilon,
int batch_size,
int hidden_dim,
hipStream_t stream,
bool preLayerNorm,
bool training,
__half* vars,
__half* means)
{
int threads = 128;
dim3 grid_dim(batch_size);
if (hidden_dim > 8192 && hidden_dim <= 16384)
threads <<= 1;
else if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 2;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 3;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim(threads);
hipLaunchKernelGGL(( fused_bias_residual_layer_norm), dim3(grid_dim), dim3(block_dim), 0, stream,
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, hidden_dim / 2);
}
__global__ void fused_bias_residual_layer_norm(float* vals,
const float* residual,
const float* gamma,
const float* beta,
float epsilon,
bool preLayerNorm,
bool training,
float* vars,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id / 32;
float vals_arr[NORM_REG];
__shared__ float shr[MAX_WARP_NUM];
residual += (row * row_stride);
vals += (row * row_stride);
float sum = 0.f;
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
vals_arr[i] = residual[i * iteration_stride + id];
sum += vals_arr[i];
}
if ((high_index) < row_stride) {
vals_arr[iterations] = residual[high_index];
sum += vals_arr[iterations];
iterations++;
}
for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (iteration_stride >> 5)) sum = shr[g.thread_rank()];
#if !defined(__STOCHASTIC_MODE__) || __CUDA_ARCH__ < 700
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> 5); i *= 2) { sum += g.shfl_down(sum, i); }
sum = g.shfl(sum, 0);
float mean = sum / row_stride;
float variance = 0.f;
for (int i = 0; i < iterations; i++) {
vals_arr[i] -= mean;
variance += vals_arr[i] * vals_arr[i];
}
for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
if (g.thread_rank() == 0) shr[gid] = variance;
b.sync();
if (g.thread_rank() < (iteration_stride >> 5)) variance = shr[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> 5); i *= 2) { variance += g.shfl_down(variance, i); }
variance = g.shfl(variance, 0);
variance /= row_stride;
variance += epsilon;
if (training)
if (g.thread_rank() == 0) vars[row] = variance;
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) {
vals_arr[i] = vals_arr[i] * rsqrtf(variance);
vals_arr[i] =
vals_arr[i] * gamma[i * iteration_stride + id] + beta[i * iteration_stride + id];
vals[i * iteration_stride + id] = vals_arr[i];
}
if ((high_index) < row_stride) {
vals_arr[iterations] = vals_arr[iterations] * rsqrtf(variance);
vals_arr[iterations] = vals_arr[iterations] * gamma[high_index] + beta[high_index];
vals[high_index] = vals_arr[iterations];
}
}
__global__ void fused_bias_residual_layer_norm(__half* vals,
const __half* residual,
const __half* gamma,
const __half* beta,
float epsilon,
bool preLayerNorm,
bool training,
__half* vars,
int row_stride)
{
#if __CUDA_ARCH__ >= 700
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id >> 5;
float2 vals_f[NORM_REG];
__shared__ float shr[MAX_WARP_NUM];
__half2* vals_cast = reinterpret_cast<__half2*>(vals);
const __half2* residual_cast = reinterpret_cast<const __half2*>(residual);
residual_cast += (row * row_stride);
vals_cast += (row * row_stride);
float sum = 0.f;
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
vals_f[i] = __half22float2(residual_cast[i * iteration_stride + id]);
sum += vals_f[i].x;
sum += vals_f[i].y;
}
if ((high_index) < row_stride) {
vals_f[iterations] = __half22float2(residual_cast[high_index]);
sum += vals_f[iterations].x;
sum += vals_f[iterations].y;
iterations++;
}
for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (iteration_stride >> 5)) sum = shr[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> 5); i *= 2) { sum += g.shfl_down(sum, i); }
sum = g.shfl(sum, 0);
float mean = sum / (row_stride * 2);
float variance = 0.f;
for (int i = 0; i < iterations; i++) {
vals_f[i].x -= mean;
vals_f[i].y -= mean;
variance += vals_f[i].x * vals_f[i].x;
variance += vals_f[i].y * vals_f[i].y;
}
for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
if (g.thread_rank() == 0) shr[gid] = variance;
b.sync();
if (g.thread_rank() < (iteration_stride >> 5)) variance = shr[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> 5); i *= 2) { variance += g.shfl_down(variance, i); }
variance = g.shfl(variance, 0);
variance /= (row_stride * 2);
variance += epsilon;
__half2 variance_h = __float2half2_rn(variance);
const __half2* gamma_cast = reinterpret_cast<const __half2*>(gamma);
const __half2* beta_cast = reinterpret_cast<const __half2*>(beta);
if (training && g.thread_rank() == 0) vars[row] = __float2half(variance);
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) {
__half2 vals_arr = __float22half2_rn(vals_f[i]);
vals_arr = vals_arr * h2rsqrt(variance_h);
vals_arr =
vals_arr * gamma_cast[i * iteration_stride + id] + beta_cast[i * iteration_stride + id];
vals_cast[i * iteration_stride + id] = vals_arr;
}
if ((high_index) < row_stride) {
__half2 vals_arr = __float22half2_rn(vals_f[iterations]);
vals_arr = vals_arr * h2rsqrt(variance_h);
vals_arr = vals_arr * gamma_cast[high_index] + beta_cast[high_index];
vals_cast[high_index] = vals_arr;
}
#endif
}
template <typename T>
void launch_bias_residual_layer_norm(T* vals,
const T* residual,
const T* gamma,
const T* beta,
float epsilon,
int batch_size,
int hidden_dim,
hipStream_t stream,
bool preLayerNorm,
bool training,
T* vars);
/*
To tune this launch the following restrictions must be met:
For float:
row_stride == hidden_size
threads * iterations == row_stride
threads is in [32, 64, 128, 256, 512, 1024]
For half:
row_stride == hidden_size / 2
threads * iterations == row_stride
threads is in [32, 64, 128, 256, 512, 1024]
*/
template <>
void launch_bias_residual_layer_norm<float>(float* vals,
const float* residual,
const float* gamma,
const float* beta,
float epsilon,
int batch_size,
int hidden_dim,
hipStream_t stream,
bool preLayerNorm,
bool training,
float* vars)
{
int threads = THREADS;
dim3 grid_dim(batch_size);
// There are some limitations to call below functions, now just enumerate the situations.
if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 1;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 2;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim(threads);
hipLaunchKernelGGL(( fused_bias_residual_layer_norm), dim3(grid_dim), dim3(block_dim), 0, stream,
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, hidden_dim);
}
template <>
void launch_bias_residual_layer_norm<__half>(__half* vals,
const __half* residual,
const __half* gamma,
const __half* beta,
float epsilon,
int batch_size,
int hidden_dim,
hipStream_t stream,
bool preLayerNorm,
bool training,
__half* vars)
{
int threads = 128;
dim3 grid_dim(batch_size);
// There are some limitations to call below functions, now just enumerate the situations.
if (hidden_dim > 8192 && hidden_dim <= 16384)
threads <<= 1;
else if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 2;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 3;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim(threads);
hipLaunchKernelGGL(( fused_bias_residual_layer_norm), dim3(grid_dim), dim3(block_dim), 0, stream,
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, hidden_dim / 2);
}
/* Normalize Gamma & Betta gradients
* Compute gradients using either X_hat or
* normalize input (invertible).
* Combine transpose with gradients computation.
*/
template <typename T>
__global__ void LayerNormBackward1(const T* __restrict__ out_grad,
const T* __restrict__ vals_hat,
const T* __restrict__ gamma,
const T* __restrict__ betta,
T* __restrict__ gamma_grad,
T* __restrict__ betta_grad,
int rows,
int width,
bool invertible)
{
__shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1];
__shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1];
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int offset = threadIdx.y * width + idx;
int y_stride = width * TILE_DIM;
float betta_reg = (invertible ? (float)betta[idx] : 0.0f);
float gamma_reg = (float)gamma[idx];
// Loop across matrix height
float betta_tmp = 0;
float gamma_tmp = 0;
for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
float grad = (float)out_grad[offset];
float val = (invertible ? ((float)vals_hat[offset] - betta_reg) / gamma_reg
: (float)vals_hat[offset]);
betta_tmp += grad;
gamma_tmp += (val * grad);
offset += y_stride;
}
betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp;
gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp;
__syncthreads();
// Sum the shared buffer.
float s1 = betta_buffer[threadIdx.y][threadIdx.x];
float s2 = gamma_buffer[threadIdx.y][threadIdx.x];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < TILE_DIM; i <<= 1) {
s1 += g.shfl_down(s1, i);
s2 += g.shfl_down(s2, i);
}
if (threadIdx.x == 0) {
int pos = blockIdx.x * TILE_DIM + threadIdx.y;
betta_grad[pos] = s1;
gamma_grad[pos] = s2;
}
}
/* Normalize Gamma & Betta gradients
* Compute gradients using the input to
* the normalize.
* Combine transpose with gradients computation.
*/
template <typename T>
__global__ void LayerNormBackward1(const T* __restrict__ out_grad,
const T* __restrict__ X_data,
const T* __restrict__ vars,
const T* __restrict__ means,
T* __restrict__ gamma_grad,
T* __restrict__ betta_grad,
int rows,
int width)
{
__shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1];
__shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1];
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int offset = threadIdx.y * width + idx;
int y_stride = width * TILE_DIM;
int pos = blockIdx.x * TILE_DIM + threadIdx.y;
// Loop across matrix height
float betta_tmp = 0;
float gamma_tmp = 0;
for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
float grad = (float)out_grad[offset];
float val = (float)X_data[offset];
val = (val - (float)means[r]) * rsqrtf((float)vars[r]);
betta_tmp += grad;
gamma_tmp += (val * grad);
offset += y_stride;
}
betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp;
gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp;
__syncthreads();
// Sum the shared buffer.
float s1 = betta_buffer[threadIdx.y][threadIdx.x];
float s2 = gamma_buffer[threadIdx.y][threadIdx.x];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < TILE_DIM; i <<= 1) {
s1 += g.shfl_down(s1, i);
s2 += g.shfl_down(s2, i);
}
if (threadIdx.x == 0) {
betta_grad[pos] = s1;
gamma_grad[pos] = s2;
}
}
/*
/* Backward Normalize (Input-Gradient)
* Using the means and variances from the input
* This type of backward is invertible!
* We do the backward using the X_hat (X - u) / sqrt(variance) or the output of Normalization.
*/
__global__ void LayerNormBackward2(const float* out_grad,
const float* vals_hat,
const float* gamma,
const float* betta,
const float* vars,
float* inp_grad,
bool invertible,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id / WARP_SIZE;
int warp_num = (THREADS < row_stride ? THREADS : row_stride) / WARP_SIZE;
__shared__ float partialSum[MAX_WARP_NUM];
out_grad += (row * row_stride);
vals_hat += (row * row_stride);
inp_grad += (row * row_stride);
float vals_arr[NORM_REG];
float vals_hat_arr[NORM_REG];
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
float gamma_reg = gamma[i * iteration_stride + id];
vals_arr[i] = out_grad[i * iteration_stride + id];
vals_arr[i] *= gamma_reg;
vals_hat_arr[i] =
(invertible ? (vals_hat[i * iteration_stride + id] - betta[i * iteration_stride + id]) /
gamma_reg
: vals_hat[i * iteration_stride + id]);
}
if ((high_index) < row_stride) {
float gamma_reg = gamma[high_index];
vals_arr[iterations] = out_grad[high_index];
vals_arr[iterations] *= gamma_reg;
vals_hat_arr[iterations] =
(invertible ? (vals_hat[high_index] - betta[high_index]) / gamma_reg
: vals_hat[high_index]);
iterations++;
}
float var_reg = vars[row];
float sum = 0;
for (int i = 0; i < iterations; i++) {
sum += vals_hat_arr[i] * vals_arr[i] *
sqrtf(var_reg); // dval_hat = gamma * (x - u) * out_grad
vals_arr[i] *= rsqrtf(var_reg); // dvar_inv = gamma * out_grad / sqrt(var)
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
for (int i = 0; i < iterations; i++) { vals_arr[i] += ((-sum * vals_hat_arr[i]) / var_reg); }
sum = 0;
for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; }
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) inp_grad[i * iteration_stride + id] = (vals_arr[i] - sum);
if ((high_index) < row_stride) inp_grad[high_index] = (vals_arr[iterations] - sum);
}
__global__ void LayerNormBackward2(const __half* out_grad,
const __half* vals_hat,
const __half* gamma,
const __half* betta,
const __half* vars,
__half* inp_grad,
bool invertible,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id / WARP_SIZE;
int warp_num = (iteration_stride < row_stride ? iteration_stride : row_stride) / WARP_SIZE;
__shared__ float partialSum[MAX_WARP_NUM];
__half2 vals_arr[NORM_REG];
float2 vals_arr_f[NORM_REG];
__half2 vals_hat_arr[NORM_REG];
__half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad);
const __half2* out_grad_h = reinterpret_cast<const __half2*>(out_grad);
const __half2* vals_hat_h = reinterpret_cast<const __half2*>(vals_hat);
inp_grad_h += (row * row_stride);
out_grad_h += (row * row_stride);
vals_hat_h += (row * row_stride);
const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma);
const __half2* betta_h = (invertible ? reinterpret_cast<const __half2*>(betta) : nullptr);
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
__half2 gamma_reg = gamma_h[i * iteration_stride + id];
vals_arr[i] = out_grad_h[i * iteration_stride + id];
vals_arr[i] *= gamma_reg;
vals_hat_arr[i] =
(invertible
? (vals_hat_h[i * iteration_stride + id] - betta_h[i * iteration_stride + id]) /
gamma_reg
: vals_hat_h[i * iteration_stride + id]);
}
if ((high_index) < row_stride) {
__half2 gamma_reg = gamma_h[high_index];
vals_arr[iterations] = out_grad_h[high_index];
vals_arr[iterations] *= gamma_reg;
vals_hat_arr[iterations] =
(invertible ? (vals_hat_h[high_index] - betta_h[high_index]) / gamma_reg
: vals_hat_h[high_index]);
iterations++;
}
__half var_h = vars[row];
__half2 var_reg = __halves2half2(var_h, var_h);
float sum = 0.f;
for (int i = 0; i < iterations; i++) {
__half2 result_h = (vals_hat_arr[i] * vals_arr[i] * h2sqrt(var_reg));
float2 result_f = __half22float2(result_h);
sum += result_f.x;
sum += result_f.y;
vals_arr[i] *= h2rsqrt(var_reg);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
__half2 sum_h = __float2half2_rn(sum);
for (int i = 0; i < iterations; i++) {
__half2 temp = ((-sum_h * vals_hat_arr[i]) / (var_reg));
vals_arr_f[i] = __half22float2(vals_arr[i]);
float2 temp_f = __half22float2(temp);
vals_arr_f[i].x += temp_f.x;
vals_arr_f[i].y += temp_f.y;
}
sum = 0.f;
for (int i = 0; i < iterations; i++) {
sum += (vals_arr_f[i].x);
sum += (vals_arr_f[i].y);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) {
vals_arr_f[i].x -= sum;
vals_arr_f[i].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[i]);
inp_grad_h[i * iteration_stride + id] = temp;
}
if ((high_index) < row_stride) {
vals_arr_f[iterations].x -= sum;
vals_arr_f[iterations].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[iterations]);
inp_grad_h[high_index] = temp;
}
}
template <>
void launch_layerNorm_backward<float>(const float* out_grad,
const float* vals_hat,
const float* vars,
const float* gamma,
float* gamma_grad,
float* betta_grad,
float* inp_grad,
int batch,
int hidden_dim,
hipStream_t stream[2],
bool invertible,
const float* betta)
{
int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
hipLaunchKernelGGL(( LayerNormBackward1<float>), dim3(grid_dim), dim3(block_dim), 0, stream[0],
out_grad, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible);
dim3 grid_dim2(batch);
if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 1;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 2;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads);
hipLaunchKernelGGL(( LayerNormBackward2), dim3(grid_dim2), dim3(block_dim2), 0, stream[1],
out_grad, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim);
}
template <>
void launch_layerNorm_backward<__half>(const __half* out_grad,
const __half* vals_hat,
const __half* vars,
const __half* gamma,
__half* gamma_grad,
__half* betta_grad,
__half* inp_grad,
int batch,
int hidden_dim,
hipStream_t stream[2],
bool invertible,
const __half* betta)
{
int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
hipLaunchKernelGGL(( LayerNormBackward1<__half>), dim3(grid_dim), dim3(block_dim), 0, stream[0],
out_grad, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible);
dim3 grid_dim2(batch);
if (hidden_dim > 8192 && hidden_dim <= 16384)
threads <<= 1;
else if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 2;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 3;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads / 2);
hipLaunchKernelGGL(( LayerNormBackward2), dim3(grid_dim2), dim3(block_dim2), 0, stream[1],
out_grad, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim / 2);
}
/* Backward Normalize (Input-Gradient)
* Using the means and variances from the input
* This type of backward is not invertible!
* We do the backward using the input (X)
*/
__global__ void LayerNormBackward2(const float* out_grad,
const float* X_vals,
const float* gamma,
const float* vars,
const float* means,
float* inp_grad,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id / WARP_SIZE;
int warp_num = (THREADS < row_stride ? THREADS : row_stride) / WARP_SIZE;
__shared__ float partialSum[MAX_WARP_NUM];
out_grad += (row * row_stride);
X_vals += (row * row_stride);
inp_grad += (row * row_stride);
float vals_arr[NORM_REG];
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
float gamma_reg = gamma[i * iteration_stride + id];
vals_arr[i] = out_grad[i * iteration_stride + id];
vals_arr[i] *= gamma_reg;
}
if ((high_index) < row_stride) {
float gamma_reg = gamma[high_index];
vals_arr[iterations] = out_grad[high_index];
vals_arr[iterations] *= gamma_reg;
iterations++;
}
float var_reg = vars[row];
float mean_reg = means[row];
float sum = 0;
float xu[NORM_REG];
for (int i = 0; i < iterations; i++) {
xu[i] = (X_vals[i * iteration_stride + id] - mean_reg);
sum += vals_arr[i] * xu[i];
vals_arr[i] *= rsqrtf(var_reg);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
for (int i = 0; i < iterations; i++) {
vals_arr[i] += (-sum * xu[i] * rsqrtf(var_reg) / (var_reg));
}
sum = 0;
for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; }
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) inp_grad[i * iteration_stride + id] = (vals_arr[i] - sum);
if ((high_index) < row_stride) inp_grad[high_index] = (vals_arr[iterations] - sum);
}
__global__ void LayerNormBackward2(const __half* out_grad,
const __half* X_vals,
const __half* gamma,
const __half* vars,
const __half* means,
__half* inp_grad,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id / WARP_SIZE;
int warp_num = (iteration_stride < row_stride ? iteration_stride : row_stride) / WARP_SIZE;
__shared__ float partialSum[MAX_WARP_NUM];
__half2 vals_arr[NORM_REG];
float2 vals_arr_f[NORM_REG];
__half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad);
const __half2* out_grad_h = reinterpret_cast<const __half2*>(out_grad);
const __half2* vals_hat_h = reinterpret_cast<const __half2*>(X_vals);
inp_grad_h += (row * row_stride);
out_grad_h += (row * row_stride);
vals_hat_h += (row * row_stride);
const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma);
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
__half2 gamma_reg = gamma_h[i * iteration_stride + id];
vals_arr[i] = out_grad_h[i * iteration_stride + id];
vals_arr[i] *= gamma_reg; // out_grad * gamma
}
if ((high_index) < row_stride) {
__half2 gamma_reg = gamma_h[high_index];
vals_arr[iterations] = out_grad_h[high_index];
vals_arr[iterations] *= gamma_reg; // out_grad * gamma
iterations++;
}
__half mean_h = means[row];
__half var_h = vars[row];
__half2 var_reg = __halves2half2(var_h, var_h);
__half2 mean_reg = __halves2half2(mean_h, mean_h);
__half2 xu[NORM_REG];
float sum = 0.f;
for (int i = 0; i < iterations; i++) {
xu[i] = (vals_hat_h[i * iteration_stride + id] - mean_reg);
__half2 result_h = (xu[i] * vals_arr[i]);
float2 result_f = __half22float2(result_h);
sum += result_f.x;
sum += result_f.y;
vals_arr[i] *= h2rsqrt(var_reg);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
__half2 sum_h = __float2half2_rn(sum);
for (int i = 0; i < iterations; i++) {
__half2 xu_grad = ((-sum_h * xu[i] * h2rsqrt(var_reg)) / (var_reg));
vals_arr_f[i] = __half22float2(vals_arr[i]);
float2 xu_grad_f = __half22float2(xu_grad);
vals_arr_f[i].x += xu_grad_f.x;
vals_arr_f[i].y += xu_grad_f.y;
}
sum = 0.f;
for (int i = 0; i < iterations; i++) {
sum += (vals_arr_f[i].x);
sum += (vals_arr_f[i].y);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) {
vals_arr_f[i].x -= sum;
vals_arr_f[i].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[i]);
inp_grad_h[i * iteration_stride + id] = temp;
}
if ((high_index) < row_stride) {
vals_arr_f[iterations].x -= sum;
vals_arr_f[iterations].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[iterations]);
inp_grad_h[high_index] = temp;
}
}
template <>
void launch_layerNorm_backward<float>(const float* out_grad,
const float* X_data,
const float* vars,
const float* means,
const float* gamma,
float* gamma_grad,
float* betta_grad,
float* inp_grad,
int batch,
int hidden_dim,
hipStream_t stream[2])
{
int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
hipLaunchKernelGGL(( LayerNormBackward1<float>), dim3(grid_dim), dim3(block_dim), 0, stream[0],
out_grad, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim);
dim3 grid_dim2(batch);
if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 1;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 2;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads);
hipLaunchKernelGGL(( LayerNormBackward2), dim3(grid_dim2), dim3(block_dim2), 0, stream[1],
out_grad, X_data, gamma, vars, means, inp_grad, hidden_dim);
}
template <>
void launch_layerNorm_backward<__half>(const __half* out_grad,
const __half* X_data,
const __half* vars,
const __half* means,
const __half* gamma,
__half* gamma_grad,
__half* betta_grad,
__half* inp_grad,
int batch,
int hidden_dim,
hipStream_t stream[2])
{
int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
hipLaunchKernelGGL(( LayerNormBackward1<__half>), dim3(grid_dim), dim3(block_dim), 0, stream[0],
out_grad, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim);
dim3 grid_dim2(batch);
if (hidden_dim > 8192 && hidden_dim <= 16384)
threads <<= 1;
else if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 2;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 3;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads / 2);
hipLaunchKernelGGL(( LayerNormBackward2), dim3(grid_dim2), dim3(block_dim2), 0, stream[1],
out_grad, X_data, gamma, vars, means, inp_grad, hidden_dim / 2);
}
template <typename T>
__global__ void LayerNormBackward1_fused_add(const T* __restrict__ out_grad1,
const T* __restrict__ out_grad2,
const T* __restrict__ vals_hat,
const T* __restrict__ gamma,
const T* __restrict__ betta,
T* __restrict__ gamma_grad,
T* __restrict__ betta_grad,
int rows,
int width,
bool invertible)
{
__shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1];
__shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1];
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int offset = threadIdx.y * width + idx;
int y_stride = width * TILE_DIM;
float betta_reg = (invertible ? (float)betta[idx] : 0.0f);
float gamma_reg = (float)gamma[idx];
// Loop across matrix height
float betta_tmp = 0;
float gamma_tmp = 0;
for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
float grad = (float)out_grad1[offset] + (float)out_grad2[offset];
float val = (invertible ? ((float)vals_hat[offset] - betta_reg) / gamma_reg
: (float)vals_hat[offset]);
betta_tmp += grad;
gamma_tmp += (val * grad);
offset += y_stride;
}
betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp;
gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp;
__syncthreads();
// Sum the shared buffer.
float s1 = betta_buffer[threadIdx.y][threadIdx.x];
float s2 = gamma_buffer[threadIdx.y][threadIdx.x];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < TILE_DIM; i <<= 1) {
s1 += g.shfl_down(s1, i);
s2 += g.shfl_down(s2, i);
}
if (threadIdx.x == 0) {
int pos = blockIdx.x * TILE_DIM + threadIdx.y;
betta_grad[pos] = s1;
gamma_grad[pos] = s2;
}
}
template <typename T>
__global__ void LayerNormBackward1_fused_add(const T* __restrict__ out_grad1,
const T* __restrict__ out_grad2,
const T* __restrict__ X_data,
const T* __restrict__ vars,
const T* __restrict__ means,
T* __restrict__ gamma_grad,
T* __restrict__ betta_grad,
int rows,
int width)
{
__shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1];
__shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1];
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int offset = threadIdx.y * width + idx;
int y_stride = width * TILE_DIM;
int pos = blockIdx.x * TILE_DIM + threadIdx.y;
// Loop across matrix height
float betta_tmp = 0;
float gamma_tmp = 0;
for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
float grad = (float)out_grad1[offset] + (float)out_grad2[offset];
float val = (float)X_data[offset];
val = (val - (float)means[r]) * rsqrtf((float)vars[r]);
betta_tmp += grad;
gamma_tmp += (val * grad);
offset += y_stride;
}
betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp;
gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp;
__syncthreads();
// Sum the shared buffer.
float s1 = betta_buffer[threadIdx.y][threadIdx.x];
float s2 = gamma_buffer[threadIdx.y][threadIdx.x];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < TILE_DIM; i <<= 1) {
s1 += g.shfl_down(s1, i);
s2 += g.shfl_down(s2, i);
}
if (threadIdx.x == 0) {
betta_grad[pos] = s1;
gamma_grad[pos] = s2;
}
}
__global__ void LayerNormBackward2_fused_add(const float* out_grad1,
const float* out_grad2,
const float* vals_hat,
const float* gamma,
const float* betta,
const float* vars,
float* inp_grad,
bool invertible,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id / WARP_SIZE;
int warp_num = (THREADS < row_stride ? THREADS : row_stride) / WARP_SIZE;
__shared__ float partialSum[MAX_WARP_NUM];
out_grad1 += (row * row_stride);
out_grad2 += (row * row_stride);
vals_hat += (row * row_stride);
inp_grad += (row * row_stride);
float vals_arr[NORM_REG];
float vals_hat_arr[NORM_REG];
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
float gamma_reg = gamma[i * iteration_stride + id];
vals_arr[i] = out_grad1[i * iteration_stride + id];
vals_arr[i] *= gamma_reg;
vals_hat_arr[i] =
(invertible ? (vals_hat[i * iteration_stride + id] - betta[i * iteration_stride + id]) /
gamma_reg
: vals_hat[i * iteration_stride + id]);
}
if ((high_index) < row_stride) {
float gamma_reg = gamma[high_index];
vals_arr[iterations] = out_grad1[high_index];
vals_arr[iterations] *= gamma_reg;
vals_hat_arr[iterations] =
(invertible ? (vals_hat[high_index] - betta[high_index]) / gamma_reg
: vals_hat[high_index]);
iterations++;
}
float var_reg = vars[row];
float sum = 0;
for (int i = 0; i < iterations; i++) {
sum += vals_hat_arr[i] * vals_arr[i] * sqrtf(var_reg);
vals_arr[i] *= rsqrtf(var_reg);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
for (int i = 0; i < iterations; i++) { vals_arr[i] += ((-sum * vals_hat_arr[i]) / var_reg); }
sum = 0;
for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; }
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++)
inp_grad[i * iteration_stride + id] =
(vals_arr[i] - sum) + out_grad2[i * iteration_stride + id];
if ((high_index) < row_stride)
inp_grad[high_index] = (vals_arr[iterations] - sum) + out_grad2[high_index];
}
__global__ void LayerNormBackward2_fused_add(const __half* out_grad1,
const __half* out_grad2,
const __half* vals_hat,
const __half* gamma,
const __half* betta,
const __half* vars,
__half* inp_grad,
bool invertible,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id / WARP_SIZE;
int warp_num = (iteration_stride < row_stride ? iteration_stride : row_stride) / WARP_SIZE;
__shared__ float partialSum[MAX_WARP_NUM];
__half2 vals_arr[NORM_REG];
float2 vals_arr_f[NORM_REG];
__half2 vals_hat_arr[NORM_REG];
// float2 result[iterations];
__half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad);
const __half2* out_grad_h1 = reinterpret_cast<const __half2*>(out_grad1);
const __half2* out_grad_h2 = reinterpret_cast<const __half2*>(out_grad2);
const __half2* vals_hat_h = reinterpret_cast<const __half2*>(vals_hat);
inp_grad_h += (row * row_stride);
out_grad_h1 += (row * row_stride);
out_grad_h2 += (row * row_stride);
vals_hat_h += (row * row_stride);
const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma);
const __half2* betta_h = (invertible ? reinterpret_cast<const __half2*>(betta) : nullptr);
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
__half2 gamma_reg = gamma_h[i * iteration_stride + id];
vals_arr[i] = out_grad_h1[i * iteration_stride + id];
vals_arr[i] *= gamma_reg; // out_grad * gamma
vals_hat_arr[i] =
(invertible
? (vals_hat_h[i * iteration_stride + id] - betta_h[i * iteration_stride + id]) /
gamma_reg
: vals_hat_h[i * iteration_stride + id]);
}
if ((high_index) < row_stride) {
__half2 gamma_reg = gamma_h[high_index];
vals_arr[iterations] = out_grad_h1[high_index];
vals_arr[iterations] *= gamma_reg; // out_grad * gamma
vals_hat_arr[iterations] =
(invertible ? (vals_hat_h[high_index] - betta_h[high_index]) / gamma_reg
: vals_hat_h[high_index]);
iterations++;
}
__half var_h = vars[row];
__half2 var_reg = __halves2half2(var_h, var_h);
float sum = 0.f;
for (int i = 0; i < iterations; i++) {
__half2 result_h = (vals_hat_arr[i] * vals_arr[i] * h2sqrt(var_reg));
float2 result_f = __half22float2(result_h);
sum += result_f.x;
sum += result_f.y;
vals_arr[i] *= h2rsqrt(var_reg);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
__half2 sum_h = __float2half2_rn(sum);
for (int i = 0; i < iterations; i++) {
__half2 temp = ((-sum_h * vals_hat_arr[i]) / (var_reg));
vals_arr_f[i] = __half22float2(vals_arr[i]);
float2 temp_f = __half22float2(temp);
vals_arr_f[i].x += temp_f.x;
vals_arr_f[i].y += temp_f.y;
}
sum = 0.f;
for (int i = 0; i < iterations; i++) {
sum += (vals_arr_f[i].x);
sum += (vals_arr_f[i].y);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) {
vals_arr_f[i].x -= sum;
vals_arr_f[i].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[i]);
inp_grad_h[i * iteration_stride + id] = temp + out_grad_h2[i * iteration_stride + id];
}
if ((high_index) < row_stride) {
vals_arr_f[iterations].x -= sum;
vals_arr_f[iterations].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[iterations]);
inp_grad_h[high_index] = temp + out_grad_h2[high_index];
}
}
template <>
void launch_layerNorm_backward_fused_add<float>(const float* out_grad1,
const float* out_grad2,
const float* vals_hat,
const float* vars,
const float* gamma,
float* gamma_grad,
float* betta_grad,
float* inp_grad,
int batch,
int hidden_dim,
hipStream_t stream[2],
bool invertible,
const float* betta)
{
int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
hipLaunchKernelGGL(( LayerNormBackward1<float>), dim3(grid_dim), dim3(block_dim), 0, stream[0],
out_grad1, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible);
dim3 grid_dim2(batch);
if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 1;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 2;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads);
hipLaunchKernelGGL(( LayerNormBackward2_fused_add), dim3(grid_dim2), dim3(block_dim2), 0, stream[1],
out_grad1, out_grad2, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim);
}
template <>
void launch_layerNorm_backward_fused_add<__half>(const __half* out_grad1,
const __half* out_grad2,
const __half* vals_hat,
const __half* vars,
const __half* gamma,
__half* gamma_grad,
__half* betta_grad,
__half* inp_grad,
int batch,
int hidden_dim,
hipStream_t stream[2],
bool invertible,
const __half* betta)
{
int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
hipLaunchKernelGGL(( LayerNormBackward1<__half>), dim3(grid_dim), dim3(block_dim), 0, stream[0],
out_grad1, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible);
dim3 grid_dim2(batch);
if (hidden_dim > 8192 && hidden_dim <= 16384)
threads <<= 1;
else if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 2;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 3;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads / 2);
hipLaunchKernelGGL(( LayerNormBackward2_fused_add), dim3(grid_dim2), dim3(block_dim2), 0, stream[1],
out_grad1, out_grad2, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim / 2);
}
/* Backward Normalize (Input-Gradient)
* Using the means and variances from the input
* This type of backward is not invertible!
* We do the backward using the input (X)
*/
__global__ void LayerNormBackward2_fused_add(const float* out_grad1,
const float* out_grad2,
const float* X_vals,
const float* gamma,
const float* vars,
const float* means,
float* inp_grad,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id / WARP_SIZE;
int warp_num = (THREADS < row_stride ? THREADS : row_stride) / WARP_SIZE;
__shared__ float partialSum[MAX_WARP_NUM];
float vals_arr[NORM_REG];
float vals_hat_arr[NORM_REG];
out_grad1 += (row * row_stride);
out_grad2 += (row * row_stride);
X_vals += (row * row_stride);
inp_grad += (row * row_stride);
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
float gamma_reg = gamma[i * iteration_stride + id];
vals_arr[i] = out_grad1[i * iteration_stride + id];
vals_arr[i] *= gamma_reg;
vals_hat_arr[i] = X_vals[i * iteration_stride + id];
}
if ((high_index) < row_stride) {
float gamma_reg = gamma[high_index];
vals_arr[iterations] = out_grad1[high_index];
vals_arr[iterations] *= gamma_reg;
vals_hat_arr[iterations] = X_vals[high_index];
iterations++;
}
float var_reg = vars[row];
float mean_reg = means[row];
float sum = 0;
float xu[NORM_REG];
for (int i = 0; i < iterations; i++) {
xu[i] = (vals_hat_arr[i] - mean_reg);
sum += vals_arr[i] * xu[i];
vals_arr[i] *= rsqrtf(var_reg);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
for (int i = 0; i < iterations; i++) {
vals_arr[i] += (-sum * xu[i] * rsqrtf(var_reg) / (var_reg));
}
sum = 0;
for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; }
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++)
inp_grad[i * iteration_stride + id] =
(vals_arr[i] - sum) + out_grad2[i * iteration_stride + id];
if ((high_index) < row_stride)
inp_grad[high_index] = (vals_arr[iterations] - sum) + out_grad2[high_index];
}
__global__ void LayerNormBackward2_fused_add(const __half* out_grad1,
const __half* out_grad2,
const __half* X_vals,
const __half* gamma,
const __half* vars,
const __half* means,
__half* inp_grad,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id / WARP_SIZE;
int warp_num = (iteration_stride < row_stride ? iteration_stride : row_stride) / WARP_SIZE;
__shared__ float partialSum[MAX_WARP_NUM];
__half2 vals_arr[NORM_REG];
float2 vals_arr_f[NORM_REG];
__half2 vals_hat_arr[NORM_REG];
__half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad);
const __half2* out_grad_h1 = reinterpret_cast<const __half2*>(out_grad1);
const __half2* out_grad_h2 = reinterpret_cast<const __half2*>(out_grad2);
const __half2* vals_hat_h = reinterpret_cast<const __half2*>(X_vals);
out_grad_h1 += (row * row_stride);
out_grad_h2 += (row * row_stride);
inp_grad_h += (row * row_stride);
vals_hat_h += (row * row_stride);
const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma);
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
__half2 gamma_reg = gamma_h[i * iteration_stride + id];
vals_arr[i] = out_grad_h1[i * iteration_stride + id];
vals_arr[i] *= gamma_reg; // out_grad * gamma
vals_hat_arr[i] = vals_hat_h[i * iteration_stride + id];
}
if ((high_index) < row_stride) {
__half2 gamma_reg = gamma_h[high_index];
vals_arr[iterations] = out_grad_h1[high_index];
vals_arr[iterations] *= gamma_reg; // out_grad * gamma
vals_hat_arr[iterations] = vals_hat_h[high_index];
iterations++;
}
__half mean_h = means[row];
__half var_h = vars[row];
__half2 var_reg = __halves2half2(var_h, var_h);
__half2 mean_reg = __halves2half2(mean_h, mean_h);
__half2 xu[NORM_REG];
float sum = 0.f;
for (int i = 0; i < iterations; i++) {
xu[i] = (vals_hat_arr[i] - mean_reg);
__half2 result_h = (xu[i] * vals_arr[i]);
float2 result_f = __half22float2(result_h);
sum += result_f.x;
sum += result_f.y;
vals_arr[i] *= h2rsqrt(var_reg);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
__half2 sum_h = __float2half2_rn(sum);
for (int i = 0; i < iterations; i++) {
__half2 xu_grad = ((-sum_h * xu[i] * h2rsqrt(var_reg)) / (var_reg));
vals_arr_f[i] = __half22float2(vals_arr[i]);
float2 xu_grad_f = __half22float2(xu_grad);
vals_arr_f[i].x += xu_grad_f.x;
vals_arr_f[i].y += xu_grad_f.y;
}
sum = 0.f;
for (int i = 0; i < iterations; i++) {
sum += (vals_arr_f[i].x);
sum += (vals_arr_f[i].y);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) {
vals_arr_f[i].x -= sum;
vals_arr_f[i].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[i]);
inp_grad_h[i * iteration_stride + id] = temp + out_grad_h2[i * iteration_stride + id];
}
if ((high_index) < row_stride) {
vals_arr_f[iterations].x -= sum;
vals_arr_f[iterations].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[iterations]);
inp_grad_h[high_index] = temp + out_grad_h2[high_index];
}
}
template <>
void launch_layerNorm_backward_fused_add<float>(const float* out_grad1,
const float* out_grad2,
const float* X_data,
const float* vars,
const float* means,
const float* gamma,
float* gamma_grad,
float* betta_grad,
float* inp_grad,
int batch,
int hidden_dim,
hipStream_t stream[2])
{
int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
hipLaunchKernelGGL(( LayerNormBackward1<float>), dim3(grid_dim), dim3(block_dim), 0, stream[0],
out_grad1, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim);
dim3 grid_dim2(batch);
if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 1;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 2;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads);
hipLaunchKernelGGL(( LayerNormBackward2_fused_add), dim3(grid_dim2), dim3(block_dim2), 0, stream[1],
out_grad1, out_grad2, X_data, gamma, vars, means, inp_grad, hidden_dim);
}
template <>
void launch_layerNorm_backward_fused_add<__half>(const __half* out_grad1,
const __half* out_grad2,
const __half* X_data,
const __half* vars,
const __half* means,
const __half* gamma,
__half* gamma_grad,
__half* betta_grad,
__half* inp_grad,
int batch,
int hidden_dim,
hipStream_t stream[2])
{
int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
hipLaunchKernelGGL(( LayerNormBackward1<__half>), dim3(grid_dim), dim3(block_dim), 0, stream[0],
out_grad1, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim);
dim3 grid_dim2(batch);
if (hidden_dim > 8192 && hidden_dim <= 16384)
threads <<= 1;
else if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 2;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 3;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads / 2);
hipLaunchKernelGGL(( LayerNormBackward2_fused_add), dim3(grid_dim2), dim3(block_dim2), 0, stream[1],
out_grad1, out_grad2, X_data, gamma, vars, means, inp_grad, hidden_dim / 2);
}
#include "hip/hip_runtime.h"
#include <math.h>
#include "hip/custom_hip_layers.h"
#include "hip/general_kernels.h"
namespace cg = cooperative_groups;
// Fused attention + softmax
template <int tbSize, int blockStride, int tbSeq>
__global__ void attn_softmax(float* vals,
const float* attn_mask,
int heads,
int seq_length,
int iterations)
{
__shared__ float partialSum[MAX_WARP_NUM];
int warp_num = blockDim.x >> 5;
int iteration_stride = blockDim.x;
int block_width = blockStride * seq_length;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<tbSize> g = cg::tiled_partition<tbSize>(b);
int batch = blockIdx.x;
int row = blockIdx.y;
int max_threads_in_sequence = ::max(seq_length, tbSeq);
int seq_lane = threadIdx.x % max_threads_in_sequence;
int data_offset = batch * (gridDim.y * block_width) + row * block_width +
(threadIdx.x / max_threads_in_sequence) * seq_length;
int mask_offset = batch * seq_length;
int wid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
float4* val_cast = reinterpret_cast<float4*>(vals);
const float4* attn_mask_cast = reinterpret_cast<const float4*>(attn_mask);
float4 data[MAX_THREAD_ITERATIONS];
float max_val = minus_infinity;
for (int i = 0; i < iterations; i++) {
int data_id = i * iteration_stride + seq_lane;
if (data_id < seq_length) {
float4 mask = attn_mask_cast[mask_offset + data_id];
data[i] = val_cast[data_offset + data_id];
data[i].x += mask.x;
data[i].y += mask.y;
data[i].z += mask.z;
data[i].w += mask.w;
max_val = (data[i].x > max_val ? data[i].x : max_val);
max_val = (data[i].y > max_val ? data[i].y : max_val);
max_val = (data[i].z > max_val ? data[i].z : max_val);
max_val = (data[i].w > max_val ? data[i].w : max_val);
} else {
data[i].x = minus_infinity;
data[i].y = minus_infinity;
data[i].z = minus_infinity;
data[i].w = minus_infinity;
}
}
for (int i = 1; i < tbSize; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
if (seq_length > tbSize) {
if (lane == 0) partialSum[wid] = max_val;
b.sync();
if (lane < warp_num) max_val = partialSum[lane];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
int iters = warp_num;
if (seq_length < iteration_stride)
iters = warp_num / (iteration_stride / max_threads_in_sequence);
for (int i = 1; i < iters; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
max_val = g.shfl(max_val, threadIdx.x / tbSize);
}
float sum = 0;
for (int i = 0; i < iterations; i++) {
data[i].x = __expf(data[i].x - max_val);
data[i].y = __expf(data[i].y - max_val);
data[i].z = __expf(data[i].z - max_val);
data[i].w = __expf(data[i].w - max_val);
sum += (data[i].x + data[i].y + data[i].z + data[i].w);
}
for (int i = 1; i < tbSize; i *= 2) { sum += g.shfl_xor(sum, i); }
if (seq_length > tbSize) {
if (lane == 0) partialSum[wid] = sum;
b.sync();
if (lane < warp_num) sum = partialSum[lane];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
int iters = warp_num;
if (seq_length < iteration_stride)
iters = warp_num / (iteration_stride / max_threads_in_sequence);
for (int i = 1; i < iters; i *= 2) { sum += g.shfl_xor(sum, i); }
sum = g.shfl(sum, threadIdx.x / tbSize);
}
sum += 1e-6;
for (int i = 0; i < iterations; i++) {
data[i].x /= sum;
data[i].y /= sum;
data[i].z /= sum;
data[i].w /= sum;
int data_id = i * iteration_stride + seq_lane;
if (data_id < seq_length) val_cast[data_offset + data_id] = data[i];
}
}
template <int tbSize, int blockStride, int tbSeq>
__global__ void attn_softmax(__half* vals,
const __half* attn_mask,
int heads,
int seq_length,
int iterations)
{
#if __CUDA_ARCH__ >= 700
__shared__ float partialSum[MAX_WARP_NUM];
int warp_num = blockDim.x >> 5;
int iteration_stride = blockDim.x;
int block_width = blockStride * seq_length;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<tbSize> g = cg::tiled_partition<tbSize>(b);
int batch = blockIdx.x;
int row = blockIdx.y;
int max_threads_in_sequence = ::max(seq_length, tbSeq);
int seq_lane = threadIdx.x % max_threads_in_sequence;
int data_offset = batch * (gridDim.y * block_width) + row * block_width +
(threadIdx.x / max_threads_in_sequence) * seq_length;
int mask_offset = batch * seq_length;
int wid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
float2* val_cast = reinterpret_cast<float2*>(vals);
const float2* attn_mask_cast = reinterpret_cast<const float2*>(attn_mask);
val_cast += data_offset;
attn_mask_cast += mask_offset;
float2 low_data[MAX_THREAD_ITERATIONS];
float2 high_data[MAX_THREAD_ITERATIONS];
float max_val = minus_infinity;
for (int i = 0; i < iterations; i++) {
int data_id = i * iteration_stride + seq_lane;
if (data_id < seq_length) {
float2 data = val_cast[data_id];
float2 mask = attn_mask_cast[data_id];
__half2* data_arr = reinterpret_cast<__half2*>(&data);
__half2* mask_arr = reinterpret_cast<__half2*>(&mask);
low_data[i] = __half22float2(data_arr[0]);
high_data[i] = __half22float2(data_arr[1]);
float2 low_mask = __half22float2(mask_arr[0]);
float2 high_mask = __half22float2(mask_arr[1]);
low_data[i].x += low_mask.x;
low_data[i].y += low_mask.y;
high_data[i].x += high_mask.x;
high_data[i].y += high_mask.y;
max_val = (low_data[i].x > max_val ? low_data[i].x : max_val);
max_val = (low_data[i].y > max_val ? low_data[i].y : max_val);
max_val = (high_data[i].x > max_val ? high_data[i].x : max_val);
max_val = (high_data[i].y > max_val ? high_data[i].y : max_val);
}
}
for (int i = 1; i < tbSize; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
if (seq_length > tbSize) {
if (lane == 0) partialSum[wid] = max_val;
b.sync();
if (lane < warp_num) max_val = partialSum[lane];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
int iters = warp_num;
if (seq_length < iteration_stride)
iters = warp_num / (iteration_stride / max_threads_in_sequence);
for (int i = 1; i < iters; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
max_val = g.shfl(max_val, threadIdx.x / tbSize);
}
float sum = 0;
for (int i = 0; i < iterations; i++) {
int data_id = i * iteration_stride + seq_lane;
if (data_id < seq_length) {
low_data[i].x = __expf(low_data[i].x - max_val);
low_data[i].y = __expf(low_data[i].y - max_val);
high_data[i].x = __expf(high_data[i].x - max_val);
high_data[i].y = __expf(high_data[i].y - max_val);
sum += (low_data[i].x + low_data[i].y + high_data[i].x + high_data[i].y);
}
}
for (int i = 1; i < tbSize; i *= 2) { sum += g.shfl_xor(sum, i); }
if (seq_length > tbSize) {
if (lane == 0) partialSum[wid] = sum;
b.sync();
if (lane < warp_num) sum = partialSum[lane];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
int iters = warp_num;
if (seq_length < iteration_stride)
iters = warp_num / (iteration_stride / max_threads_in_sequence);
for (int i = 1; i < iters; i *= 2) { sum += g.shfl_xor(sum, i); }
sum = g.shfl(sum, threadIdx.x / tbSize);
}
sum += 1e-6;
for (int i = 0; i < iterations; i++) {
int data_id = i * iteration_stride + seq_lane;
if (data_id < seq_length) {
float2 result_f;
__half2* result_h = reinterpret_cast<__half2*>(&result_f);
low_data[i].x /= sum;
low_data[i].y /= sum;
high_data[i].x /= sum;
high_data[i].y /= sum;
result_h[0] = __float22half2_rn(low_data[i]);
result_h[1] = __float22half2_rn(high_data[i]);
val_cast[data_id] = result_f;
}
}
#endif
}
template <typename T>
void launch_attn_softmax(T*, const T*, int, int, int, hipStream_t);
template <>
void launch_attn_softmax<float>(float* vals,
const float* attn_mask,
int batch_size,
int heads,
int sequence_length,
hipStream_t stream)
{
const int threads = 128;
int seq_length4 = sequence_length / 4;
int block_compute_size =
(seq_length4 < threads ? (int)pow(2.0, floor(log2((float)(threads / seq_length4)))) : 1);
dim3 grid_dim(batch_size, heads * sequence_length / block_compute_size);
int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads;
dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) /
subblock_max_workload * threads)
: threads);
int iterations =
(sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads
: MAX_THREAD_ITERATIONS);
if (sequence_length <= 8)
hipLaunchKernelGGL(( attn_softmax<2, (threads / 2), 2>)
, dim3(grid_dim), dim3(block_dim), 0, stream, vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 16)
hipLaunchKernelGGL(( attn_softmax<4, (threads / 4), 4>)
, dim3(grid_dim), dim3(block_dim), 0, stream, vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 32)
hipLaunchKernelGGL(( attn_softmax<8, (threads / 8), 8>)
, dim3(grid_dim), dim3(block_dim), 0, stream, vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 64)
hipLaunchKernelGGL(( attn_softmax<16, (threads / 16), 16>)
, dim3(grid_dim), dim3(block_dim), 0, stream, vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 128)
hipLaunchKernelGGL(( attn_softmax<32, (threads / 32), 32>)
, dim3(grid_dim), dim3(block_dim), 0, stream, vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 256)
hipLaunchKernelGGL(( attn_softmax<32, (threads / 64), 64>)
, dim3(grid_dim), dim3(block_dim), 0, stream, vals, attn_mask, heads, seq_length4, iterations);
else {
const int threads = 256;
block_compute_size =
(seq_length4 < threads ? (int)pow(2.0, floor(log2((float)(threads / seq_length4))))
: 1);
dim3 grid_dim(batch_size, heads * sequence_length / block_compute_size);
int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads;
dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) /
subblock_max_workload * threads)
: threads);
iterations =
(sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads
: MAX_THREAD_ITERATIONS);
if (sequence_length <= 512)
hipLaunchKernelGGL(( attn_softmax<32, (threads / 128), 128>), dim3(grid_dim), dim3(block_dim), 0, stream,
vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length < (MAX_THREADS * MAX_THREAD_ITERATIONS * 4))
hipLaunchKernelGGL(( attn_softmax<32, 1, 128>), dim3(grid_dim), dim3(block_dim), 0, stream,
vals, attn_mask, heads, seq_length4, iterations);
else
throw std::runtime_error(
"Unsupport Seq_Length! Check the restriction of the max_threads and "
"max_thread_iterations!");
}
}
template <>
void launch_attn_softmax<__half>(__half* vals,
const __half* attn_mask,
int batch_size,
int heads,
int sequence_length,
hipStream_t stream)
{
const int threads = 128;
int seq_length4 = sequence_length / 4;
int block_compute_size =
(seq_length4 < threads ? (int)pow(2.0, floor(log2((float)(threads / seq_length4)))) : 1);
dim3 grid_dim(batch_size, heads * sequence_length / block_compute_size);
int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads;
dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) /
subblock_max_workload * threads)
: threads);
int iterations =
(sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads
: MAX_THREAD_ITERATIONS);
if (sequence_length <= 8)
hipLaunchKernelGGL(( attn_softmax<2, (threads / 2), 2>)
, dim3(grid_dim), dim3(block_dim), 0, stream, vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 16)
hipLaunchKernelGGL(( attn_softmax<4, (threads / 4), 4>)
, dim3(grid_dim), dim3(block_dim), 0, stream, vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 32)
hipLaunchKernelGGL(( attn_softmax<8, (threads / 8), 8>)
, dim3(grid_dim), dim3(block_dim), 0, stream, vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 64)
hipLaunchKernelGGL(( attn_softmax<16, (threads / 16), 16>)
, dim3(grid_dim), dim3(block_dim), 0, stream, vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 128)
hipLaunchKernelGGL(( attn_softmax<32, (threads / 32), 32>)
, dim3(grid_dim), dim3(block_dim), 0, stream, vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 256)
hipLaunchKernelGGL(( attn_softmax<32, (threads / 64), 64>)
, dim3(grid_dim), dim3(block_dim), 0, stream, vals, attn_mask, heads, seq_length4, iterations);
else {
const int threads = 256;
block_compute_size =
(seq_length4 < threads ? (int)pow(2.0, floor(log2((float)(threads / seq_length4))))
: 1);
dim3 grid_dim(batch_size, heads * sequence_length / block_compute_size);
int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads;
dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) /
subblock_max_workload * threads)
: threads);
iterations =
(sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads
: MAX_THREAD_ITERATIONS);
if (sequence_length <= 512)
hipLaunchKernelGGL(( attn_softmax<32, (threads / 128), 128>), dim3(grid_dim), dim3(block_dim), 0, stream,
vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length < (MAX_THREADS * MAX_THREAD_ITERATIONS * 4))
hipLaunchKernelGGL(( attn_softmax<32, 1, 128>), dim3(grid_dim), dim3(block_dim), 0, stream,
vals, attn_mask, heads, seq_length4, iterations);
else
throw std::runtime_error(
"Unsupport Seq_Length! Check the restriction of the max_threads and "
"max_thread_iterations!");
}
}
template <typename T, int tbSize, int blockStride>
__global__ void softmax_backward_kernel(T* out_grad, const T* soft_inp, int seq_length)
{
__shared__ float partialSum[MAX_WARP_NUM];
int warp_num = blockDim.x >> 5; // warp-count = num_threads / WARP_SIZE (32)
int iteration_stride = blockDim.x;
int block_width = blockStride * seq_length;
int iterations = (seq_length < (MAX_THREAD_ITERATIONS * iteration_stride)
? (seq_length + iteration_stride - 1) / iteration_stride
: MAX_THREAD_ITERATIONS);
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<tbSize> g = cg::tiled_partition<tbSize>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id >> 5;
int lane = id & 0x1f;
T val_reg[MAX_THREAD_ITERATIONS];
T soft_reg[MAX_THREAD_ITERATIONS];
float grad_reg = 0.0f;
#pragma unroll
for (int i = 0; i < iterations; i++) {
int data_id = i * iteration_stride + id;
if (data_id < block_width) {
val_reg[i] = out_grad[row * block_width + data_id];
soft_reg[i] = soft_inp[row * block_width + data_id];
grad_reg += ((float)val_reg[i] *
(float)soft_reg[i]); // if done in half, the multiplication, we may lose
// 2% of accuracy in computation!!
}
}
for (int i = 1; i < tbSize; i *= 2) grad_reg += g.shfl_xor(grad_reg, i);
if (seq_length > tbSize) {
if (lane == 0) partialSum[wid] = grad_reg;
b.sync();
if (lane < warp_num) grad_reg = partialSum[lane];
int iters = warp_num;
if (seq_length < iteration_stride) iters = warp_num / (iteration_stride / seq_length);
for (int i = 1; i < iters; i *= 2) grad_reg += g.shfl_xor(grad_reg, i);
grad_reg = g.shfl(grad_reg, id / tbSize);
}
for (int i = 0; i < iterations; i++) {
int data_id = i * iteration_stride + id;
if (data_id < block_width) {
float temp = (float)soft_reg[i] * ((float)val_reg[i] - grad_reg);
out_grad[row * block_width + data_id] = (T)temp;
}
}
}
template <typename T, int ITERATIONS>
__global__ void softmax_backward_kernel_v2(T* grad /* input & output*/,
const T* output,
int softmax_length)
{
int batch_idx = blockIdx.x * blockDim.y + threadIdx.y;
int offset = batch_idx * softmax_length + threadIdx.x;
grad += offset;
output += offset;
T grad_reg[ITERATIONS];
T output_reg[ITERATIONS];
float sum = 0.0;
#pragma unroll
for (int i = 0; i < ITERATIONS; ++i) {
int curr_idx = threadIdx.x + i * WARP_SIZE;
if (curr_idx < softmax_length) {
grad_reg[i] = grad[i * WARP_SIZE];
output_reg[i] = output[i * WARP_SIZE];
sum += (float)grad_reg[i] * (float)output_reg[i];
}
}
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_xor(sum, i);
#pragma unroll
for (int i = 0; i < ITERATIONS; ++i) {
int curr_idx = threadIdx.x + i * WARP_SIZE;
if (curr_idx < softmax_length)
grad[i * WARP_SIZE] = (float)output_reg[i] * ((float)grad_reg[i] - sum);
}
}
template <typename T>
void launch_attn_softmax_backward_v2(T* out_grad,
const T* soft_inp,
int batch_size,
int heads,
int seq_length,
hipStream_t stream)
{
const int warps_per_block = 4;
dim3 grid_dim(batch_size * heads * seq_length / warps_per_block);
dim3 block_dim(WARP_SIZE, warps_per_block);
if (seq_length <= 32)
hipLaunchKernelGGL(( softmax_backward_kernel_v2<T, 1>)
, dim3(grid_dim), dim3(block_dim), 0, stream, out_grad, soft_inp, seq_length);
else if (seq_length <= 64)
hipLaunchKernelGGL(( softmax_backward_kernel_v2<T, 2>)
, dim3(grid_dim), dim3(block_dim), 0, stream, out_grad, soft_inp, seq_length);
else if (seq_length <= 128)
hipLaunchKernelGGL(( softmax_backward_kernel_v2<T, 4>)
, dim3(grid_dim), dim3(block_dim), 0, stream, out_grad, soft_inp, seq_length);
else if (seq_length <= 256)
hipLaunchKernelGGL(( softmax_backward_kernel_v2<T, 8>)
, dim3(grid_dim), dim3(block_dim), 0, stream, out_grad, soft_inp, seq_length);
else if (seq_length <= 384)
hipLaunchKernelGGL(( softmax_backward_kernel_v2<T, 12>)
, dim3(grid_dim), dim3(block_dim), 0, stream, out_grad, soft_inp, seq_length);
else if (seq_length <= 512)
hipLaunchKernelGGL(( softmax_backward_kernel_v2<T, 16>)
, dim3(grid_dim), dim3(block_dim), 0, stream, out_grad, soft_inp, seq_length);
else if (seq_length <= 768)
hipLaunchKernelGGL(( softmax_backward_kernel_v2<T, 24>)
, dim3(grid_dim), dim3(block_dim), 0, stream, out_grad, soft_inp, seq_length);
else if (seq_length <= 1024)
hipLaunchKernelGGL(( softmax_backward_kernel_v2<T, 32>)
, dim3(grid_dim), dim3(block_dim), 0, stream, out_grad, soft_inp, seq_length);
else if (seq_length <= 2048)
hipLaunchKernelGGL(( softmax_backward_kernel_v2<T, 64>)
, dim3(grid_dim), dim3(block_dim), 0, stream, out_grad, soft_inp, seq_length);
else
throw std::runtime_error(
std::string("Special sequence length found in softmax backward, seq_length: ") +
std::to_string(seq_length));
}
template void launch_attn_softmax_backward_v2<__half>(__half* out_grad,
const __half* soft_inp,
int batch_size,
int heads,
int seq_length,
hipStream_t stream);
template void launch_attn_softmax_backward_v2<float>(float* out_grad,
const float* soft_inp,
int batch_size,
int heads,
int seq_length,
hipStream_t stream);
#include "hip/hip_runtime.h"
#include "hip/custom_hip_layers.h"
#define rows_trans 16
#define cols_trans 16
template <typename T>
__global__ void Transpose_Kernel(const T* inp, T* out, int row_width, int col_width)
{
__shared__ T data_block[rows_trans * (cols_trans + 1)];
int r = threadIdx.x / cols_trans;
int c = threadIdx.x % cols_trans;
int m = row_width / cols_trans;
int i = blockIdx.x / m * rows_trans + r;
int j = blockIdx.x % m * cols_trans + c;
int row_stride = rows_trans / ((rows_trans * cols_trans + THREADS - 1) / THREADS);
for (int k = 0; k < rows_trans; k += row_stride)
data_block[(k + r) * cols_trans + c] = inp[(i + k) * row_width + j];
__syncthreads();
i = blockIdx.x % m * rows_trans + r;
j = blockIdx.x / m * cols_trans + c;
for (int k = 0; k < rows_trans; k += row_stride)
out[(i + k) * col_width + j] = data_block[c * cols_trans + r + k];
}
template <>
void Transpose<__half>(const __half* inp_mat,
__half* out_mat,
int rows,
int cols,
hipStream_t stream)
{
int threads = THREADS;
hipLaunchKernelGGL(( Transpose_Kernel<__half>), dim3((rows * cols + threads - 1) / threads), dim3(threads), 0, stream,
inp_mat, out_mat, cols, rows);
}
template <>
void Transpose<float>(const float* inp_mat, float* out_mat, int rows, int cols, hipStream_t stream)
{
int threads = THREADS;
hipLaunchKernelGGL(( Transpose_Kernel<float>), dim3((rows * cols + threads - 1) / threads), dim3(threads), 0, stream,
inp_mat, out_mat, cols, rows);
}
template <typename T>
__global__ void transform_0213(T* output,
const T* vals,
int hidden_dim,
int seq_length,
int heads,
int head_ext);
template <>
__global__ void transform_0213<float>(float* output,
const float* vals,
int hidden_dim,
int seq_length,
int heads,
int head_ext)
{
int d0_stride = hidden_dim * seq_length;
int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads;
int d0_out_stride = d0_stride;
int d1_out_stride = d2_stride;
int d2_out_stride = d2_stride * seq_length;
int d0 = blockIdx.x; // Batch
int d1 = blockIdx.y / head_ext; // Sequence ID (0-127)
int d2 = threadIdx.y + (blockIdx.y % head_ext) * (heads / head_ext); // Head (0-11)
int d3 = threadIdx.x; // Values (groups of 4)
const float4* vals_vec = reinterpret_cast<const float4*>(vals);
float4* output_vec = reinterpret_cast<float4*>(output);
float4 inputs = vals_vec[d0 * d0_stride + d1 * d1_stride + d2 * d2_stride + d3];
output_vec[d0 * d0_out_stride + d1 * d1_out_stride + d2 * d2_out_stride + d3] = inputs;
}
template <>
__global__ void transform_0213<__half>(__half* output,
const __half* vals,
int hidden_dim,
int seq_length,
int heads,
int head_ext)
{
#if __CUDA_ARCH__ >= 700
int d0_stride = hidden_dim * seq_length;
int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads;
int d0_out_stride = d0_stride;
int d1_out_stride = d2_stride;
int d2_out_stride = d2_stride * seq_length;
int d0 = blockIdx.x; // Batch
int d1 = blockIdx.y / head_ext; // Sequence ID (0-127)
int d2 = threadIdx.y + (blockIdx.y % head_ext) * (heads / head_ext); // Head (0-11)
int d3 = threadIdx.x; // Values (groups of 4)
float4 vals_arr[1];
const float4* vals_vec = reinterpret_cast<const float4*>(vals);
float4* output_vec = reinterpret_cast<float4*>(output);
vals_arr[0] = vals_vec[d0 * d0_stride + d1 * d1_stride + d2 * d2_stride + d3];
output_vec[d0 * d0_out_stride + d1 * d1_out_stride + d2 * d2_out_stride + d3] = vals_arr[0];
#endif
}
template <>
void launch_transform_0213<float>(float* output,
const float* vals,
int batch_size,
int seq_length,
int hidden_dim,
int heads,
hipStream_t stream)
{
hidden_dim >>= 2;
int head_ext = (hidden_dim - 1) / MAX_THREADS + 1;
dim3 block_dim(hidden_dim / heads, (heads / head_ext));
dim3 grid_dim(batch_size, (seq_length * head_ext));
hipLaunchKernelGGL(( transform_0213<float>)
, dim3(grid_dim), dim3(block_dim), 0, stream, output, vals, hidden_dim, seq_length, heads, head_ext);
}
template <>
void launch_transform_0213<__half>(__half* output,
const __half* vals,
int batch_size,
int seq_length,
int hidden_dim,
int heads,
hipStream_t stream)
{
hidden_dim >>= 3;
int head_ext = (hidden_dim - 1) / MAX_THREADS + 1;
dim3 block_dim(hidden_dim / heads, (heads / head_ext));
dim3 grid_dim(batch_size, (seq_length * head_ext));
hipLaunchKernelGGL(( transform_0213<__half>)
, dim3(grid_dim), dim3(block_dim), 0, stream, output, vals, hidden_dim, seq_length, heads, head_ext);
}
// Bias add
template <typename T>
__global__ void bias_add_transform_0213(T* output,
const T* vals,
const T* bias,
int hidden_dim,
int seq_length,
int heads,
int head_ext);
template <>
__global__ void bias_add_transform_0213<float>(float* output,
const float* vals,
const float* bias,
int hidden_dim,
int seq_length,
int heads,
int head_ext)
{
int d0_stride = hidden_dim * seq_length;
int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads;
int d0_out_stride = d0_stride;
int d1_out_stride = d2_stride;
int d2_out_stride = d2_stride * seq_length;
int d0 = blockIdx.x; // Batch
int d1 = blockIdx.y; // Sequence ID (0-127)
int cnt = blockIdx.z / head_ext; // Hidden count
int d2 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head (0-11)
int d3 = threadIdx.x; // Values (groups of 4)
const float4* vals_vec = reinterpret_cast<const float4*>(vals);
const float4* bias_vec = reinterpret_cast<const float4*>(bias);
float4* output_vec = reinterpret_cast<float4*>(output);
float4 inputs = vals_vec[d0 * d0_stride * (gridDim.z / head_ext) + cnt * d1_stride +
d1 * d1_stride * (gridDim.z / head_ext) + d2 * d2_stride + d3];
float4 biases = bias_vec[cnt * d1_stride + d2 * d2_stride + d3];
float4 outputs;
outputs.x = inputs.x + biases.x;
outputs.y = inputs.y + biases.y;
outputs.z = inputs.z + biases.z;
outputs.w = inputs.w + biases.w;
output_vec[cnt * d0_out_stride * gridDim.x + d0 * d0_out_stride + d1 * d1_out_stride +
d2 * d2_out_stride + d3] = outputs;
}
#define ATTN_H 3
#define MAX_SEQ_LINE 10
template <>
__global__ void bias_add_transform_0213<__half>(__half* output,
const __half* vals,
const __half* bias,
int hidden_dim,
int seq_length,
int heads,
int head_ext)
{
#if __CUDA_ARCH__ >= 700
int d0_stride = hidden_dim * seq_length;
int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads;
int d2_out_stride = d2_stride * seq_length;
int d0 = blockIdx.x; // Batch
int d1 = blockIdx.y; // Sequence ID (0-127)
int cnt = blockIdx.z / head_ext; // Hidden count
int d2 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head (0-11)
int d3 = threadIdx.x; // Values (groups of 4)
float4 vals_arr;
float4 bias_arr;
float4 output_arr;
__half2* vals_half = reinterpret_cast<__half2*>(&vals_arr);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_arr);
__half2* output_half = reinterpret_cast<__half2*>(&output_arr);
const float4* vals_vec = reinterpret_cast<const float4*>(vals);
const float4* bias_vec = reinterpret_cast<const float4*>(bias);
float4* output_vec = reinterpret_cast<float4*>(output);
vals_vec += (d0 * d0_stride * (gridDim.z / head_ext));
vals_vec += (d1 * d1_stride * (gridDim.z / head_ext));
vals_vec += (cnt * d1_stride);
vals_vec += (d2 * d2_stride);
bias_vec += (cnt * d1_stride);
bias_vec += (d2 * d2_stride);
output_vec += (cnt * d0_stride * gridDim.x);
output_vec += (d1 * d2_stride);
output_vec += (d0 * d0_stride);
output_vec += (d2 * d2_out_stride);
bias_arr = bias_vec[d3];
vals_arr = vals_vec[d3];
#if defined(__ACC_HALF__)
output_half[0] = vals_half[0] + bias_half[0];
output_half[1] = vals_half[1] + bias_half[1];
output_half[2] = vals_half[2] + bias_half[2];
output_half[3] = vals_half[3] + bias_half[3];
#else
float2 bias_arr_f[4];
float2 vals_arr_f[4];
#pragma unroll
for (int l = 0; l < 4; l++) {
bias_arr_f[l] = __half22float2(bias_half[l]);
vals_arr_f[l] = __half22float2(vals_half[l]);
vals_arr_f[l].x += bias_arr_f[l].x;
vals_arr_f[l].y += bias_arr_f[l].y;
output_half[l] = __float22half2_rn(vals_arr_f[l]);
}
#endif
output_vec[d3] = output_arr;
#endif
}
__global__ void bias_add_transform_0213_v2(__half* output,
const __half* vals,
const __half* bias,
int hidden_dim,
int seq_length,
int heads)
{
#if __CUDA_ARCH__ >= 700
__shared__ float4 in_data[3072];
int d0_stride = hidden_dim * seq_length;
int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads;
int iteration_stride = d1_stride * blockDim.z; // Hidden * 3 / 8
int batch_stride = d0_stride * blockDim.z; // Hidden * S * 3 / 8
int d0_out_stride = d0_stride;
int d1_out_stride = d2_stride;
int d2_out_stride = d2_stride * seq_length;
int d0 = blockIdx.x; // Batch
int d1 = blockIdx.y; // Sequence ID (0-127)
int cnt = threadIdx.z; // blockIdx.z; // Hidden count
int d2 = threadIdx.y; // Head (0-11)
int d3 = threadIdx.x; // Values (groups of 4)
float4 vals_arr[1];
float4 bias_arr[1];
float4 output_arr[1];
__half2* vals_half = reinterpret_cast<__half2*>(vals_arr);
__half2* bias_half = reinterpret_cast<__half2*>(bias_arr);
__half2* output_half = reinterpret_cast<__half2*>(output_arr);
const float4* vals_vec = reinterpret_cast<const float4*>(vals);
const float4* bias_vec = reinterpret_cast<const float4*>(bias);
float4* output_vec = reinterpret_cast<float4*>(output);
int iter_index = cnt * d1_stride + d2 * d2_stride + d3;
int input_offset = d0 * batch_stride + d1 * (iteration_stride << 1);
bias_arr[0] = bias_vec[iter_index];
#pragma unroll
for (int iter = 0; iter < 2; iter++) {
int iter_id = iter * iteration_stride + iter_index;
vals_arr[0] = vals_vec[input_offset + iter_id];
output_half[0] = vals_half[0] + bias_half[0];
output_half[1] = vals_half[1] + bias_half[1];
output_half[2] = vals_half[2] + bias_half[2];
output_half[3] = vals_half[3] + bias_half[3];
in_data[iter_id] = output_arr[0];
}
__syncthreads();
iteration_stride = blockDim.z * (blockDim.y >> 1);
int matrix_stride = (d0_out_stride * gridDim.x);
int head_count = (d2 >> 1) + cnt * (blockDim.y >> 1);
int out_index = d0 * d0_out_stride + d1 * (d1_out_stride << 1) + d3 + (d2 % 2) * d2_stride;
#pragma unroll
for (int iter = 0; iter < 2; iter++) {
int iter_row = (iter * iteration_stride) + head_count;
int iter_offset =
(iter_row % blockDim.y) * d2_out_stride + (iter_row / blockDim.y) * matrix_stride;
output_vec[out_index + iter_offset] =
in_data[iter_row * d2_stride + d3 + (d2 % 2) * (d1_stride * blockDim.z)];
}
#endif
}
// [B S C*H] - > C * [B A S N]
template <>
void launch_bias_add_transform_0213<float>(float* output,
const float* vals,
const float* bias,
int batch_size,
int seq_length,
int hidden_dim,
int heads,
hipStream_t stream,
int trans_count)
{
hidden_dim >>= 2;
int head_ext = (hidden_dim - 1) / MAX_THREADS + 1;
dim3 block_dim(hidden_dim / heads, (heads / head_ext));
dim3 grid_dim(batch_size, seq_length, (trans_count * head_ext));
hipLaunchKernelGGL(( bias_add_transform_0213<float>), dim3(grid_dim), dim3(block_dim), 0, stream,
output, vals, bias, hidden_dim, seq_length, heads, head_ext);
}
template <>
void launch_bias_add_transform_0213<__half>(__half* output,
const __half* vals,
const __half* bias,
int batch_size,
int seq_length,
int hidden_dim,
int heads,
hipStream_t stream,
int trans_count)
{
hidden_dim >>= 3;
if (hidden_dim > 128 || hidden_dim < 16) {
int head_ext = (hidden_dim - 1) / MAX_THREADS + 1;
dim3 block_dim(hidden_dim / heads, (heads / head_ext));
dim3 grid_dim(batch_size, seq_length, (trans_count * head_ext));
hipLaunchKernelGGL(( bias_add_transform_0213<__half>), dim3(grid_dim), dim3(block_dim), 0, stream,
output, vals, bias, hidden_dim, seq_length, heads, head_ext);
} else {
dim3 block_dim(hidden_dim / heads, heads, trans_count);
dim3 grid_dim(batch_size, seq_length / 2);
hipLaunchKernelGGL(( bias_add_transform_0213_v2), dim3(grid_dim), dim3(block_dim), 0, stream,
output, vals, bias, hidden_dim, seq_length, heads);
}
}
template <typename T>
__global__ void transform4d_0213(T* out,
const T* in,
int heads,
int seq_length,
int hidden_dim,
int head_ext);
template <>
__global__ void transform4d_0213<float>(float* out,
const float* in,
int heads,
int seq_length,
int hidden_dim,
int head_ext)
{
int d0_stride = hidden_dim * seq_length;
int d1_stride = d0_stride / heads;
int d2_stride = hidden_dim / heads;
int d0_out_stride = d0_stride;
int d1_out_stride = d2_stride;
int d2_out_stride = hidden_dim;
int d0 = blockIdx.x; // Batch
int d1 = blockIdx.y / ((seq_length - 1) / blockDim.y + 1); // Head
int d2 = (threadIdx.y + blockDim.y * blockIdx.y) % seq_length;
int cnt = blockIdx.z;
int d3 = threadIdx.x; // Values (groups of 8)
if (d2 < seq_length) {
const float4* in_vec = reinterpret_cast<const float4*>(in);
float4* out_vec = reinterpret_cast<float4*>(out);
float4 vals_vec = in_vec[cnt * d0_stride * gridDim.x + d0 * d0_stride + d1 * d1_stride +
d2 * d2_stride + d3];
out_vec[d0 * d0_out_stride * gridDim.z + cnt * d2_out_stride + d1 * d1_out_stride +
d2 * d2_out_stride * gridDim.z + d3] = vals_vec;
}
}
template <>
__global__ void transform4d_0213<__half>(__half* out,
const __half* in,
int heads,
int seq_length,
int hidden_dim,
int head_ext)
{
#if __CUDA_ARCH__ >= 700
int d0_stride = hidden_dim * (seq_length / head_ext);
int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads;
int d0 = blockIdx.x; // Batch
int d1 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head
int d2 = blockIdx.z / head_ext; // Sequence
int cnt = blockIdx.y; // Hidden count
int d3 = threadIdx.x; // Values (groups of 8)
const float4* in_vec = reinterpret_cast<const float4*>(in);
float4* out_vec = reinterpret_cast<float4*>(out);
in_vec += (cnt * d0_stride * gridDim.x);
in_vec += (d0 * d0_stride);
in_vec += (d2 * d2_stride);
in_vec += (d1 * d2_stride * seq_length);
out_vec += (cnt * d1_stride);
out_vec += (d1 * d2_stride);
out_vec += (d0 * d0_stride * gridDim.y);
out_vec += (d2 * d1_stride * gridDim.y);
out_vec[d3] = in_vec[d3];
#endif
}
__global__ void transform4d_0213_v2(__half* out,
const __half* in,
int heads,
int seq_length,
int hidden_dim)
{
#if __CUDA_ARCH__ >= 700
__shared__ float4 in_data[3072];
int d0_stride = hidden_dim * seq_length;
int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads;
int d0 = blockIdx.x; // Batch
int d1 = threadIdx.y; // Head
int d2 = blockIdx.y; // Sequence
int cnt = threadIdx.z; // Hidden count
int d3 = threadIdx.x; // Values (groups of 8)
const float4* in_vec = reinterpret_cast<const float4*>(in);
float4* out_vec = reinterpret_cast<float4*>(out);
int input_offset = d0 * d0_stride + d2 * (d2_stride << 1) + d3 + (d1 % 2) * d2_stride;
int head_count = (d1 >> 1) + cnt * (blockDim.y >> 1);
int iteration_stride = blockDim.z * (blockDim.y >> 1);
int matrix_stride = (d0_stride * gridDim.x);
#pragma unroll
for (int iter = 0; iter < 2; iter++) {
int iter_row = iter * iteration_stride + head_count;
int iter_offset = (iter_row % blockDim.y) * d2_stride;
in_data[d3 + iter_offset + (iter_row / blockDim.y + (d1 % 2) * blockDim.z) * d1_stride] =
in_vec[input_offset + iter_offset * seq_length +
(iter_row / blockDim.y) * matrix_stride];
}
__syncthreads();
iteration_stride = d1_stride * blockDim.z;
int iter_index = cnt * d1_stride + d1 * d2_stride + d3;
int output_offset = d0 * d0_stride * blockDim.z + d2 * (iteration_stride << 1);
#pragma unroll
for (int iter = 0; iter < 2; iter++) {
int iter_id = iter * iteration_stride + iter_index;
out_vec[output_offset + iter_id] = in_data[iter_id];
}
#endif
}
// 3 * [B A S N] - > [B S C*H]
template <>
void launch_transform4d_0213<float>(float* out,
const float* in,
int batch_size,
int heads,
int seq_length,
int hidden_dim,
hipStream_t stream,
int trans_count)
{
hidden_dim >>= 2;
dim3 grid_dims(batch_size, heads * ((seq_length - 1) / 8 + 1), trans_count);
dim3 block_dims(hidden_dim / heads, 8);
hipLaunchKernelGGL(( transform4d_0213<float>)
, dim3(grid_dims), dim3(block_dims), 0, stream, out, in, heads, seq_length, hidden_dim, 1);
}
template <>
void launch_transform4d_0213<__half>(__half* out,
const __half* in,
int batch_size,
int heads,
int seq_length,
int hidden_dim,
hipStream_t stream,
int trans_count)
{
hidden_dim >>= 3;
if (hidden_dim > 128 || hidden_dim < 16) {
int head_ext = (hidden_dim - 1) / MAX_THREADS + 1;
dim3 grid_dims(batch_size, trans_count, (seq_length * head_ext));
dim3 block_dims(hidden_dim / heads, (heads / head_ext));
hipLaunchKernelGGL(( transform4d_0213<__half>), dim3(grid_dims), dim3(block_dims), 0, stream,
out, in, heads, seq_length, hidden_dim, head_ext);
} else {
dim3 grid_dims(batch_size, seq_length / 2);
dim3 block_dims(hidden_dim / heads, heads, trans_count);
hipLaunchKernelGGL(( transform4d_0213_v2), dim3(grid_dims), dim3(block_dims), 0, stream,
out, in, heads, seq_length, hidden_dim);
}
}
/*
Copyright 2020 The Microsoft DeepSpeed Team
Copyright NVIDIA/apex
This file is adapted from fused adam in NVIDIA/apex, commit a109f85
*/
#include <torch/csrc/utils/tensor_flatten.h>
#include <torch/extension.h>
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/utils/tensor_flatten.h
at::Tensor flatten(std::vector<at::Tensor> tensors)
{
return torch::utils::flatten_dense_tensors(tensors);
}
std::vector<at::Tensor> unflatten(at::Tensor flat, std::vector<at::Tensor> tensors)
{
return torch::utils::unflatten_dense_tensors(flat, tensors);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("flatten", &flatten, "Flatten dense tensors");
m.def("unflatten", &unflatten, "Unflatten dense tensors");
}
../../csrc
\ No newline at end of file
/* Copyright 2020 The Microsoft DeepSpeed Team
Copyright NVIDIA/apex
This file is adapted from fused adam in NVIDIA/apex, commit a109f85
*/
#ifndef TORCH_CHECK
#define TORCH_CHECK AT_CHECK
#endif
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif
#include "cpu_adam.h"
#include <cuda_runtime_api.h>
#include <math.h>
#include <omp.h>
#include <torch/extension.h>
#include <iostream>
#include <memory>
#include <type_traits>
#include <unordered_map>
#include "cublas_v2.h"
#include "cuda.h"
#include "curand.h"
#include "custom_cuda_layers.h"
static std::unordered_map<int, std::shared_ptr<void>> s_optimizers;
#define ROUND_DOWN(size, step) ((size) & ~((step)-1))
// C++ interface
void Adam_Optimizer::Step(float* _params,
float* grads,
float* _exp_avg,
float* _exp_avg_sq,
size_t _param_size,
__half* dev_params)
{
float betta1_minus1 = 1 - _betta1;
float betta2_minus1 = 1 - _betta2;
float step_size = -1 * _alpha / _bias_correction1;
float w_decay = -1 * _alpha * _weight_decay;
size_t rounded_size = 0;
#if defined(__AVX512__) or defined(__AVX256__)
AVX_Data betta1_4;
betta1_4.data = SIMD_SET(_betta1);
AVX_Data betta2_4;
betta2_4.data = SIMD_SET(_betta2);
AVX_Data betta1_minus1_4;
betta1_minus1_4.data = SIMD_SET(betta1_minus1);
AVX_Data betta2_minus1_4;
betta2_minus1_4.data = SIMD_SET(betta2_minus1);
AVX_Data bias2_sqrt;
bias2_sqrt.data = SIMD_SET(_bias_correction2);
AVX_Data eps_4;
eps_4.data = SIMD_SET(_eps);
AVX_Data step_size_4;
step_size_4.data = SIMD_SET(step_size);
AVX_Data weight_decay4;
if (_weight_decay > 0)
weight_decay4.data = (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay));
rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH);
for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
size_t offset = copy_size + t;
if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
#pragma omp parallel for
for (size_t i = t; i < offset; i += SIMD_WIDTH) {
AVX_Data grad_4;
grad_4.data = SIMD_LOAD(grads + i);
AVX_Data momentum_4;
momentum_4.data = SIMD_LOAD(_exp_avg + i);
AVX_Data variance_4;
variance_4.data = SIMD_LOAD(_exp_avg_sq + i);
AVX_Data param_4;
param_4.data = SIMD_LOAD(_params + i);
if (_weight_decay > 0 && !_adamw_mode) {
grad_4.data = SIMD_FMA(param_4.data, weight_decay4.data, grad_4.data);
}
momentum_4.data = SIMD_MUL(momentum_4.data, betta1_4.data);
momentum_4.data = SIMD_FMA(grad_4.data, betta1_minus1_4.data, momentum_4.data);
variance_4.data = SIMD_MUL(variance_4.data, betta2_4.data);
grad_4.data = SIMD_MUL(grad_4.data, grad_4.data);
variance_4.data = SIMD_FMA(grad_4.data, betta2_minus1_4.data, variance_4.data);
grad_4.data = SIMD_SQRT(variance_4.data);
grad_4.data = SIMD_FMA(grad_4.data, bias2_sqrt.data, eps_4.data);
grad_4.data = SIMD_DIV(momentum_4.data, grad_4.data);
if (_weight_decay > 0 && _adamw_mode) {
param_4.data = SIMD_FMA(param_4.data, weight_decay4.data, param_4.data);
}
param_4.data = SIMD_FMA(grad_4.data, step_size_4.data, param_4.data);
SIMD_STORE(_params + i, param_4.data);
if (dev_params) SIMD_STORE(_doubled_buffer[_buf_index] + (i - t), param_4.data);
SIMD_STORE(_exp_avg + i, momentum_4.data);
SIMD_STORE(_exp_avg_sq + i, variance_4.data);
}
if (dev_params) {
launch_param_update(
_doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]);
_buf_index = !_buf_index;
}
}
#endif
if (_param_size > rounded_size) {
for (size_t t = rounded_size; t < _param_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > _param_size) copy_size = _param_size - t;
size_t offset = copy_size + t;
if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
#pragma omp parallel for
for (size_t k = t; k < offset; k++) {
float grad = grads[k];
float param = _params[k];
float momentum = _exp_avg[k];
float variance = _exp_avg_sq[k];
if (_weight_decay > 0 && !_adamw_mode) { grad = param * _weight_decay + grad; }
momentum = momentum * _betta1;
momentum = grad * betta1_minus1 + momentum;
variance = variance * _betta2;
grad = grad * grad;
variance = grad * betta2_minus1 + variance;
grad = sqrt(variance);
grad = grad * _bias_correction2 + _eps;
grad = momentum / grad;
if (_weight_decay > 0 && _adamw_mode) { param += w_decay * param; }
param = grad * step_size + param;
if (dev_params) _doubled_buffer[_buf_index][k - t] = param;
_params[k] = param;
_exp_avg[k] = momentum;
_exp_avg_sq[k] = variance;
}
if (dev_params) {
launch_param_update(
_doubled_buffer[_buf_index], dev_params + t, (copy_size), _streams[_buf_index]);
_buf_index = !_buf_index;
}
}
}
}
void Adam_Optimizer::Step_4(float* _params,
float* grads,
float* _exp_avg,
float* _exp_avg_sq,
size_t _param_size,
__half* dev_params)
{
size_t rounded_size = 0;
#if defined(__AVX512__) or defined(__AVX256__)
AVX_Data betta1_4;
betta1_4.data = SIMD_SET(_betta1);
AVX_Data betta2_4;
betta2_4.data = SIMD_SET(_betta2);
float betta1_minus1 = 1 - _betta1;
float betta2_minus1 = 1 - _betta2;
AVX_Data betta1_minus1_4;
betta1_minus1_4.data = SIMD_SET(betta1_minus1);
AVX_Data betta2_minus1_4;
betta2_minus1_4.data = SIMD_SET(betta2_minus1);
AVX_Data bias2_sqrt;
bias2_sqrt.data = SIMD_SET(_bias_correction2);
AVX_Data eps_4;
eps_4.data = SIMD_SET(_eps);
float step_size = -1 * _alpha / _bias_correction1;
AVX_Data step_size_4;
step_size_4.data = SIMD_SET(step_size);
float w_decay = -1 * _alpha * _weight_decay;
AVX_Data weight_decay4;
if (_weight_decay > 0)
weight_decay4.data = (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay));
rounded_size = ROUND_DOWN(_param_size, (SIMD_WIDTH << 2));
for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
size_t offset = copy_size + t;
if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
#pragma omp parallel for
for (size_t i = t; i < offset; i += (SIMD_WIDTH << 2)) {
AVX_Data grad_4[4];
grad_4[0].data = SIMD_LOAD(grads + i);
grad_4[1].data = SIMD_LOAD(grads + i + SIMD_WIDTH);
grad_4[2].data = SIMD_LOAD(grads + i + (SIMD_WIDTH << 1));
grad_4[3].data = SIMD_LOAD(grads + i + SIMD_WIDTH * 3);
AVX_Data momentum_4[4];
momentum_4[0].data = SIMD_LOAD(_exp_avg + i);
momentum_4[1].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH);
momentum_4[2].data = SIMD_LOAD(_exp_avg + i + (SIMD_WIDTH << 1));
momentum_4[3].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * 3);
AVX_Data variance_4[4];
variance_4[0].data = SIMD_LOAD(_exp_avg_sq + i);
variance_4[1].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH);
variance_4[2].data = SIMD_LOAD(_exp_avg_sq + i + (SIMD_WIDTH << 1));
variance_4[3].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * 3);
AVX_Data param_4[4];
param_4[0].data = SIMD_LOAD(_params + i);
param_4[1].data = SIMD_LOAD(_params + i + SIMD_WIDTH);
param_4[2].data = SIMD_LOAD(_params + i + (SIMD_WIDTH << 1));
param_4[3].data = SIMD_LOAD(_params + i + SIMD_WIDTH * 3);
if (_weight_decay > 0 && !_adamw_mode) {
grad_4[0].data = SIMD_FMA(param_4[0].data, weight_decay4.data, grad_4[0].data);
grad_4[1].data = SIMD_FMA(param_4[1].data, weight_decay4.data, grad_4[1].data);
grad_4[2].data = SIMD_FMA(param_4[2].data, weight_decay4.data, grad_4[2].data);
grad_4[3].data = SIMD_FMA(param_4[3].data, weight_decay4.data, grad_4[3].data);
}
momentum_4[0].data = SIMD_MUL(momentum_4[0].data, betta1_4.data);
momentum_4[0].data = SIMD_FMA(grad_4[0].data, betta1_minus1_4.data, momentum_4[0].data);
momentum_4[1].data = SIMD_MUL(momentum_4[1].data, betta1_4.data);
momentum_4[1].data = SIMD_FMA(grad_4[1].data, betta1_minus1_4.data, momentum_4[1].data);
momentum_4[2].data = SIMD_MUL(momentum_4[2].data, betta1_4.data);
momentum_4[2].data = SIMD_FMA(grad_4[2].data, betta1_minus1_4.data, momentum_4[2].data);
momentum_4[3].data = SIMD_MUL(momentum_4[3].data, betta1_4.data);
momentum_4[3].data = SIMD_FMA(grad_4[3].data, betta1_minus1_4.data, momentum_4[3].data);
variance_4[0].data = SIMD_MUL(variance_4[0].data, betta2_4.data);
variance_4[1].data = SIMD_MUL(variance_4[1].data, betta2_4.data);
variance_4[2].data = SIMD_MUL(variance_4[2].data, betta2_4.data);
variance_4[3].data = SIMD_MUL(variance_4[3].data, betta2_4.data);
grad_4[0].data = SIMD_MUL(grad_4[0].data, grad_4[0].data);
grad_4[1].data = SIMD_MUL(grad_4[1].data, grad_4[1].data);
grad_4[2].data = SIMD_MUL(grad_4[2].data, grad_4[2].data);
grad_4[3].data = SIMD_MUL(grad_4[3].data, grad_4[3].data);
variance_4[0].data = SIMD_FMA(grad_4[0].data, betta2_minus1_4.data, variance_4[0].data);
variance_4[1].data = SIMD_FMA(grad_4[1].data, betta2_minus1_4.data, variance_4[1].data);
variance_4[2].data = SIMD_FMA(grad_4[2].data, betta2_minus1_4.data, variance_4[2].data);
variance_4[3].data = SIMD_FMA(grad_4[3].data, betta2_minus1_4.data, variance_4[3].data);
grad_4[0].data = SIMD_SQRT(variance_4[0].data);
grad_4[1].data = SIMD_SQRT(variance_4[1].data);
grad_4[2].data = SIMD_SQRT(variance_4[2].data);
grad_4[3].data = SIMD_SQRT(variance_4[3].data);
grad_4[0].data = SIMD_FMA(grad_4[0].data, bias2_sqrt.data, eps_4.data);
grad_4[1].data = SIMD_FMA(grad_4[1].data, bias2_sqrt.data, eps_4.data);
grad_4[2].data = SIMD_FMA(grad_4[2].data, bias2_sqrt.data, eps_4.data);
grad_4[3].data = SIMD_FMA(grad_4[3].data, bias2_sqrt.data, eps_4.data);
grad_4[0].data = SIMD_DIV(momentum_4[0].data, grad_4[0].data);
grad_4[1].data = SIMD_DIV(momentum_4[1].data, grad_4[1].data);
grad_4[2].data = SIMD_DIV(momentum_4[2].data, grad_4[2].data);
grad_4[3].data = SIMD_DIV(momentum_4[3].data, grad_4[3].data);
if (_weight_decay > 0 && _adamw_mode) {
param_4[0].data = SIMD_FMA(param_4[0].data, weight_decay4.data, param_4[0].data);
param_4[1].data = SIMD_FMA(param_4[1].data, weight_decay4.data, param_4[1].data);
param_4[2].data = SIMD_FMA(param_4[2].data, weight_decay4.data, param_4[2].data);
param_4[3].data = SIMD_FMA(param_4[3].data, weight_decay4.data, param_4[3].data);
}
param_4[0].data = SIMD_FMA(grad_4[0].data, step_size_4.data, param_4[0].data);
param_4[1].data = SIMD_FMA(grad_4[1].data, step_size_4.data, param_4[1].data);
param_4[2].data = SIMD_FMA(grad_4[2].data, step_size_4.data, param_4[2].data);
param_4[3].data = SIMD_FMA(grad_4[3].data, step_size_4.data, param_4[3].data);
SIMD_STORE(_params + i, param_4[0].data);
SIMD_STORE(_params + i + SIMD_WIDTH, param_4[1].data);
SIMD_STORE(_params + i + (SIMD_WIDTH << 1), param_4[2].data);
SIMD_STORE(_params + i + SIMD_WIDTH * 3, param_4[3].data);
if (dev_params) {
SIMD_STORE(_doubled_buffer[_buf_index] + (i - t), param_4[0].data);
SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH, param_4[1].data);
SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + (SIMD_WIDTH << 1),
param_4[2].data);
SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH * 3, param_4[3].data);
}
SIMD_STORE(_exp_avg + i, momentum_4[0].data);
SIMD_STORE(_exp_avg + i + SIMD_WIDTH, momentum_4[1].data);
SIMD_STORE(_exp_avg + i + (SIMD_WIDTH << 1), momentum_4[2].data);
SIMD_STORE(_exp_avg + i + SIMD_WIDTH * 3, momentum_4[3].data);
SIMD_STORE(_exp_avg_sq + i, variance_4[0].data);
SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH, variance_4[1].data);
SIMD_STORE(_exp_avg_sq + i + (SIMD_WIDTH << 1), variance_4[2].data);
SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * 3, variance_4[3].data);
}
if (dev_params) {
launch_param_update(
_doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]);
_buf_index = !_buf_index;
}
}
#endif
if (_param_size > rounded_size)
Step((_params + rounded_size),
(grads + rounded_size),
(_exp_avg + rounded_size),
(_exp_avg_sq + rounded_size),
(_param_size - rounded_size),
(dev_params != nullptr ? (dev_params + rounded_size) : dev_params));
}
int create_adam_optimizer(int optimizer_id,
float alpha = 1e-3,
float betta1 = 0.9,
float betta2 = 0.999,
float eps = 1e-8,
float weight_decay = 0,
bool adamw_mode = true)
{
auto opt =
std::make_shared<Adam_Optimizer>(alpha, betta1, betta2, eps, weight_decay, adamw_mode);
s_optimizers[optimizer_id] = opt;
#if defined(__AVX512__)
std::cout << "Adam Optimizer #" << optimizer_id
<< " is created with AVX512 arithmetic capability." << std::endl;
printf("Config: alpha=%f, betas=(%f, %f), weight_decay=%f, adam_w=%d\n",
alpha,
betta1,
betta2,
weight_decay,
(int)adamw_mode);
#else
#if defined(__AVX256__)
std::cout << "Adam Optimizer #" << optimizer_id
<< " is created with AVX2 arithmetic capability." << std::endl;
printf("Config: alpha=%f, betas=(%f, %f), weight_decay=%f, adam_w=%d\n",
alpha,
betta1,
betta2,
weight_decay,
(int)adamw_mode);
#else
std::cout << "Adam Optimizer #" << optimizer_id
<< " is created with scalar arithmetic capability." << std::endl;
printf("Config: alpha=%f, betas=(%f, %f), weight_decay=%f, adam_w=%d\n",
alpha,
betta1,
betta2,
weight_decay,
(int)adamw_mode);
#endif
#endif
return 0;
}
void Adam_Optimizer::Step_8(float* _params,
float* grads,
float* _exp_avg,
float* _exp_avg_sq,
size_t _param_size,
__half* dev_params)
{
size_t rounded_size = 0;
#if defined(__AVX512__) or defined(__AVX256__)
AVX_Data betta1_4;
betta1_4.data = SIMD_SET(_betta1);
AVX_Data betta2_4;
betta2_4.data = SIMD_SET(_betta2);
float betta1_minus1 = 1 - _betta1;
float betta2_minus1 = 1 - _betta2;
AVX_Data betta1_minus1_4;
betta1_minus1_4.data = SIMD_SET(betta1_minus1);
AVX_Data betta2_minus1_4;
betta2_minus1_4.data = SIMD_SET(betta2_minus1);
AVX_Data bias2_sqrt;
bias2_sqrt.data = SIMD_SET(_bias_correction2);
AVX_Data eps_4;
eps_4.data = SIMD_SET(_eps);
float step_size = -1 * _alpha / _bias_correction1;
AVX_Data step_size_4;
step_size_4.data = SIMD_SET(step_size);
float w_decay = -1 * _alpha * _weight_decay;
AVX_Data weight_decay4;
if (_weight_decay > 0)
weight_decay4.data = (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay));
rounded_size = ROUND_DOWN(_param_size, (SIMD_WIDTH << 3));
for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
size_t offset = copy_size + t;
if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); }
#pragma omp parallel for
for (size_t i = t; i < offset; i += (SIMD_WIDTH << 3)) {
AVX_Data grad_4[8];
grad_4[0].data = SIMD_LOAD(grads + i);
grad_4[1].data = SIMD_LOAD(grads + i + SIMD_WIDTH);
grad_4[2].data = SIMD_LOAD(grads + i + (SIMD_WIDTH << 1));
grad_4[3].data = SIMD_LOAD(grads + i + SIMD_WIDTH * 3);
grad_4[4].data = SIMD_LOAD(grads + i + (SIMD_WIDTH << 2));
grad_4[5].data = SIMD_LOAD(grads + i + SIMD_WIDTH * 5);
grad_4[6].data = SIMD_LOAD(grads + i + SIMD_WIDTH * 6);
grad_4[7].data = SIMD_LOAD(grads + i + SIMD_WIDTH * 7);
AVX_Data momentum_4[8];
momentum_4[0].data = SIMD_LOAD(_exp_avg + i);
momentum_4[1].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH);
momentum_4[2].data = SIMD_LOAD(_exp_avg + i + (SIMD_WIDTH << 1));
momentum_4[3].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * 3);
momentum_4[4].data = SIMD_LOAD(_exp_avg + i + (SIMD_WIDTH << 2));
momentum_4[5].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * 5);
momentum_4[6].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * 6);
momentum_4[7].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * 7);
AVX_Data variance_4[8];
variance_4[0].data = SIMD_LOAD(_exp_avg_sq + i);
variance_4[1].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH);
variance_4[2].data = SIMD_LOAD(_exp_avg_sq + i + (SIMD_WIDTH << 1));
variance_4[3].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * 3);
variance_4[4].data = SIMD_LOAD(_exp_avg_sq + i + (SIMD_WIDTH << 2));
variance_4[5].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * 5);
variance_4[6].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * 6);
variance_4[7].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * 7);
AVX_Data param_4[8];
param_4[0].data = SIMD_LOAD(_params + i);
param_4[1].data = SIMD_LOAD(_params + i + SIMD_WIDTH);
param_4[2].data = SIMD_LOAD(_params + i + (SIMD_WIDTH << 1));
param_4[3].data = SIMD_LOAD(_params + i + SIMD_WIDTH * 3);
param_4[4].data = SIMD_LOAD(_params + i + (SIMD_WIDTH << 2));
param_4[5].data = SIMD_LOAD(_params + i + SIMD_WIDTH * 5);
param_4[6].data = SIMD_LOAD(_params + i + SIMD_WIDTH * 6);
param_4[7].data = SIMD_LOAD(_params + i + SIMD_WIDTH * 7);
if (_weight_decay > 0 && !_adamw_mode) {
grad_4[0].data = SIMD_FMA(param_4[0].data, weight_decay4.data, grad_4[0].data);
grad_4[1].data = SIMD_FMA(param_4[1].data, weight_decay4.data, grad_4[1].data);
grad_4[2].data = SIMD_FMA(param_4[2].data, weight_decay4.data, grad_4[2].data);
grad_4[3].data = SIMD_FMA(param_4[3].data, weight_decay4.data, grad_4[3].data);
grad_4[4].data = SIMD_FMA(param_4[4].data, weight_decay4.data, grad_4[4].data);
grad_4[5].data = SIMD_FMA(param_4[5].data, weight_decay4.data, grad_4[5].data);
grad_4[6].data = SIMD_FMA(param_4[6].data, weight_decay4.data, grad_4[6].data);
grad_4[7].data = SIMD_FMA(param_4[7].data, weight_decay4.data, grad_4[7].data);
}
momentum_4[0].data = SIMD_MUL(momentum_4[0].data, betta1_4.data);
momentum_4[0].data = SIMD_FMA(grad_4[0].data, betta1_minus1_4.data, momentum_4[0].data);
momentum_4[1].data = SIMD_MUL(momentum_4[1].data, betta1_4.data);
momentum_4[1].data = SIMD_FMA(grad_4[1].data, betta1_minus1_4.data, momentum_4[1].data);
momentum_4[2].data = SIMD_MUL(momentum_4[2].data, betta1_4.data);
momentum_4[2].data = SIMD_FMA(grad_4[2].data, betta1_minus1_4.data, momentum_4[2].data);
momentum_4[3].data = SIMD_MUL(momentum_4[3].data, betta1_4.data);
momentum_4[3].data = SIMD_FMA(grad_4[3].data, betta1_minus1_4.data, momentum_4[3].data);
momentum_4[4].data = SIMD_MUL(momentum_4[4].data, betta1_4.data);
momentum_4[4].data = SIMD_FMA(grad_4[4].data, betta1_minus1_4.data, momentum_4[4].data);
momentum_4[5].data = SIMD_MUL(momentum_4[5].data, betta1_4.data);
momentum_4[5].data = SIMD_FMA(grad_4[5].data, betta1_minus1_4.data, momentum_4[5].data);
momentum_4[6].data = SIMD_MUL(momentum_4[6].data, betta1_4.data);
momentum_4[6].data = SIMD_FMA(grad_4[6].data, betta1_minus1_4.data, momentum_4[6].data);
momentum_4[7].data = SIMD_MUL(momentum_4[7].data, betta1_4.data);
momentum_4[7].data = SIMD_FMA(grad_4[7].data, betta1_minus1_4.data, momentum_4[7].data);
variance_4[0].data = SIMD_MUL(variance_4[0].data, betta2_4.data);
variance_4[1].data = SIMD_MUL(variance_4[1].data, betta2_4.data);
variance_4[2].data = SIMD_MUL(variance_4[2].data, betta2_4.data);
variance_4[3].data = SIMD_MUL(variance_4[3].data, betta2_4.data);
variance_4[4].data = SIMD_MUL(variance_4[4].data, betta2_4.data);
variance_4[5].data = SIMD_MUL(variance_4[5].data, betta2_4.data);
variance_4[6].data = SIMD_MUL(variance_4[6].data, betta2_4.data);
variance_4[7].data = SIMD_MUL(variance_4[7].data, betta2_4.data);
grad_4[0].data = SIMD_MUL(grad_4[0].data, grad_4[0].data);
grad_4[1].data = SIMD_MUL(grad_4[1].data, grad_4[1].data);
grad_4[2].data = SIMD_MUL(grad_4[2].data, grad_4[2].data);
grad_4[3].data = SIMD_MUL(grad_4[3].data, grad_4[3].data);
grad_4[4].data = SIMD_MUL(grad_4[4].data, grad_4[4].data);
grad_4[5].data = SIMD_MUL(grad_4[5].data, grad_4[5].data);
grad_4[6].data = SIMD_MUL(grad_4[6].data, grad_4[6].data);
grad_4[7].data = SIMD_MUL(grad_4[7].data, grad_4[7].data);
variance_4[0].data = SIMD_FMA(grad_4[0].data, betta2_minus1_4.data, variance_4[0].data);
variance_4[1].data = SIMD_FMA(grad_4[1].data, betta2_minus1_4.data, variance_4[1].data);
variance_4[2].data = SIMD_FMA(grad_4[2].data, betta2_minus1_4.data, variance_4[2].data);
variance_4[3].data = SIMD_FMA(grad_4[3].data, betta2_minus1_4.data, variance_4[3].data);
variance_4[4].data = SIMD_FMA(grad_4[4].data, betta2_minus1_4.data, variance_4[4].data);
variance_4[5].data = SIMD_FMA(grad_4[5].data, betta2_minus1_4.data, variance_4[5].data);
variance_4[6].data = SIMD_FMA(grad_4[6].data, betta2_minus1_4.data, variance_4[6].data);
variance_4[7].data = SIMD_FMA(grad_4[7].data, betta2_minus1_4.data, variance_4[7].data);
grad_4[0].data = SIMD_SQRT(variance_4[0].data);
grad_4[1].data = SIMD_SQRT(variance_4[1].data);
grad_4[2].data = SIMD_SQRT(variance_4[2].data);
grad_4[3].data = SIMD_SQRT(variance_4[3].data);
grad_4[4].data = SIMD_SQRT(variance_4[4].data);
grad_4[5].data = SIMD_SQRT(variance_4[5].data);
grad_4[6].data = SIMD_SQRT(variance_4[6].data);
grad_4[7].data = SIMD_SQRT(variance_4[7].data);
grad_4[0].data = SIMD_FMA(grad_4[0].data, bias2_sqrt.data, eps_4.data);
grad_4[1].data = SIMD_FMA(grad_4[1].data, bias2_sqrt.data, eps_4.data);
grad_4[2].data = SIMD_FMA(grad_4[2].data, bias2_sqrt.data, eps_4.data);
grad_4[3].data = SIMD_FMA(grad_4[3].data, bias2_sqrt.data, eps_4.data);
grad_4[4].data = SIMD_FMA(grad_4[4].data, bias2_sqrt.data, eps_4.data);
grad_4[5].data = SIMD_FMA(grad_4[5].data, bias2_sqrt.data, eps_4.data);
grad_4[6].data = SIMD_FMA(grad_4[6].data, bias2_sqrt.data, eps_4.data);
grad_4[7].data = SIMD_FMA(grad_4[7].data, bias2_sqrt.data, eps_4.data);
grad_4[0].data = SIMD_DIV(momentum_4[0].data, grad_4[0].data);
grad_4[1].data = SIMD_DIV(momentum_4[1].data, grad_4[1].data);
grad_4[2].data = SIMD_DIV(momentum_4[2].data, grad_4[2].data);
grad_4[3].data = SIMD_DIV(momentum_4[3].data, grad_4[3].data);
grad_4[4].data = SIMD_DIV(momentum_4[4].data, grad_4[4].data);
grad_4[5].data = SIMD_DIV(momentum_4[5].data, grad_4[5].data);
grad_4[6].data = SIMD_DIV(momentum_4[6].data, grad_4[6].data);
grad_4[7].data = SIMD_DIV(momentum_4[7].data, grad_4[7].data);
if (_weight_decay > 0 && _adamw_mode) {
param_4[0].data = SIMD_FMA(param_4[0].data, weight_decay4.data, param_4[0].data);
param_4[1].data = SIMD_FMA(param_4[1].data, weight_decay4.data, param_4[1].data);
param_4[2].data = SIMD_FMA(param_4[2].data, weight_decay4.data, param_4[2].data);
param_4[3].data = SIMD_FMA(param_4[3].data, weight_decay4.data, param_4[3].data);
param_4[4].data = SIMD_FMA(param_4[4].data, weight_decay4.data, param_4[4].data);
param_4[5].data = SIMD_FMA(param_4[5].data, weight_decay4.data, param_4[5].data);
param_4[6].data = SIMD_FMA(param_4[6].data, weight_decay4.data, param_4[6].data);
param_4[7].data = SIMD_FMA(param_4[7].data, weight_decay4.data, param_4[7].data);
}
param_4[0].data = SIMD_FMA(grad_4[0].data, step_size_4.data, param_4[0].data);
param_4[1].data = SIMD_FMA(grad_4[1].data, step_size_4.data, param_4[1].data);
param_4[2].data = SIMD_FMA(grad_4[2].data, step_size_4.data, param_4[2].data);
param_4[3].data = SIMD_FMA(grad_4[3].data, step_size_4.data, param_4[3].data);
param_4[4].data = SIMD_FMA(grad_4[4].data, step_size_4.data, param_4[4].data);
param_4[5].data = SIMD_FMA(grad_4[5].data, step_size_4.data, param_4[5].data);
param_4[6].data = SIMD_FMA(grad_4[6].data, step_size_4.data, param_4[6].data);
param_4[7].data = SIMD_FMA(grad_4[7].data, step_size_4.data, param_4[7].data);
SIMD_STORE(_params + i, param_4[0].data);
SIMD_STORE(_params + i + SIMD_WIDTH, param_4[1].data);
SIMD_STORE(_params + i + (SIMD_WIDTH << 1), param_4[2].data);
SIMD_STORE(_params + i + SIMD_WIDTH * 3, param_4[3].data);
SIMD_STORE(_params + i + (SIMD_WIDTH << 2), param_4[4].data);
SIMD_STORE(_params + i + SIMD_WIDTH * 5, param_4[5].data);
SIMD_STORE(_params + i + SIMD_WIDTH * 6, param_4[6].data);
SIMD_STORE(_params + i + SIMD_WIDTH * 7, param_4[7].data);
if (dev_params) {
SIMD_STORE(_doubled_buffer[_buf_index] + (i - t), param_4[0].data);
SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH, param_4[1].data);
SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + (SIMD_WIDTH << 1),
param_4[2].data);
SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH * 3, param_4[3].data);
SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + (SIMD_WIDTH << 2),
param_4[4].data);
SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH * 5, param_4[5].data);
SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH * 6, param_4[6].data);
SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH * 7, param_4[7].data);
}
SIMD_STORE(_exp_avg + i, momentum_4[0].data);
SIMD_STORE(_exp_avg + i + SIMD_WIDTH, momentum_4[1].data);
SIMD_STORE(_exp_avg + i + (SIMD_WIDTH << 1), momentum_4[2].data);
SIMD_STORE(_exp_avg + i + SIMD_WIDTH * 3, momentum_4[3].data);
SIMD_STORE(_exp_avg + i + (SIMD_WIDTH << 2), momentum_4[4].data);
SIMD_STORE(_exp_avg + i + SIMD_WIDTH * 5, momentum_4[5].data);
SIMD_STORE(_exp_avg + i + SIMD_WIDTH * 6, momentum_4[6].data);
SIMD_STORE(_exp_avg + i + SIMD_WIDTH * 7, momentum_4[7].data);
SIMD_STORE(_exp_avg_sq + i, variance_4[0].data);
SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH, variance_4[1].data);
SIMD_STORE(_exp_avg_sq + i + (SIMD_WIDTH << 1), variance_4[2].data);
SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * 3, variance_4[3].data);
SIMD_STORE(_exp_avg_sq + i + (SIMD_WIDTH << 2), variance_4[4].data);
SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * 5, variance_4[5].data);
SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * 6, variance_4[6].data);
SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * 7, variance_4[7].data);
}
if (dev_params) {
launch_param_update(
_doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]);
_buf_index = !_buf_index;
}
}
#endif
if (_param_size > rounded_size)
Step_4((_params + rounded_size),
(grads + rounded_size),
(_exp_avg + rounded_size),
(_exp_avg_sq + rounded_size),
(_param_size - rounded_size),
(dev_params != nullptr ? (dev_params + rounded_size) : dev_params));
}
int ds_adam_step(int optimizer_id,
size_t step,
float lr,
float beta1,
float beta2,
float epsilon,
float weight_decay,
bool bias_correction,
torch::Tensor& params,
torch::Tensor& grads,
torch::Tensor& exp_avg,
torch::Tensor& exp_avg_sq)
{
auto params_c = params.contiguous();
auto grads_c = grads.contiguous();
auto exp_avg_c = exp_avg.contiguous();
auto exp_avg_sq_c = exp_avg_sq.contiguous();
float* params_ptr = (float*)params_c.data_ptr();
float* grads_ptr = (float*)grads_c.data_ptr();
float* exp_avg_ptr = (float*)exp_avg_c.data_ptr();
float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr();
std::shared_ptr<Adam_Optimizer> opt =
std::static_pointer_cast<Adam_Optimizer>(s_optimizers[optimizer_id]);
opt->IncrementStep(step, beta1, beta2);
opt->update_state(lr, epsilon, weight_decay, bias_correction);
opt->Step_8(params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr, params_c.size(0));
opt->SynchronizeStreams();
return 0;
}
int ds_adam_step_plus_copy(int optimizer_id,
size_t step,
float lr,
float beta1,
float beta2,
float epsilon,
float weight_decay,
bool bias_correction,
torch::Tensor& params,
torch::Tensor& grads,
torch::Tensor& exp_avg,
torch::Tensor& exp_avg_sq,
torch::Tensor& gpu_params)
{
auto params_c = params.contiguous();
auto gpu_params_c = gpu_params.contiguous();
auto exp_avg_c = exp_avg.contiguous();
auto exp_avg_sq_c = exp_avg_sq.contiguous();
auto grads_c = grads.contiguous();
float* params_ptr = (float*)params_c.data_ptr();
float* grads_ptr = (float*)grads_c.data_ptr();
__half* gpu_params_ptr = (__half*)gpu_params_c.data_ptr();
float* exp_avg_ptr = (float*)exp_avg_c.data_ptr();
float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr();
std::shared_ptr<Adam_Optimizer> opt =
std::static_pointer_cast<Adam_Optimizer>(s_optimizers[optimizer_id]);
opt->IncrementStep(step, beta1, beta2);
opt->update_state(lr, epsilon, weight_decay, bias_correction);
opt->Step_8(
params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr, params_c.size(0), gpu_params_ptr);
opt->SynchronizeStreams();
return 0;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("adam_update", &ds_adam_step, "DeepSpeed CPU Adam update (C++)");
m.def("adam_update_copy",
&ds_adam_step_plus_copy,
"DeepSpeed CPU Adam update and param copy (C++)");
m.def("create_adam", &create_adam_optimizer, "DeepSpeed CPU Adam (C++)");
}
#include "custom_cuda_layers.h"
__global__ void param_update_kernel(const float* input, __half* output, int size)
{
int id = blockIdx.x * blockDim.x + threadIdx.x;
if (id < size) { output[id] = (__half)input[id]; }
}
void launch_param_update(const float* input, __half* output, int size, cudaStream_t stream)
{
int threads = 1024;
dim3 grid_dim((size - 1) / threads + 1);
dim3 block_dim(threads);
param_update_kernel<<<grid_dim, block_dim, 0, stream>>>(input, output, size);
}
#include <torch/extension.h>
void multi_tensor_adam_cuda(int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
const float lr,
const float beta1,
const float beta2,
const float epsilon,
const int step,
const int mode,
const int bias_correction,
const float weight_decay);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("multi_tensor_adam",
&multi_tensor_adam_cuda,
"Compute and apply gradient update to parameters for Adam optimizer");
}
/* Copyright 2020 The Microsoft DeepSpeed Team
Copyright NVIDIA/apex
This file is adapted from fused adam in NVIDIA/apex, commit a109f85
*/
#ifndef TORCH_CHECK
#define TORCH_CHECK AT_CHECK
#endif
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif
#include "cpu_adam.h"
#include <hip/hip_runtime_api.h>
#include <math.h>
#include <omp.h>
#include <torch/extension.h>
#include <iostream>
#include <memory>
#include <type_traits>
#include <unordered_map>
#include "rocblas.h"
#include "hip/hip_runtime.h"
#include "hiprand.h"
#include "custom_cuda_layers.h"
static std::unordered_map<int, std::shared_ptr<void>> s_optimizers;
#define ROUND_DOWN(size, step) ((size) & ~((step)-1))
// C++ interface
void Adam_Optimizer::Step(float* _params,
float* grads,
float* _exp_avg,
float* _exp_avg_sq,
size_t _param_size,
__half* dev_params)
{
float betta1_minus1 = 1 - _betta1;
float betta2_minus1 = 1 - _betta2;
float step_size = -1 * _alpha / _bias_correction1;
float w_decay = -1 * _alpha * _weight_decay;
size_t rounded_size = 0;
#if defined(__AVX512__) or defined(__AVX256__)
AVX_Data betta1_4;
betta1_4.data = SIMD_SET(_betta1);
AVX_Data betta2_4;
betta2_4.data = SIMD_SET(_betta2);
AVX_Data betta1_minus1_4;
betta1_minus1_4.data = SIMD_SET(betta1_minus1);
AVX_Data betta2_minus1_4;
betta2_minus1_4.data = SIMD_SET(betta2_minus1);
AVX_Data bias2_sqrt;
bias2_sqrt.data = SIMD_SET(_bias_correction2);
AVX_Data eps_4;
eps_4.data = SIMD_SET(_eps);
AVX_Data step_size_4;
step_size_4.data = SIMD_SET(step_size);
AVX_Data weight_decay4;
if (_weight_decay > 0)
weight_decay4.data = (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay));
rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH);
for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
size_t offset = copy_size + t;
if ((t / TILE) >= 2) { hipStreamSynchronize(_streams[_buf_index]); }
#pragma omp parallel for
for (size_t i = t; i < offset; i += SIMD_WIDTH) {
AVX_Data grad_4;
grad_4.data = SIMD_LOAD(grads + i);
AVX_Data momentum_4;
momentum_4.data = SIMD_LOAD(_exp_avg + i);
AVX_Data variance_4;
variance_4.data = SIMD_LOAD(_exp_avg_sq + i);
AVX_Data param_4;
param_4.data = SIMD_LOAD(_params + i);
if (_weight_decay > 0 && !_adamw_mode) {
grad_4.data = SIMD_FMA(param_4.data, weight_decay4.data, grad_4.data);
}
momentum_4.data = SIMD_MUL(momentum_4.data, betta1_4.data);
momentum_4.data = SIMD_FMA(grad_4.data, betta1_minus1_4.data, momentum_4.data);
variance_4.data = SIMD_MUL(variance_4.data, betta2_4.data);
grad_4.data = SIMD_MUL(grad_4.data, grad_4.data);
variance_4.data = SIMD_FMA(grad_4.data, betta2_minus1_4.data, variance_4.data);
grad_4.data = SIMD_SQRT(variance_4.data);
grad_4.data = SIMD_FMA(grad_4.data, bias2_sqrt.data, eps_4.data);
grad_4.data = SIMD_DIV(momentum_4.data, grad_4.data);
if (_weight_decay > 0 && _adamw_mode) {
param_4.data = SIMD_FMA(param_4.data, weight_decay4.data, param_4.data);
}
param_4.data = SIMD_FMA(grad_4.data, step_size_4.data, param_4.data);
SIMD_STORE(_params + i, param_4.data);
if (dev_params) SIMD_STORE(_doubled_buffer[_buf_index] + (i - t), param_4.data);
SIMD_STORE(_exp_avg + i, momentum_4.data);
SIMD_STORE(_exp_avg_sq + i, variance_4.data);
}
if (dev_params) {
launch_param_update(
_doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]);
_buf_index = !_buf_index;
}
}
#endif
if (_param_size > rounded_size) {
for (size_t t = rounded_size; t < _param_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > _param_size) copy_size = _param_size - t;
size_t offset = copy_size + t;
if ((t / TILE) >= 2) { hipStreamSynchronize(_streams[_buf_index]); }
#pragma omp parallel for
for (size_t k = t; k < offset; k++) {
float grad = grads[k];
float param = _params[k];
float momentum = _exp_avg[k];
float variance = _exp_avg_sq[k];
if (_weight_decay > 0 && !_adamw_mode) { grad = param * _weight_decay + grad; }
momentum = momentum * _betta1;
momentum = grad * betta1_minus1 + momentum;
variance = variance * _betta2;
grad = grad * grad;
variance = grad * betta2_minus1 + variance;
grad = sqrt(variance);
grad = grad * _bias_correction2 + _eps;
grad = momentum / grad;
if (_weight_decay > 0 && _adamw_mode) { param += w_decay * param; }
param = grad * step_size + param;
if (dev_params) _doubled_buffer[_buf_index][k - t] = param;
_params[k] = param;
_exp_avg[k] = momentum;
_exp_avg_sq[k] = variance;
}
if (dev_params) {
launch_param_update(
_doubled_buffer[_buf_index], dev_params + t, (copy_size), _streams[_buf_index]);
_buf_index = !_buf_index;
}
}
}
}
void Adam_Optimizer::Step_4(float* _params,
float* grads,
float* _exp_avg,
float* _exp_avg_sq,
size_t _param_size,
__half* dev_params)
{
size_t rounded_size = 0;
#if defined(__AVX512__) or defined(__AVX256__)
AVX_Data betta1_4;
betta1_4.data = SIMD_SET(_betta1);
AVX_Data betta2_4;
betta2_4.data = SIMD_SET(_betta2);
float betta1_minus1 = 1 - _betta1;
float betta2_minus1 = 1 - _betta2;
AVX_Data betta1_minus1_4;
betta1_minus1_4.data = SIMD_SET(betta1_minus1);
AVX_Data betta2_minus1_4;
betta2_minus1_4.data = SIMD_SET(betta2_minus1);
AVX_Data bias2_sqrt;
bias2_sqrt.data = SIMD_SET(_bias_correction2);
AVX_Data eps_4;
eps_4.data = SIMD_SET(_eps);
float step_size = -1 * _alpha / _bias_correction1;
AVX_Data step_size_4;
step_size_4.data = SIMD_SET(step_size);
float w_decay = -1 * _alpha * _weight_decay;
AVX_Data weight_decay4;
if (_weight_decay > 0)
weight_decay4.data = (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay));
rounded_size = ROUND_DOWN(_param_size, (SIMD_WIDTH << 2));
for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
size_t offset = copy_size + t;
if ((t / TILE) >= 2) { hipStreamSynchronize(_streams[_buf_index]); }
#pragma omp parallel for
for (size_t i = t; i < offset; i += (SIMD_WIDTH << 2)) {
AVX_Data grad_4[4];
grad_4[0].data = SIMD_LOAD(grads + i);
grad_4[1].data = SIMD_LOAD(grads + i + SIMD_WIDTH);
grad_4[2].data = SIMD_LOAD(grads + i + (SIMD_WIDTH << 1));
grad_4[3].data = SIMD_LOAD(grads + i + SIMD_WIDTH * 3);
AVX_Data momentum_4[4];
momentum_4[0].data = SIMD_LOAD(_exp_avg + i);
momentum_4[1].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH);
momentum_4[2].data = SIMD_LOAD(_exp_avg + i + (SIMD_WIDTH << 1));
momentum_4[3].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * 3);
AVX_Data variance_4[4];
variance_4[0].data = SIMD_LOAD(_exp_avg_sq + i);
variance_4[1].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH);
variance_4[2].data = SIMD_LOAD(_exp_avg_sq + i + (SIMD_WIDTH << 1));
variance_4[3].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * 3);
AVX_Data param_4[4];
param_4[0].data = SIMD_LOAD(_params + i);
param_4[1].data = SIMD_LOAD(_params + i + SIMD_WIDTH);
param_4[2].data = SIMD_LOAD(_params + i + (SIMD_WIDTH << 1));
param_4[3].data = SIMD_LOAD(_params + i + SIMD_WIDTH * 3);
if (_weight_decay > 0 && !_adamw_mode) {
grad_4[0].data = SIMD_FMA(param_4[0].data, weight_decay4.data, grad_4[0].data);
grad_4[1].data = SIMD_FMA(param_4[1].data, weight_decay4.data, grad_4[1].data);
grad_4[2].data = SIMD_FMA(param_4[2].data, weight_decay4.data, grad_4[2].data);
grad_4[3].data = SIMD_FMA(param_4[3].data, weight_decay4.data, grad_4[3].data);
}
momentum_4[0].data = SIMD_MUL(momentum_4[0].data, betta1_4.data);
momentum_4[0].data = SIMD_FMA(grad_4[0].data, betta1_minus1_4.data, momentum_4[0].data);
momentum_4[1].data = SIMD_MUL(momentum_4[1].data, betta1_4.data);
momentum_4[1].data = SIMD_FMA(grad_4[1].data, betta1_minus1_4.data, momentum_4[1].data);
momentum_4[2].data = SIMD_MUL(momentum_4[2].data, betta1_4.data);
momentum_4[2].data = SIMD_FMA(grad_4[2].data, betta1_minus1_4.data, momentum_4[2].data);
momentum_4[3].data = SIMD_MUL(momentum_4[3].data, betta1_4.data);
momentum_4[3].data = SIMD_FMA(grad_4[3].data, betta1_minus1_4.data, momentum_4[3].data);
variance_4[0].data = SIMD_MUL(variance_4[0].data, betta2_4.data);
variance_4[1].data = SIMD_MUL(variance_4[1].data, betta2_4.data);
variance_4[2].data = SIMD_MUL(variance_4[2].data, betta2_4.data);
variance_4[3].data = SIMD_MUL(variance_4[3].data, betta2_4.data);
grad_4[0].data = SIMD_MUL(grad_4[0].data, grad_4[0].data);
grad_4[1].data = SIMD_MUL(grad_4[1].data, grad_4[1].data);
grad_4[2].data = SIMD_MUL(grad_4[2].data, grad_4[2].data);
grad_4[3].data = SIMD_MUL(grad_4[3].data, grad_4[3].data);
variance_4[0].data = SIMD_FMA(grad_4[0].data, betta2_minus1_4.data, variance_4[0].data);
variance_4[1].data = SIMD_FMA(grad_4[1].data, betta2_minus1_4.data, variance_4[1].data);
variance_4[2].data = SIMD_FMA(grad_4[2].data, betta2_minus1_4.data, variance_4[2].data);
variance_4[3].data = SIMD_FMA(grad_4[3].data, betta2_minus1_4.data, variance_4[3].data);
grad_4[0].data = SIMD_SQRT(variance_4[0].data);
grad_4[1].data = SIMD_SQRT(variance_4[1].data);
grad_4[2].data = SIMD_SQRT(variance_4[2].data);
grad_4[3].data = SIMD_SQRT(variance_4[3].data);
grad_4[0].data = SIMD_FMA(grad_4[0].data, bias2_sqrt.data, eps_4.data);
grad_4[1].data = SIMD_FMA(grad_4[1].data, bias2_sqrt.data, eps_4.data);
grad_4[2].data = SIMD_FMA(grad_4[2].data, bias2_sqrt.data, eps_4.data);
grad_4[3].data = SIMD_FMA(grad_4[3].data, bias2_sqrt.data, eps_4.data);
grad_4[0].data = SIMD_DIV(momentum_4[0].data, grad_4[0].data);
grad_4[1].data = SIMD_DIV(momentum_4[1].data, grad_4[1].data);
grad_4[2].data = SIMD_DIV(momentum_4[2].data, grad_4[2].data);
grad_4[3].data = SIMD_DIV(momentum_4[3].data, grad_4[3].data);
if (_weight_decay > 0 && _adamw_mode) {
param_4[0].data = SIMD_FMA(param_4[0].data, weight_decay4.data, param_4[0].data);
param_4[1].data = SIMD_FMA(param_4[1].data, weight_decay4.data, param_4[1].data);
param_4[2].data = SIMD_FMA(param_4[2].data, weight_decay4.data, param_4[2].data);
param_4[3].data = SIMD_FMA(param_4[3].data, weight_decay4.data, param_4[3].data);
}
param_4[0].data = SIMD_FMA(grad_4[0].data, step_size_4.data, param_4[0].data);
param_4[1].data = SIMD_FMA(grad_4[1].data, step_size_4.data, param_4[1].data);
param_4[2].data = SIMD_FMA(grad_4[2].data, step_size_4.data, param_4[2].data);
param_4[3].data = SIMD_FMA(grad_4[3].data, step_size_4.data, param_4[3].data);
SIMD_STORE(_params + i, param_4[0].data);
SIMD_STORE(_params + i + SIMD_WIDTH, param_4[1].data);
SIMD_STORE(_params + i + (SIMD_WIDTH << 1), param_4[2].data);
SIMD_STORE(_params + i + SIMD_WIDTH * 3, param_4[3].data);
if (dev_params) {
SIMD_STORE(_doubled_buffer[_buf_index] + (i - t), param_4[0].data);
SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH, param_4[1].data);
SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + (SIMD_WIDTH << 1),
param_4[2].data);
SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH * 3, param_4[3].data);
}
SIMD_STORE(_exp_avg + i, momentum_4[0].data);
SIMD_STORE(_exp_avg + i + SIMD_WIDTH, momentum_4[1].data);
SIMD_STORE(_exp_avg + i + (SIMD_WIDTH << 1), momentum_4[2].data);
SIMD_STORE(_exp_avg + i + SIMD_WIDTH * 3, momentum_4[3].data);
SIMD_STORE(_exp_avg_sq + i, variance_4[0].data);
SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH, variance_4[1].data);
SIMD_STORE(_exp_avg_sq + i + (SIMD_WIDTH << 1), variance_4[2].data);
SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * 3, variance_4[3].data);
}
if (dev_params) {
launch_param_update(
_doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]);
_buf_index = !_buf_index;
}
}
#endif
if (_param_size > rounded_size)
Step((_params + rounded_size),
(grads + rounded_size),
(_exp_avg + rounded_size),
(_exp_avg_sq + rounded_size),
(_param_size - rounded_size),
(dev_params != nullptr ? (dev_params + rounded_size) : dev_params));
}
int create_adam_optimizer(int optimizer_id,
float alpha = 1e-3,
float betta1 = 0.9,
float betta2 = 0.999,
float eps = 1e-8,
float weight_decay = 0,
bool adamw_mode = true)
{
auto opt =
std::make_shared<Adam_Optimizer>(alpha, betta1, betta2, eps, weight_decay, adamw_mode);
s_optimizers[optimizer_id] = opt;
#if defined(__AVX512__)
std::cout << "Adam Optimizer #" << optimizer_id
<< " is created with AVX512 arithmetic capability." << std::endl;
printf("Config: alpha=%f, betas=(%f, %f), weight_decay=%f, adam_w=%d\n",
alpha,
betta1,
betta2,
weight_decay,
(int)adamw_mode);
#else
#if defined(__AVX256__)
std::cout << "Adam Optimizer #" << optimizer_id
<< " is created with AVX2 arithmetic capability." << std::endl;
printf("Config: alpha=%f, betas=(%f, %f), weight_decay=%f, adam_w=%d\n",
alpha,
betta1,
betta2,
weight_decay,
(int)adamw_mode);
#else
std::cout << "Adam Optimizer #" << optimizer_id
<< " is created with scalar arithmetic capability." << std::endl;
printf("Config: alpha=%f, betas=(%f, %f), weight_decay=%f, adam_w=%d\n",
alpha,
betta1,
betta2,
weight_decay,
(int)adamw_mode);
#endif
#endif
return 0;
}
void Adam_Optimizer::Step_8(float* _params,
float* grads,
float* _exp_avg,
float* _exp_avg_sq,
size_t _param_size,
__half* dev_params)
{
size_t rounded_size = 0;
#if defined(__AVX512__) or defined(__AVX256__)
AVX_Data betta1_4;
betta1_4.data = SIMD_SET(_betta1);
AVX_Data betta2_4;
betta2_4.data = SIMD_SET(_betta2);
float betta1_minus1 = 1 - _betta1;
float betta2_minus1 = 1 - _betta2;
AVX_Data betta1_minus1_4;
betta1_minus1_4.data = SIMD_SET(betta1_minus1);
AVX_Data betta2_minus1_4;
betta2_minus1_4.data = SIMD_SET(betta2_minus1);
AVX_Data bias2_sqrt;
bias2_sqrt.data = SIMD_SET(_bias_correction2);
AVX_Data eps_4;
eps_4.data = SIMD_SET(_eps);
float step_size = -1 * _alpha / _bias_correction1;
AVX_Data step_size_4;
step_size_4.data = SIMD_SET(step_size);
float w_decay = -1 * _alpha * _weight_decay;
AVX_Data weight_decay4;
if (_weight_decay > 0)
weight_decay4.data = (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay));
rounded_size = ROUND_DOWN(_param_size, (SIMD_WIDTH << 3));
for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
size_t offset = copy_size + t;
if ((t / TILE) >= 2) { hipStreamSynchronize(_streams[_buf_index]); }
#pragma omp parallel for
for (size_t i = t; i < offset; i += (SIMD_WIDTH << 3)) {
AVX_Data grad_4[8];
grad_4[0].data = SIMD_LOAD(grads + i);
grad_4[1].data = SIMD_LOAD(grads + i + SIMD_WIDTH);
grad_4[2].data = SIMD_LOAD(grads + i + (SIMD_WIDTH << 1));
grad_4[3].data = SIMD_LOAD(grads + i + SIMD_WIDTH * 3);
grad_4[4].data = SIMD_LOAD(grads + i + (SIMD_WIDTH << 2));
grad_4[5].data = SIMD_LOAD(grads + i + SIMD_WIDTH * 5);
grad_4[6].data = SIMD_LOAD(grads + i + SIMD_WIDTH * 6);
grad_4[7].data = SIMD_LOAD(grads + i + SIMD_WIDTH * 7);
AVX_Data momentum_4[8];
momentum_4[0].data = SIMD_LOAD(_exp_avg + i);
momentum_4[1].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH);
momentum_4[2].data = SIMD_LOAD(_exp_avg + i + (SIMD_WIDTH << 1));
momentum_4[3].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * 3);
momentum_4[4].data = SIMD_LOAD(_exp_avg + i + (SIMD_WIDTH << 2));
momentum_4[5].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * 5);
momentum_4[6].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * 6);
momentum_4[7].data = SIMD_LOAD(_exp_avg + i + SIMD_WIDTH * 7);
AVX_Data variance_4[8];
variance_4[0].data = SIMD_LOAD(_exp_avg_sq + i);
variance_4[1].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH);
variance_4[2].data = SIMD_LOAD(_exp_avg_sq + i + (SIMD_WIDTH << 1));
variance_4[3].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * 3);
variance_4[4].data = SIMD_LOAD(_exp_avg_sq + i + (SIMD_WIDTH << 2));
variance_4[5].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * 5);
variance_4[6].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * 6);
variance_4[7].data = SIMD_LOAD(_exp_avg_sq + i + SIMD_WIDTH * 7);
AVX_Data param_4[8];
param_4[0].data = SIMD_LOAD(_params + i);
param_4[1].data = SIMD_LOAD(_params + i + SIMD_WIDTH);
param_4[2].data = SIMD_LOAD(_params + i + (SIMD_WIDTH << 1));
param_4[3].data = SIMD_LOAD(_params + i + SIMD_WIDTH * 3);
param_4[4].data = SIMD_LOAD(_params + i + (SIMD_WIDTH << 2));
param_4[5].data = SIMD_LOAD(_params + i + SIMD_WIDTH * 5);
param_4[6].data = SIMD_LOAD(_params + i + SIMD_WIDTH * 6);
param_4[7].data = SIMD_LOAD(_params + i + SIMD_WIDTH * 7);
if (_weight_decay > 0 && !_adamw_mode) {
grad_4[0].data = SIMD_FMA(param_4[0].data, weight_decay4.data, grad_4[0].data);
grad_4[1].data = SIMD_FMA(param_4[1].data, weight_decay4.data, grad_4[1].data);
grad_4[2].data = SIMD_FMA(param_4[2].data, weight_decay4.data, grad_4[2].data);
grad_4[3].data = SIMD_FMA(param_4[3].data, weight_decay4.data, grad_4[3].data);
grad_4[4].data = SIMD_FMA(param_4[4].data, weight_decay4.data, grad_4[4].data);
grad_4[5].data = SIMD_FMA(param_4[5].data, weight_decay4.data, grad_4[5].data);
grad_4[6].data = SIMD_FMA(param_4[6].data, weight_decay4.data, grad_4[6].data);
grad_4[7].data = SIMD_FMA(param_4[7].data, weight_decay4.data, grad_4[7].data);
}
momentum_4[0].data = SIMD_MUL(momentum_4[0].data, betta1_4.data);
momentum_4[0].data = SIMD_FMA(grad_4[0].data, betta1_minus1_4.data, momentum_4[0].data);
momentum_4[1].data = SIMD_MUL(momentum_4[1].data, betta1_4.data);
momentum_4[1].data = SIMD_FMA(grad_4[1].data, betta1_minus1_4.data, momentum_4[1].data);
momentum_4[2].data = SIMD_MUL(momentum_4[2].data, betta1_4.data);
momentum_4[2].data = SIMD_FMA(grad_4[2].data, betta1_minus1_4.data, momentum_4[2].data);
momentum_4[3].data = SIMD_MUL(momentum_4[3].data, betta1_4.data);
momentum_4[3].data = SIMD_FMA(grad_4[3].data, betta1_minus1_4.data, momentum_4[3].data);
momentum_4[4].data = SIMD_MUL(momentum_4[4].data, betta1_4.data);
momentum_4[4].data = SIMD_FMA(grad_4[4].data, betta1_minus1_4.data, momentum_4[4].data);
momentum_4[5].data = SIMD_MUL(momentum_4[5].data, betta1_4.data);
momentum_4[5].data = SIMD_FMA(grad_4[5].data, betta1_minus1_4.data, momentum_4[5].data);
momentum_4[6].data = SIMD_MUL(momentum_4[6].data, betta1_4.data);
momentum_4[6].data = SIMD_FMA(grad_4[6].data, betta1_minus1_4.data, momentum_4[6].data);
momentum_4[7].data = SIMD_MUL(momentum_4[7].data, betta1_4.data);
momentum_4[7].data = SIMD_FMA(grad_4[7].data, betta1_minus1_4.data, momentum_4[7].data);
variance_4[0].data = SIMD_MUL(variance_4[0].data, betta2_4.data);
variance_4[1].data = SIMD_MUL(variance_4[1].data, betta2_4.data);
variance_4[2].data = SIMD_MUL(variance_4[2].data, betta2_4.data);
variance_4[3].data = SIMD_MUL(variance_4[3].data, betta2_4.data);
variance_4[4].data = SIMD_MUL(variance_4[4].data, betta2_4.data);
variance_4[5].data = SIMD_MUL(variance_4[5].data, betta2_4.data);
variance_4[6].data = SIMD_MUL(variance_4[6].data, betta2_4.data);
variance_4[7].data = SIMD_MUL(variance_4[7].data, betta2_4.data);
grad_4[0].data = SIMD_MUL(grad_4[0].data, grad_4[0].data);
grad_4[1].data = SIMD_MUL(grad_4[1].data, grad_4[1].data);
grad_4[2].data = SIMD_MUL(grad_4[2].data, grad_4[2].data);
grad_4[3].data = SIMD_MUL(grad_4[3].data, grad_4[3].data);
grad_4[4].data = SIMD_MUL(grad_4[4].data, grad_4[4].data);
grad_4[5].data = SIMD_MUL(grad_4[5].data, grad_4[5].data);
grad_4[6].data = SIMD_MUL(grad_4[6].data, grad_4[6].data);
grad_4[7].data = SIMD_MUL(grad_4[7].data, grad_4[7].data);
variance_4[0].data = SIMD_FMA(grad_4[0].data, betta2_minus1_4.data, variance_4[0].data);
variance_4[1].data = SIMD_FMA(grad_4[1].data, betta2_minus1_4.data, variance_4[1].data);
variance_4[2].data = SIMD_FMA(grad_4[2].data, betta2_minus1_4.data, variance_4[2].data);
variance_4[3].data = SIMD_FMA(grad_4[3].data, betta2_minus1_4.data, variance_4[3].data);
variance_4[4].data = SIMD_FMA(grad_4[4].data, betta2_minus1_4.data, variance_4[4].data);
variance_4[5].data = SIMD_FMA(grad_4[5].data, betta2_minus1_4.data, variance_4[5].data);
variance_4[6].data = SIMD_FMA(grad_4[6].data, betta2_minus1_4.data, variance_4[6].data);
variance_4[7].data = SIMD_FMA(grad_4[7].data, betta2_minus1_4.data, variance_4[7].data);
grad_4[0].data = SIMD_SQRT(variance_4[0].data);
grad_4[1].data = SIMD_SQRT(variance_4[1].data);
grad_4[2].data = SIMD_SQRT(variance_4[2].data);
grad_4[3].data = SIMD_SQRT(variance_4[3].data);
grad_4[4].data = SIMD_SQRT(variance_4[4].data);
grad_4[5].data = SIMD_SQRT(variance_4[5].data);
grad_4[6].data = SIMD_SQRT(variance_4[6].data);
grad_4[7].data = SIMD_SQRT(variance_4[7].data);
grad_4[0].data = SIMD_FMA(grad_4[0].data, bias2_sqrt.data, eps_4.data);
grad_4[1].data = SIMD_FMA(grad_4[1].data, bias2_sqrt.data, eps_4.data);
grad_4[2].data = SIMD_FMA(grad_4[2].data, bias2_sqrt.data, eps_4.data);
grad_4[3].data = SIMD_FMA(grad_4[3].data, bias2_sqrt.data, eps_4.data);
grad_4[4].data = SIMD_FMA(grad_4[4].data, bias2_sqrt.data, eps_4.data);
grad_4[5].data = SIMD_FMA(grad_4[5].data, bias2_sqrt.data, eps_4.data);
grad_4[6].data = SIMD_FMA(grad_4[6].data, bias2_sqrt.data, eps_4.data);
grad_4[7].data = SIMD_FMA(grad_4[7].data, bias2_sqrt.data, eps_4.data);
grad_4[0].data = SIMD_DIV(momentum_4[0].data, grad_4[0].data);
grad_4[1].data = SIMD_DIV(momentum_4[1].data, grad_4[1].data);
grad_4[2].data = SIMD_DIV(momentum_4[2].data, grad_4[2].data);
grad_4[3].data = SIMD_DIV(momentum_4[3].data, grad_4[3].data);
grad_4[4].data = SIMD_DIV(momentum_4[4].data, grad_4[4].data);
grad_4[5].data = SIMD_DIV(momentum_4[5].data, grad_4[5].data);
grad_4[6].data = SIMD_DIV(momentum_4[6].data, grad_4[6].data);
grad_4[7].data = SIMD_DIV(momentum_4[7].data, grad_4[7].data);
if (_weight_decay > 0 && _adamw_mode) {
param_4[0].data = SIMD_FMA(param_4[0].data, weight_decay4.data, param_4[0].data);
param_4[1].data = SIMD_FMA(param_4[1].data, weight_decay4.data, param_4[1].data);
param_4[2].data = SIMD_FMA(param_4[2].data, weight_decay4.data, param_4[2].data);
param_4[3].data = SIMD_FMA(param_4[3].data, weight_decay4.data, param_4[3].data);
param_4[4].data = SIMD_FMA(param_4[4].data, weight_decay4.data, param_4[4].data);
param_4[5].data = SIMD_FMA(param_4[5].data, weight_decay4.data, param_4[5].data);
param_4[6].data = SIMD_FMA(param_4[6].data, weight_decay4.data, param_4[6].data);
param_4[7].data = SIMD_FMA(param_4[7].data, weight_decay4.data, param_4[7].data);
}
param_4[0].data = SIMD_FMA(grad_4[0].data, step_size_4.data, param_4[0].data);
param_4[1].data = SIMD_FMA(grad_4[1].data, step_size_4.data, param_4[1].data);
param_4[2].data = SIMD_FMA(grad_4[2].data, step_size_4.data, param_4[2].data);
param_4[3].data = SIMD_FMA(grad_4[3].data, step_size_4.data, param_4[3].data);
param_4[4].data = SIMD_FMA(grad_4[4].data, step_size_4.data, param_4[4].data);
param_4[5].data = SIMD_FMA(grad_4[5].data, step_size_4.data, param_4[5].data);
param_4[6].data = SIMD_FMA(grad_4[6].data, step_size_4.data, param_4[6].data);
param_4[7].data = SIMD_FMA(grad_4[7].data, step_size_4.data, param_4[7].data);
SIMD_STORE(_params + i, param_4[0].data);
SIMD_STORE(_params + i + SIMD_WIDTH, param_4[1].data);
SIMD_STORE(_params + i + (SIMD_WIDTH << 1), param_4[2].data);
SIMD_STORE(_params + i + SIMD_WIDTH * 3, param_4[3].data);
SIMD_STORE(_params + i + (SIMD_WIDTH << 2), param_4[4].data);
SIMD_STORE(_params + i + SIMD_WIDTH * 5, param_4[5].data);
SIMD_STORE(_params + i + SIMD_WIDTH * 6, param_4[6].data);
SIMD_STORE(_params + i + SIMD_WIDTH * 7, param_4[7].data);
if (dev_params) {
SIMD_STORE(_doubled_buffer[_buf_index] + (i - t), param_4[0].data);
SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH, param_4[1].data);
SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + (SIMD_WIDTH << 1),
param_4[2].data);
SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH * 3, param_4[3].data);
SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + (SIMD_WIDTH << 2),
param_4[4].data);
SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH * 5, param_4[5].data);
SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH * 6, param_4[6].data);
SIMD_STORE(_doubled_buffer[_buf_index] + (i - t) + SIMD_WIDTH * 7, param_4[7].data);
}
SIMD_STORE(_exp_avg + i, momentum_4[0].data);
SIMD_STORE(_exp_avg + i + SIMD_WIDTH, momentum_4[1].data);
SIMD_STORE(_exp_avg + i + (SIMD_WIDTH << 1), momentum_4[2].data);
SIMD_STORE(_exp_avg + i + SIMD_WIDTH * 3, momentum_4[3].data);
SIMD_STORE(_exp_avg + i + (SIMD_WIDTH << 2), momentum_4[4].data);
SIMD_STORE(_exp_avg + i + SIMD_WIDTH * 5, momentum_4[5].data);
SIMD_STORE(_exp_avg + i + SIMD_WIDTH * 6, momentum_4[6].data);
SIMD_STORE(_exp_avg + i + SIMD_WIDTH * 7, momentum_4[7].data);
SIMD_STORE(_exp_avg_sq + i, variance_4[0].data);
SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH, variance_4[1].data);
SIMD_STORE(_exp_avg_sq + i + (SIMD_WIDTH << 1), variance_4[2].data);
SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * 3, variance_4[3].data);
SIMD_STORE(_exp_avg_sq + i + (SIMD_WIDTH << 2), variance_4[4].data);
SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * 5, variance_4[5].data);
SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * 6, variance_4[6].data);
SIMD_STORE(_exp_avg_sq + i + SIMD_WIDTH * 7, variance_4[7].data);
}
if (dev_params) {
launch_param_update(
_doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]);
_buf_index = !_buf_index;
}
}
#endif
if (_param_size > rounded_size)
Step_4((_params + rounded_size),
(grads + rounded_size),
(_exp_avg + rounded_size),
(_exp_avg_sq + rounded_size),
(_param_size - rounded_size),
(dev_params != nullptr ? (dev_params + rounded_size) : dev_params));
}
int ds_adam_step(int optimizer_id,
size_t step,
float lr,
float beta1,
float beta2,
float epsilon,
float weight_decay,
bool bias_correction,
torch::Tensor& params,
torch::Tensor& grads,
torch::Tensor& exp_avg,
torch::Tensor& exp_avg_sq)
{
auto params_c = params.contiguous();
auto grads_c = grads.contiguous();
auto exp_avg_c = exp_avg.contiguous();
auto exp_avg_sq_c = exp_avg_sq.contiguous();
float* params_ptr = (float*)params_c.data_ptr();
float* grads_ptr = (float*)grads_c.data_ptr();
float* exp_avg_ptr = (float*)exp_avg_c.data_ptr();
float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr();
std::shared_ptr<Adam_Optimizer> opt =
std::static_pointer_cast<Adam_Optimizer>(s_optimizers[optimizer_id]);
opt->IncrementStep(step, beta1, beta2);
opt->update_state(lr, epsilon, weight_decay, bias_correction);
opt->Step_8(params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr, params_c.size(0));
opt->SynchronizeStreams();
return 0;
}
int ds_adam_step_plus_copy(int optimizer_id,
size_t step,
float lr,
float beta1,
float beta2,
float epsilon,
float weight_decay,
bool bias_correction,
torch::Tensor& params,
torch::Tensor& grads,
torch::Tensor& exp_avg,
torch::Tensor& exp_avg_sq,
torch::Tensor& gpu_params)
{
auto params_c = params.contiguous();
auto gpu_params_c = gpu_params.contiguous();
auto exp_avg_c = exp_avg.contiguous();
auto exp_avg_sq_c = exp_avg_sq.contiguous();
auto grads_c = grads.contiguous();
float* params_ptr = (float*)params_c.data_ptr();
float* grads_ptr = (float*)grads_c.data_ptr();
__half* gpu_params_ptr = (__half*)gpu_params_c.data_ptr();
float* exp_avg_ptr = (float*)exp_avg_c.data_ptr();
float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr();
std::shared_ptr<Adam_Optimizer> opt =
std::static_pointer_cast<Adam_Optimizer>(s_optimizers[optimizer_id]);
opt->IncrementStep(step, beta1, beta2);
opt->update_state(lr, epsilon, weight_decay, bias_correction);
opt->Step_8(
params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr, params_c.size(0), gpu_params_ptr);
opt->SynchronizeStreams();
return 0;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("adam_update", &ds_adam_step, "DeepSpeed CPU Adam update (C++)");
m.def("adam_update_copy",
&ds_adam_step_plus_copy,
"DeepSpeed CPU Adam update and param copy (C++)");
m.def("create_adam", &create_adam_optimizer, "DeepSpeed CPU Adam (C++)");
}
#include "hip/hip_runtime.h"
#include "custom_cuda_layers.h"
__global__ void param_update_kernel(const float* input, __half* output, int size)
{
int id = blockIdx.x * blockDim.x + threadIdx.x;
if (id < size) { output[id] = (__half)input[id]; }
}
void launch_param_update(const float* input, __half* output, int size, hipStream_t stream)
{
int threads = 1024;
dim3 grid_dim((size - 1) / threads + 1);
dim3 block_dim(threads);
hipLaunchKernelGGL(( param_update_kernel), dim3(grid_dim), dim3(block_dim), 0, stream, input, output, size);
}
#include <torch/extension.h>
void multi_tensor_adam_cuda(int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
const float lr,
const float beta1,
const float beta2,
const float epsilon,
const int step,
const int mode,
const int bias_correction,
const float weight_decay);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("multi_tensor_adam",
&multi_tensor_adam_cuda,
"Compute and apply gradient update to parameters for Adam optimizer");
}
/* Copyright 2020 The Microsoft DeepSpeed Team
Copyright NVIDIA/apex
This file is adapted from fused adam in NVIDIA/apex, commit a109f85
*/
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/hip/HIPContext.h>
#include <ATen/hip/Exceptions.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
#include "multi_tensor_apply.cuh"
#include "type_shim.h"
#define BLOCK_SIZE 512
#define ILP 4
typedef enum {
ADAM_MODE_0 = 0, // L2 regularization mode
ADAM_MODE_1 = 1 // Decoupled weight decay mode(AdamW)
} adamMode_t;
using MATH_T = float;
template <typename T>
struct AdamFunctor {
__device__ __forceinline__ void operator()(int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<4>& tl,
const float beta1,
const float beta2,
const float beta1_correction,
const float beta2_correction,
const float epsilon,
const float lr,
adamMode_t mode,
const float decay)
{
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
// potentially use to pass in list of scalar
// int tensor_num = tl.start_tensor_this_launch + tensor_loc;
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
T* g = (T*)tl.addresses[0][tensor_loc];
g += chunk_idx * chunk_size;
T* p = (T*)tl.addresses[1][tensor_loc];
p += chunk_idx * chunk_size;
T* m = (T*)tl.addresses[2][tensor_loc];
m += chunk_idx * chunk_size;
T* v = (T*)tl.addresses[3][tensor_loc];
v += chunk_idx * chunk_size;
n -= chunk_idx * chunk_size;
// see note in multi_tensor_scale_kernel.cu
for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {
MATH_T r_g[ILP];
MATH_T r_p[ILP];
MATH_T r_m[ILP];
MATH_T r_v[ILP];
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
r_g[ii] = g[i];
r_p[ii] = p[i];
r_m[ii] = m[i];
r_v[ii] = v[i];
} else {
r_g[ii] = MATH_T(0);
r_p[ii] = MATH_T(0);
r_m[ii] = MATH_T(0);
r_v[ii] = MATH_T(0);
}
}
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
if (mode == ADAM_MODE_0) { // L2
r_g[ii] = r_g[ii] + (decay * r_p[ii]);
r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
MATH_T update = next_m_unbiased / denom;
r_p[ii] = r_p[ii] - (lr * update);
} else { // weight decay
r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]);
r_p[ii] = r_p[ii] - (lr * update);
}
}
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
p[i] = r_p[ii];
m[i] = r_m[ii];
v[i] = r_v[ii];
}
}
}
}
};
void multi_tensor_adam_cuda(int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
const float lr,
const float beta1,
const float beta2,
const float epsilon,
const int step,
const int mode,
const int bias_correction,
const float weight_decay)
{
using namespace at;
// Handle bias correction mode
float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
if (bias_correction == 1) {
bias_correction1 = 1 - ::pow(beta1, step);
bias_correction2 = 1 - ::pow(beta2, step);
}
// Assume single type across p,g,m1,m2 now
DISPATCH_DOUBLE_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(),
0,
"adam",
multi_tensor_apply<4>(BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<scalar_t_0>(),
beta1,
beta2,
bias_correction1,
bias_correction2,
epsilon,
lr,
(adamMode_t)mode,
weight_decay);)
AT_CUDA_CHECK(hipGetLastError());
}
#include "hip/hip_runtime.h"
/* Copyright 2020 The Microsoft DeepSpeed Team
Copyright NVIDIA/apex
This file is adapted from fused adam in NVIDIA/apex, commit a109f85
*/
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/hip/HIPContext.h>
#include <ATen/hip/Exceptions.h>
#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>
#include "compat.h"
#include <assert.h>
// #include <iostream>
// This header is the one-stop shop for all your multi-tensor apply needs.
// TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson)
constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30};
constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320};
template <int n>
struct TensorListMetadata {
void* addresses[n][depth_to_max_tensors[n - 1]];
int sizes[depth_to_max_tensors[n - 1]];
unsigned char block_to_tensor[depth_to_max_blocks[n - 1]];
int block_to_chunk[depth_to_max_blocks[n - 1]]; // I fear this needs to be a full int.
int start_tensor_this_launch;
};
template <typename T, typename U, typename... ArgTypes>
__global__ void multi_tensor_apply_kernel(int chunk_size,
volatile int* noop_flag,
T tl,
U callable,
ArgTypes... args)
{
// Hand the chunk information to the user-supplied functor to process however it likes.
callable(chunk_size, noop_flag, tl, args...);
}
template <int depth, typename T, typename... ArgTypes>
void multi_tensor_apply(int block_size,
int chunk_size,
const at::Tensor& noop_flag,
const std::vector<std::vector<at::Tensor>>& tensor_lists,
T callable,
ArgTypes... args)
{
TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth");
int len0 = tensor_lists[0].size();
TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0");
auto ref_device = tensor_lists[0][0].device();
TORCH_CHECK(ref_device.type() == at::kCUDA, "expected input to be on cuda");
for (int l = 0; l < tensor_lists.size(); l++) // No range-based for because I need indices
{
TORCH_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists");
for (int t = 0; t < tensor_lists[l].size(); t++) {
// TODO: Print which tensor fails.
bool contiguous_memory = tensor_lists[l][t].is_contiguous();
#ifdef VERSION_GE_1_5
contiguous_memory = (contiguous_memory ||
tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast));
#endif
TORCH_CHECK(contiguous_memory, "A tensor was not contiguous.");
TORCH_CHECK(tensor_lists[l][t].device() == ref_device,
"A tensor was not on the same device as the first tensor");
TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), "Size mismatch");
}
}
int ntensors = tensor_lists[0].size();
TensorListMetadata<depth> tl;
const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(tensor_lists[0][0]));
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA();
tl.start_tensor_this_launch = 0;
int loc_block_info = 0;
int loc_tensor_info = 0;
for (int t = 0; t < ntensors; t++) {
tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel();
for (int d = 0; d < depth; d++)
tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr();
loc_tensor_info++;
int chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
for (int chunk = 0; chunk < chunks_this_tensor; chunk++) {
// std::cout << chunks_this_tensor << std::endl;
tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
tl.block_to_chunk[loc_block_info] = chunk;
loc_block_info++;
bool tensors_full = (loc_tensor_info == depth_to_max_tensors[depth - 1] &&
chunk == chunks_this_tensor - 1);
bool blocks_full = (loc_block_info == depth_to_max_blocks[depth - 1]);
bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1);
if (tensors_full || blocks_full || last_chunk) {
// using accscalar_t = acc_type<scalar_t, true>;
hipLaunchKernelGGL(( multi_tensor_apply_kernel), dim3(loc_block_info), dim3(block_size), 0, stream,
chunk_size, noop_flag.DATA_PTR<int>(), tl, callable, args...);
AT_CUDA_CHECK(hipGetLastError());
// Reset. The control flow possibilities here make my brain hurt.
loc_block_info = 0;
if (chunk == chunks_this_tensor - 1) {
// std::cout << "Hit case 1 " << cond1 << " " << cond2 << " " << cond3 <<
// std::endl;
loc_tensor_info = 0;
tl.start_tensor_this_launch = t + 1;
} else {
// std::cout << "Hit case 2 " << cond1 << " " << cond2 << " " << cond3 <<
// std::endl;
tl.sizes[0] = tl.sizes[loc_tensor_info - 1];
for (int d = 0; d < depth; d++)
tl.addresses[d][0] = tl.addresses[d][loc_tensor_info - 1];
loc_tensor_info = 1;
tl.start_tensor_this_launch = t;
}
}
}
}
}
/* Copyright 2020 The Microsoft DeepSpeed Team
Copyright NVIDIA/apex
This file is adapted from fused adam in NVIDIA/apex, commit a109f85
*/
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
#include "multi_tensor_apply.cuh"
#include "type_shim.h"
#define BLOCK_SIZE 512
#define ILP 4
typedef enum {
ADAM_MODE_0 = 0, // L2 regularization mode
ADAM_MODE_1 = 1 // Decoupled weight decay mode(AdamW)
} adamMode_t;
using MATH_T = float;
template <typename T>
struct AdamFunctor {
__device__ __forceinline__ void operator()(int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<4>& tl,
const float beta1,
const float beta2,
const float beta1_correction,
const float beta2_correction,
const float epsilon,
const float lr,
adamMode_t mode,
const float decay)
{
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
// potentially use to pass in list of scalar
// int tensor_num = tl.start_tensor_this_launch + tensor_loc;
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
T* g = (T*)tl.addresses[0][tensor_loc];
g += chunk_idx * chunk_size;
T* p = (T*)tl.addresses[1][tensor_loc];
p += chunk_idx * chunk_size;
T* m = (T*)tl.addresses[2][tensor_loc];
m += chunk_idx * chunk_size;
T* v = (T*)tl.addresses[3][tensor_loc];
v += chunk_idx * chunk_size;
n -= chunk_idx * chunk_size;
// see note in multi_tensor_scale_kernel.cu
for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {
MATH_T r_g[ILP];
MATH_T r_p[ILP];
MATH_T r_m[ILP];
MATH_T r_v[ILP];
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
r_g[ii] = g[i];
r_p[ii] = p[i];
r_m[ii] = m[i];
r_v[ii] = v[i];
} else {
r_g[ii] = MATH_T(0);
r_p[ii] = MATH_T(0);
r_m[ii] = MATH_T(0);
r_v[ii] = MATH_T(0);
}
}
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
if (mode == ADAM_MODE_0) { // L2
r_g[ii] = r_g[ii] + (decay * r_p[ii]);
r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
MATH_T update = next_m_unbiased / denom;
r_p[ii] = r_p[ii] - (lr * update);
} else { // weight decay
r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]);
r_p[ii] = r_p[ii] - (lr * update);
}
}
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
p[i] = r_p[ii];
m[i] = r_m[ii];
v[i] = r_v[ii];
}
}
}
}
};
void multi_tensor_adam_cuda(int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
const float lr,
const float beta1,
const float beta2,
const float epsilon,
const int step,
const int mode,
const int bias_correction,
const float weight_decay)
{
using namespace at;
// Handle bias correction mode
float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
if (bias_correction == 1) {
bias_correction1 = 1 - std::pow(beta1, step);
bias_correction2 = 1 - std::pow(beta2, step);
}
// Assume single type across p,g,m1,m2 now
DISPATCH_DOUBLE_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(),
0,
"adam",
multi_tensor_apply<4>(BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<scalar_t_0>(),
beta1,
beta2,
bias_correction1,
bias_correction2,
epsilon,
lr,
(adamMode_t)mode,
weight_decay);)
AT_CUDA_CHECK(cudaGetLastError());
}
/* Copyright 2020 The Microsoft DeepSpeed Team
Copyright NVIDIA/apex
This file is adapted from fused adam in NVIDIA/apex, commit a109f85
*/
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <c10/cuda/CUDAGuard.h>
#include "compat.h"
#include <assert.h>
// #include <iostream>
// This header is the one-stop shop for all your multi-tensor apply needs.
// TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson)
constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30};
constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320};
template <int n>
struct TensorListMetadata {
void* addresses[n][depth_to_max_tensors[n - 1]];
int sizes[depth_to_max_tensors[n - 1]];
unsigned char block_to_tensor[depth_to_max_blocks[n - 1]];
int block_to_chunk[depth_to_max_blocks[n - 1]]; // I fear this needs to be a full int.
int start_tensor_this_launch;
};
template <typename T, typename U, typename... ArgTypes>
__global__ void multi_tensor_apply_kernel(int chunk_size,
volatile int* noop_flag,
T tl,
U callable,
ArgTypes... args)
{
// Hand the chunk information to the user-supplied functor to process however it likes.
callable(chunk_size, noop_flag, tl, args...);
}
template <int depth, typename T, typename... ArgTypes>
void multi_tensor_apply(int block_size,
int chunk_size,
const at::Tensor& noop_flag,
const std::vector<std::vector<at::Tensor>>& tensor_lists,
T callable,
ArgTypes... args)
{
TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth");
int len0 = tensor_lists[0].size();
TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0");
auto ref_device = tensor_lists[0][0].device();
TORCH_CHECK(ref_device.type() == at::kCUDA, "expected input to be on cuda");
for (int l = 0; l < tensor_lists.size(); l++) // No range-based for because I need indices
{
TORCH_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists");
for (int t = 0; t < tensor_lists[l].size(); t++) {
// TODO: Print which tensor fails.
bool contiguous_memory = tensor_lists[l][t].is_contiguous();
#ifdef VERSION_GE_1_5
contiguous_memory = (contiguous_memory ||
tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast));
#endif
TORCH_CHECK(contiguous_memory, "A tensor was not contiguous.");
TORCH_CHECK(tensor_lists[l][t].device() == ref_device,
"A tensor was not on the same device as the first tensor");
TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), "Size mismatch");
}
}
int ntensors = tensor_lists[0].size();
TensorListMetadata<depth> tl;
const at::cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0]));
auto stream = at::cuda::getCurrentCUDAStream();
tl.start_tensor_this_launch = 0;
int loc_block_info = 0;
int loc_tensor_info = 0;
for (int t = 0; t < ntensors; t++) {
tl.sizes[loc_tensor_info] = tensor_lists[0][t].numel();
for (int d = 0; d < depth; d++)
tl.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr();
loc_tensor_info++;
int chunks_this_tensor = (tensor_lists[0][t].numel() + chunk_size - 1) / chunk_size;
for (int chunk = 0; chunk < chunks_this_tensor; chunk++) {
// std::cout << chunks_this_tensor << std::endl;
tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
tl.block_to_chunk[loc_block_info] = chunk;
loc_block_info++;
bool tensors_full = (loc_tensor_info == depth_to_max_tensors[depth - 1] &&
chunk == chunks_this_tensor - 1);
bool blocks_full = (loc_block_info == depth_to_max_blocks[depth - 1]);
bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1);
if (tensors_full || blocks_full || last_chunk) {
// using accscalar_t = acc_type<scalar_t, true>;
multi_tensor_apply_kernel<<<loc_block_info, block_size, 0, stream>>>(
chunk_size, noop_flag.DATA_PTR<int>(), tl, callable, args...);
AT_CUDA_CHECK(cudaGetLastError());
// Reset. The control flow possibilities here make my brain hurt.
loc_block_info = 0;
if (chunk == chunks_this_tensor - 1) {
// std::cout << "Hit case 1 " << cond1 << " " << cond2 << " " << cond3 <<
// std::endl;
loc_tensor_info = 0;
tl.start_tensor_this_launch = t + 1;
} else {
// std::cout << "Hit case 2 " << cond1 << " " << cond2 << " " << cond3 <<
// std::endl;
tl.sizes[0] = tl.sizes[loc_tensor_info - 1];
for (int d = 0; d < depth; d++)
tl.addresses[d][0] = tl.addresses[d][loc_tensor_info - 1];
loc_tensor_info = 1;
tl.start_tensor_this_launch = t;
}
}
}
}
}
#pragma once
#ifdef _WIN32
#include <windows.h>
#else
#include <time.h>
#endif
#ifdef _WIN32
class Stopwatch {
private:
double m_total_time;
LARGE_INTEGER m_start_time;
public:
Stopwatch() { m_total_time = 0.0; }
~Stopwatch() {}
void Reset() { m_total_time = 0.0; }
void Start() { QueryPerformanceCounter(&m_start_time); }
void Restart()
{
m_total_time = 0.0;
QueryPerformanceCounter(&m_start_time);
}
void Stop()
{
LARGE_INTEGER frequency;
LARGE_INTEGER stop_time;
QueryPerformanceFrequency(&frequency);
QueryPerformanceCounter(&stop_time);
m_total_time +=
((double)(stop_time.QuadPart - m_start_time.QuadPart) / (double)frequency.QuadPart);
}
double GetTimeInSeconds() { return m_total_time; }
};
#else
class Stopwatch {
private:
double m_total_time;
struct timespec m_start_time;
bool m_is_started;
public:
Stopwatch()
{
m_total_time = 0.0;
m_is_started = false;
}
~Stopwatch() {}
void Reset() { m_total_time = 0.0; }
void Start()
{
clock_gettime(CLOCK_MONOTONIC, &m_start_time);
m_is_started = true;
}
void Restart()
{
m_total_time = 0.0;
clock_gettime(CLOCK_MONOTONIC, &m_start_time);
m_is_started = true;
}
void Stop()
{
if (m_is_started) {
m_is_started = false;
struct timespec end_time;
clock_gettime(CLOCK_MONOTONIC, &end_time);
m_total_time += (double)(end_time.tv_sec - m_start_time.tv_sec) +
(double)(end_time.tv_nsec - m_start_time.tv_nsec) / 1e9;
}
}
double GetTimeInSeconds()
{
if (m_is_started) {
Stop();
Start();
}
return m_total_time;
}
};
#endif
#ifndef __TIMER_H__
#define __TIMER_H__
#include <cuda_runtime.h>
#include <chrono>
#include "cuda.h"
class GPUTimer {
cudaEvent_t start, stop;
public:
GPUTimer()
{
cudaEventCreate(&start);
cudaEventCreate(&stop);
}
~GPUTimer()
{
cudaEventDestroy(start);
cudaEventDestroy(stop);
}
inline void Record() { cudaEventRecord(start); }
inline void Elapsed(float& time_elapsed)
{
cudaEventRecord(stop);
cudaEventSynchronize(stop);
cudaEventElapsedTime(&time_elapsed, start, stop);
}
};
class CPUTimer {
std::chrono::high_resolution_clock::time_point start;
public:
CPUTimer() : start(std::chrono::high_resolution_clock::now()) {}
inline void Reset() { start = std::chrono::high_resolution_clock::now(); }
inline float Elapsed()
{
auto temp = start;
start = std::chrono::high_resolution_clock::now();
return (float)(std::chrono::duration_cast<std::chrono::microseconds>(start - temp).count() /
1e3);
}
};
#endif
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime_api.h>
#include <cassert>
#include <iostream>
#include <vector>
#include "cublas_v2.h"
#include "cuda.h"
#include "curand.h"
#include "gemm_test.h"
#define WARP_SIZE 32
#define CUDA_CHECK(callstr) \
{ \
cudaError_t error_code = callstr; \
if (error_code != cudaSuccess) { \
std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__; \
assert(0); \
} \
}
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x)
#define CUDA_2D_KERNEL_LOOP(i, n, j, m) \
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x) \
for (size_t j = blockIdx.y * blockDim.y + threadIdx.y; j < (m); j += blockDim.y * gridDim.y)
#define DS_CUDA_NUM_THREADS 512
#define DS_MAXIMUM_NUM_BLOCKS 262144
inline int DS_GET_BLOCKS(const int N)
{
return (std::max)(
(std::min)((N + DS_CUDA_NUM_THREADS - 1) / DS_CUDA_NUM_THREADS, DS_MAXIMUM_NUM_BLOCKS),
// Use at least 1 block, since CUDA does not allow empty block
1);
}
class Context {
public:
Context() : _workspace(nullptr), _seed(42), _curr_offset(0)
{
curandCreateGenerator(&_gen, CURAND_RNG_PSEUDO_DEFAULT);
curandSetPseudoRandomGeneratorSeed(_gen, 123);
if (cublasCreate(&_cublasHandle) != CUBLAS_STATUS_SUCCESS) {
auto message = std::string("Fail to create cublas handle.");
std::cerr << message << std::endl;
throw std::runtime_error(message);
}
}
virtual ~Context()
{
cublasDestroy(_cublasHandle);
cudaFree(_workspace);
}
static Context& Instance()
{
static Context _ctx;
return _ctx;
}
void SetWorkSpace(void* workspace)
{
if (!workspace) { throw std::runtime_error("Workspace is null."); }
_workspace = workspace;
}
void* GetWorkSpace() { return _workspace; }
curandGenerator_t& GetRandGenerator() { return _gen; }
cudaStream_t GetCurrentStream()
{
// get current pytorch stream.
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
return stream;
}
cudaStream_t GetNewStream() { return at::cuda::getStreamFromPool(); }
cublasHandle_t GetCublasHandle() { return _cublasHandle; }
std::pair<uint64_t, uint64_t> IncrementOffset(uint64_t offset_inc)
{
uint64_t offset = _curr_offset;
_curr_offset += offset_inc;
return std::pair<uint64_t, uint64_t>(_seed, offset);
}
void SetSeed(uint64_t new_seed) { _seed = new_seed; }
void TestGemmFP16(bool test_gemm, int batch_size, int seq_len, int head_num, int size_per_head)
{
// avoid rerun.
if (_gemm_algos.size() > 0) return;
if (test_gemm) {
cublasHandle_t handle = GetCublasHandle();
std::unique_ptr<GemmTest<__half>> test_qkv_fw(
new GemmTest<__half>(batch_size * seq_len, // M
head_num * size_per_head, // N
head_num * size_per_head, // K
CUBLAS_OP_T,
CUBLAS_OP_N,
handle));
std::unique_ptr<GemmTest<__half>> test_inter(
new GemmTest<__half>(batch_size * seq_len, // M
4 * head_num * size_per_head, // N
head_num * size_per_head, // K
CUBLAS_OP_T,
CUBLAS_OP_N,
handle));
std::unique_ptr<GemmTest<__half>> test_output(
new GemmTest<__half>(batch_size * seq_len, // M
head_num * size_per_head, // N
4 * head_num * size_per_head, // K
CUBLAS_OP_T,
CUBLAS_OP_N,
handle));
std::unique_ptr<StridedGemmTest<__half>> test_attn_scores(
new StridedGemmTest<__half>(batch_size * head_num, // batch
seq_len, // M
seq_len, // N
size_per_head, // K
CUBLAS_OP_T,
CUBLAS_OP_N,
handle));
std::unique_ptr<StridedGemmTest<__half>> test_attn_context(
new StridedGemmTest<__half>(batch_size * head_num, // batch
size_per_head, // M
seq_len, // N
seq_len, // K
CUBLAS_OP_N,
CUBLAS_OP_N,
handle));
_gemm_algos.push_back(test_qkv_fw->TestAlgo(100));
_gemm_algos.push_back(test_inter->TestAlgo(100));
_gemm_algos.push_back(test_output->TestAlgo(100));
_gemm_algos.push_back(test_attn_scores->TestAlgo(100));
_gemm_algos.push_back(test_attn_context->TestAlgo(100));
} else {
// Use default algo.
_gemm_algos.push_back(std::array<int, 3>({99, 99, 99}));
_gemm_algos.push_back(std::array<int, 3>({99, 99, 99}));
_gemm_algos.push_back(std::array<int, 3>({99, 99, 99}));
_gemm_algos.push_back(std::array<int, 3>({99, 99, 99}));
_gemm_algos.push_back(std::array<int, 3>({99, 99, 99}));
}
}
const std::vector<std::array<int, 3>>& GetGemmAlgos() const { return _gemm_algos; }
private:
curandGenerator_t _gen;
cublasHandle_t _cublasHandle;
void* _workspace;
uint64_t _seed;
uint64_t _curr_offset;
std::vector<std::array<int, 3>> _gemm_algos;
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment