Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
deepspeed
Commits
eadbbe09
Commit
eadbbe09
authored
Apr 25, 2021
by
401qingkong
Browse files
push rocm deepspeed v0.3.13
parent
ab5534fc
Changes
155
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
5667 additions
and
1 deletion
+5667
-1
csrc/transformer/hip/normalize_kernels.hip
csrc/transformer/hip/normalize_kernels.hip
+2104
-0
csrc/transformer/hip/softmax_kernels.hip
csrc/transformer/hip/softmax_kernels.hip
+592
-0
csrc/transformer/hip/transform_kernels.hip
csrc/transformer/hip/transform_kernels.hip
+576
-0
csrc/utils/hip/flatten_unflatten.cpp
csrc/utils/hip/flatten_unflatten.cpp
+25
-0
deepspeed/ops/csrc
deepspeed/ops/csrc
+0
-1
deepspeed/ops/csrc/adam/compat.h
deepspeed/ops/csrc/adam/compat.h
+14
-0
deepspeed/ops/csrc/adam/cpu_adam.cpp
deepspeed/ops/csrc/adam/cpu_adam.cpp
+682
-0
deepspeed/ops/csrc/adam/custom_cuda_kernel.cu
deepspeed/ops/csrc/adam/custom_cuda_kernel.cu
+20
-0
deepspeed/ops/csrc/adam/fused_adam_frontend.cpp
deepspeed/ops/csrc/adam/fused_adam_frontend.cpp
+20
-0
deepspeed/ops/csrc/adam/hip/compat.h
deepspeed/ops/csrc/adam/hip/compat.h
+14
-0
deepspeed/ops/csrc/adam/hip/cpu_adam.cpp
deepspeed/ops/csrc/adam/hip/cpu_adam.cpp
+682
-0
deepspeed/ops/csrc/adam/hip/custom_hip_kernel.hip
deepspeed/ops/csrc/adam/hip/custom_hip_kernel.hip
+21
-0
deepspeed/ops/csrc/adam/hip/fused_adam_frontend.cpp
deepspeed/ops/csrc/adam/hip/fused_adam_frontend.cpp
+20
-0
deepspeed/ops/csrc/adam/hip/multi_tensor_adam.hip
deepspeed/ops/csrc/adam/hip/multi_tensor_adam.hip
+163
-0
deepspeed/ops/csrc/adam/hip/multi_tensor_apply.cuh
deepspeed/ops/csrc/adam/hip/multi_tensor_apply.cuh
+128
-0
deepspeed/ops/csrc/adam/multi_tensor_adam.cu
deepspeed/ops/csrc/adam/multi_tensor_adam.cu
+163
-0
deepspeed/ops/csrc/adam/multi_tensor_apply.cuh
deepspeed/ops/csrc/adam/multi_tensor_apply.cuh
+127
-0
deepspeed/ops/csrc/includes/StopWatch.h
deepspeed/ops/csrc/includes/StopWatch.h
+98
-0
deepspeed/ops/csrc/includes/Timer.h
deepspeed/ops/csrc/includes/Timer.h
+47
-0
deepspeed/ops/csrc/includes/context.h
deepspeed/ops/csrc/includes/context.h
+171
-0
No files found.
csrc/transformer/hip/normalize_kernels.hip
0 → 100644
View file @
eadbbe09
#include "hip/hip_runtime.h"
#include "hip/custom_hip_layers.h"
namespace cg = cooperative_groups;
/*
Fused bias add, residual (elementwise) add, and normalization layer.
For FP16, this kernel does not promote to FP32 in order to utilize the 2x throughput for
__half2 instructions, and avoid the conversion overhead (1/8 of __hal2 arithmetic).
For specific launch constraints, see the launch functions.
*/
#define NORM_REG (MAX_REGISTERS / 4)
__global__ void fused_bias_residual_layer_norm(float* vals,
const float* residual,
const float* gamma,
const float* beta,
float epsilon,
bool preLayerNorm,
bool training,
float* vars,
float* means,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id / WARP_SIZE;
float vals_arr[NORM_REG];
__shared__ float shr[MAX_WARP_NUM];
residual += (row * row_stride);
vals += (row * row_stride);
float sum = 0.f;
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
vals_arr[i] = residual[i * iteration_stride + id];
sum += vals_arr[i];
}
if (high_index < row_stride) {
vals_arr[iterations] = residual[high_index];
sum += vals_arr[iterations];
iterations++;
}
for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (iteration_stride >> 5)) sum = shr[g.thread_rank()];
#if !defined(__STOCHASTIC_MODE__) || __CUDA_ARCH__ < 700
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> 5); i *= 2) { sum += g.shfl_down(sum, i); }
sum = g.shfl(sum, 0);
float mean = sum / row_stride;
if (training)
if (g.thread_rank() == 0) means[row] = mean;
float variance = 0.f;
for (int i = 0; i < iterations; i++) {
vals_arr[i] -= mean;
variance += vals_arr[i] * vals_arr[i];
}
for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
if (g.thread_rank() == 0) shr[gid] = variance;
b.sync();
if (g.thread_rank() < (iteration_stride >> 5)) variance = shr[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> 5); i *= 2) { variance += g.shfl_down(variance, i); }
variance = g.shfl(variance, 0);
variance /= row_stride;
variance += epsilon;
if (training)
if (g.thread_rank() == 0) vars[row] = variance;
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) {
vals_arr[i] = vals_arr[i] * rsqrtf(variance);
vals_arr[i] =
vals_arr[i] * gamma[i * iteration_stride + id] + beta[i * iteration_stride + id];
vals[i * iteration_stride + id] = vals_arr[i];
}
if ((high_index) < row_stride) {
vals_arr[iterations] = vals_arr[iterations] * rsqrtf(variance);
vals_arr[iterations] = vals_arr[iterations] * gamma[high_index] + beta[high_index];
vals[high_index] = vals_arr[iterations];
}
}
__global__ void fused_bias_residual_layer_norm(__half* vals,
const __half* residual,
const __half* gamma,
const __half* beta,
float epsilon,
bool preLayerNorm,
bool training,
__half* vars,
__half* means,
int row_stride)
{
#if __CUDA_ARCH__ >= 700
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id >> 5;
float2 vals_f[NORM_REG];
__shared__ float shr[MAX_WARP_NUM];
__half2* vals_cast = reinterpret_cast<__half2*>(vals);
const __half2* residual_cast = reinterpret_cast<const __half2*>(residual);
residual_cast += (row * row_stride);
vals_cast += (row * row_stride);
float sum = 0.f;
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
vals_f[i] = __half22float2(residual_cast[i * iteration_stride + id]);
sum += vals_f[i].x;
sum += vals_f[i].y;
}
if ((high_index) < row_stride) {
vals_f[iterations] = __half22float2(residual_cast[high_index]);
sum += vals_f[iterations].x;
sum += vals_f[iterations].y;
iterations++;
}
for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (iteration_stride >> 5)) sum = shr[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> 5); i *= 2) { sum += g.shfl_down(sum, i); }
sum = g.shfl(sum, 0);
float mean = sum / (row_stride * 2);
float variance = 0.f;
for (int i = 0; i < iterations; i++) {
vals_f[i].x -= mean;
vals_f[i].y -= mean;
variance += vals_f[i].x * vals_f[i].x;
variance += vals_f[i].y * vals_f[i].y;
}
for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
if (g.thread_rank() == 0) shr[gid] = variance;
b.sync();
if (g.thread_rank() < (iteration_stride >> 5)) variance = shr[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> 5); i *= 2) { variance += g.shfl_down(variance, i); }
variance = g.shfl(variance, 0);
variance /= (row_stride * 2);
variance += epsilon;
__half2 variance_h = __float2half2_rn(variance);
const __half2* gamma_cast = reinterpret_cast<const __half2*>(gamma);
const __half2* beta_cast = reinterpret_cast<const __half2*>(beta);
if (training && g.thread_rank() == 0) {
vars[row] = __float2half(variance);
means[row] = __float2half(mean);
}
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) {
__half2 vals_arr = __float22half2_rn(vals_f[i]);
vals_arr = vals_arr * h2rsqrt(variance_h);
vals_arr =
vals_arr * gamma_cast[i * iteration_stride + id] + beta_cast[i * iteration_stride + id];
vals_cast[i * iteration_stride + id] = vals_arr;
}
if ((high_index) < row_stride) {
__half2 vals_arr = __float22half2_rn(vals_f[iterations]);
vals_arr = vals_arr * h2rsqrt(variance_h);
vals_arr = vals_arr * gamma_cast[high_index] + beta_cast[high_index];
vals_cast[high_index] = vals_arr;
}
#endif
}
template <typename T>
void launch_bias_residual_layer_norm(T* vals,
const T* residual,
const T* gamma,
const T* beta,
float epsilon,
int batch_size,
int hidden_dim,
hipStream_t stream,
bool preLayerNorm,
bool training,
T* vars,
T* means);
template <>
void launch_bias_residual_layer_norm<float>(float* vals,
const float* residual,
const float* gamma,
const float* beta,
float epsilon,
int batch_size,
int hidden_dim,
hipStream_t stream,
bool preLayerNorm,
bool training,
float* vars,
float* means)
{
int threads = THREADS;
dim3 grid_dim(batch_size);
if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 1;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 2;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim(threads);
hipLaunchKernelGGL(( fused_bias_residual_layer_norm), dim3(grid_dim), dim3(block_dim), 0, stream,
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, hidden_dim);
}
template <>
void launch_bias_residual_layer_norm<__half>(__half* vals,
const __half* residual,
const __half* gamma,
const __half* beta,
float epsilon,
int batch_size,
int hidden_dim,
hipStream_t stream,
bool preLayerNorm,
bool training,
__half* vars,
__half* means)
{
int threads = 128;
dim3 grid_dim(batch_size);
if (hidden_dim > 8192 && hidden_dim <= 16384)
threads <<= 1;
else if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 2;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 3;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim(threads);
hipLaunchKernelGGL(( fused_bias_residual_layer_norm), dim3(grid_dim), dim3(block_dim), 0, stream,
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, hidden_dim / 2);
}
__global__ void fused_bias_residual_layer_norm(float* vals,
const float* residual,
const float* gamma,
const float* beta,
float epsilon,
bool preLayerNorm,
bool training,
float* vars,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id / 32;
float vals_arr[NORM_REG];
__shared__ float shr[MAX_WARP_NUM];
residual += (row * row_stride);
vals += (row * row_stride);
float sum = 0.f;
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
vals_arr[i] = residual[i * iteration_stride + id];
sum += vals_arr[i];
}
if ((high_index) < row_stride) {
vals_arr[iterations] = residual[high_index];
sum += vals_arr[iterations];
iterations++;
}
for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (iteration_stride >> 5)) sum = shr[g.thread_rank()];
#if !defined(__STOCHASTIC_MODE__) || __CUDA_ARCH__ < 700
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> 5); i *= 2) { sum += g.shfl_down(sum, i); }
sum = g.shfl(sum, 0);
float mean = sum / row_stride;
float variance = 0.f;
for (int i = 0; i < iterations; i++) {
vals_arr[i] -= mean;
variance += vals_arr[i] * vals_arr[i];
}
for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
if (g.thread_rank() == 0) shr[gid] = variance;
b.sync();
if (g.thread_rank() < (iteration_stride >> 5)) variance = shr[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> 5); i *= 2) { variance += g.shfl_down(variance, i); }
variance = g.shfl(variance, 0);
variance /= row_stride;
variance += epsilon;
if (training)
if (g.thread_rank() == 0) vars[row] = variance;
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) {
vals_arr[i] = vals_arr[i] * rsqrtf(variance);
vals_arr[i] =
vals_arr[i] * gamma[i * iteration_stride + id] + beta[i * iteration_stride + id];
vals[i * iteration_stride + id] = vals_arr[i];
}
if ((high_index) < row_stride) {
vals_arr[iterations] = vals_arr[iterations] * rsqrtf(variance);
vals_arr[iterations] = vals_arr[iterations] * gamma[high_index] + beta[high_index];
vals[high_index] = vals_arr[iterations];
}
}
__global__ void fused_bias_residual_layer_norm(__half* vals,
const __half* residual,
const __half* gamma,
const __half* beta,
float epsilon,
bool preLayerNorm,
bool training,
__half* vars,
int row_stride)
{
#if __CUDA_ARCH__ >= 700
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int gid = id >> 5;
float2 vals_f[NORM_REG];
__shared__ float shr[MAX_WARP_NUM];
__half2* vals_cast = reinterpret_cast<__half2*>(vals);
const __half2* residual_cast = reinterpret_cast<const __half2*>(residual);
residual_cast += (row * row_stride);
vals_cast += (row * row_stride);
float sum = 0.f;
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
vals_f[i] = __half22float2(residual_cast[i * iteration_stride + id]);
sum += vals_f[i].x;
sum += vals_f[i].y;
}
if ((high_index) < row_stride) {
vals_f[iterations] = __half22float2(residual_cast[high_index]);
sum += vals_f[iterations].x;
sum += vals_f[iterations].y;
iterations++;
}
for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) shr[gid] = sum;
b.sync();
if (g.thread_rank() < (iteration_stride >> 5)) sum = shr[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> 5); i *= 2) { sum += g.shfl_down(sum, i); }
sum = g.shfl(sum, 0);
float mean = sum / (row_stride * 2);
float variance = 0.f;
for (int i = 0; i < iterations; i++) {
vals_f[i].x -= mean;
vals_f[i].y -= mean;
variance += vals_f[i].x * vals_f[i].x;
variance += vals_f[i].y * vals_f[i].y;
}
for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); }
if (g.thread_rank() == 0) shr[gid] = variance;
b.sync();
if (g.thread_rank() < (iteration_stride >> 5)) variance = shr[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
for (int i = 1; i < (iteration_stride >> 5); i *= 2) { variance += g.shfl_down(variance, i); }
variance = g.shfl(variance, 0);
variance /= (row_stride * 2);
variance += epsilon;
__half2 variance_h = __float2half2_rn(variance);
const __half2* gamma_cast = reinterpret_cast<const __half2*>(gamma);
const __half2* beta_cast = reinterpret_cast<const __half2*>(beta);
if (training && g.thread_rank() == 0) vars[row] = __float2half(variance);
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) {
__half2 vals_arr = __float22half2_rn(vals_f[i]);
vals_arr = vals_arr * h2rsqrt(variance_h);
vals_arr =
vals_arr * gamma_cast[i * iteration_stride + id] + beta_cast[i * iteration_stride + id];
vals_cast[i * iteration_stride + id] = vals_arr;
}
if ((high_index) < row_stride) {
__half2 vals_arr = __float22half2_rn(vals_f[iterations]);
vals_arr = vals_arr * h2rsqrt(variance_h);
vals_arr = vals_arr * gamma_cast[high_index] + beta_cast[high_index];
vals_cast[high_index] = vals_arr;
}
#endif
}
template <typename T>
void launch_bias_residual_layer_norm(T* vals,
const T* residual,
const T* gamma,
const T* beta,
float epsilon,
int batch_size,
int hidden_dim,
hipStream_t stream,
bool preLayerNorm,
bool training,
T* vars);
/*
To tune this launch the following restrictions must be met:
For float:
row_stride == hidden_size
threads * iterations == row_stride
threads is in [32, 64, 128, 256, 512, 1024]
For half:
row_stride == hidden_size / 2
threads * iterations == row_stride
threads is in [32, 64, 128, 256, 512, 1024]
*/
template <>
void launch_bias_residual_layer_norm<float>(float* vals,
const float* residual,
const float* gamma,
const float* beta,
float epsilon,
int batch_size,
int hidden_dim,
hipStream_t stream,
bool preLayerNorm,
bool training,
float* vars)
{
int threads = THREADS;
dim3 grid_dim(batch_size);
// There are some limitations to call below functions, now just enumerate the situations.
if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 1;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 2;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim(threads);
hipLaunchKernelGGL(( fused_bias_residual_layer_norm), dim3(grid_dim), dim3(block_dim), 0, stream,
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, hidden_dim);
}
template <>
void launch_bias_residual_layer_norm<__half>(__half* vals,
const __half* residual,
const __half* gamma,
const __half* beta,
float epsilon,
int batch_size,
int hidden_dim,
hipStream_t stream,
bool preLayerNorm,
bool training,
__half* vars)
{
int threads = 128;
dim3 grid_dim(batch_size);
// There are some limitations to call below functions, now just enumerate the situations.
if (hidden_dim > 8192 && hidden_dim <= 16384)
threads <<= 1;
else if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 2;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 3;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim(threads);
hipLaunchKernelGGL(( fused_bias_residual_layer_norm), dim3(grid_dim), dim3(block_dim), 0, stream,
vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, hidden_dim / 2);
}
/* Normalize Gamma & Betta gradients
* Compute gradients using either X_hat or
* normalize input (invertible).
* Combine transpose with gradients computation.
*/
template <typename T>
__global__ void LayerNormBackward1(const T* __restrict__ out_grad,
const T* __restrict__ vals_hat,
const T* __restrict__ gamma,
const T* __restrict__ betta,
T* __restrict__ gamma_grad,
T* __restrict__ betta_grad,
int rows,
int width,
bool invertible)
{
__shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1];
__shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1];
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int offset = threadIdx.y * width + idx;
int y_stride = width * TILE_DIM;
float betta_reg = (invertible ? (float)betta[idx] : 0.0f);
float gamma_reg = (float)gamma[idx];
// Loop across matrix height
float betta_tmp = 0;
float gamma_tmp = 0;
for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
float grad = (float)out_grad[offset];
float val = (invertible ? ((float)vals_hat[offset] - betta_reg) / gamma_reg
: (float)vals_hat[offset]);
betta_tmp += grad;
gamma_tmp += (val * grad);
offset += y_stride;
}
betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp;
gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp;
__syncthreads();
// Sum the shared buffer.
float s1 = betta_buffer[threadIdx.y][threadIdx.x];
float s2 = gamma_buffer[threadIdx.y][threadIdx.x];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < TILE_DIM; i <<= 1) {
s1 += g.shfl_down(s1, i);
s2 += g.shfl_down(s2, i);
}
if (threadIdx.x == 0) {
int pos = blockIdx.x * TILE_DIM + threadIdx.y;
betta_grad[pos] = s1;
gamma_grad[pos] = s2;
}
}
/* Normalize Gamma & Betta gradients
* Compute gradients using the input to
* the normalize.
* Combine transpose with gradients computation.
*/
template <typename T>
__global__ void LayerNormBackward1(const T* __restrict__ out_grad,
const T* __restrict__ X_data,
const T* __restrict__ vars,
const T* __restrict__ means,
T* __restrict__ gamma_grad,
T* __restrict__ betta_grad,
int rows,
int width)
{
__shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1];
__shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1];
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int offset = threadIdx.y * width + idx;
int y_stride = width * TILE_DIM;
int pos = blockIdx.x * TILE_DIM + threadIdx.y;
// Loop across matrix height
float betta_tmp = 0;
float gamma_tmp = 0;
for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
float grad = (float)out_grad[offset];
float val = (float)X_data[offset];
val = (val - (float)means[r]) * rsqrtf((float)vars[r]);
betta_tmp += grad;
gamma_tmp += (val * grad);
offset += y_stride;
}
betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp;
gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp;
__syncthreads();
// Sum the shared buffer.
float s1 = betta_buffer[threadIdx.y][threadIdx.x];
float s2 = gamma_buffer[threadIdx.y][threadIdx.x];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < TILE_DIM; i <<= 1) {
s1 += g.shfl_down(s1, i);
s2 += g.shfl_down(s2, i);
}
if (threadIdx.x == 0) {
betta_grad[pos] = s1;
gamma_grad[pos] = s2;
}
}
/*
/* Backward Normalize (Input-Gradient)
* Using the means and variances from the input
* This type of backward is invertible!
* We do the backward using the X_hat (X - u) / sqrt(variance) or the output of Normalization.
*/
__global__ void LayerNormBackward2(const float* out_grad,
const float* vals_hat,
const float* gamma,
const float* betta,
const float* vars,
float* inp_grad,
bool invertible,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id / WARP_SIZE;
int warp_num = (THREADS < row_stride ? THREADS : row_stride) / WARP_SIZE;
__shared__ float partialSum[MAX_WARP_NUM];
out_grad += (row * row_stride);
vals_hat += (row * row_stride);
inp_grad += (row * row_stride);
float vals_arr[NORM_REG];
float vals_hat_arr[NORM_REG];
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
float gamma_reg = gamma[i * iteration_stride + id];
vals_arr[i] = out_grad[i * iteration_stride + id];
vals_arr[i] *= gamma_reg;
vals_hat_arr[i] =
(invertible ? (vals_hat[i * iteration_stride + id] - betta[i * iteration_stride + id]) /
gamma_reg
: vals_hat[i * iteration_stride + id]);
}
if ((high_index) < row_stride) {
float gamma_reg = gamma[high_index];
vals_arr[iterations] = out_grad[high_index];
vals_arr[iterations] *= gamma_reg;
vals_hat_arr[iterations] =
(invertible ? (vals_hat[high_index] - betta[high_index]) / gamma_reg
: vals_hat[high_index]);
iterations++;
}
float var_reg = vars[row];
float sum = 0;
for (int i = 0; i < iterations; i++) {
sum += vals_hat_arr[i] * vals_arr[i] *
sqrtf(var_reg); // dval_hat = gamma * (x - u) * out_grad
vals_arr[i] *= rsqrtf(var_reg); // dvar_inv = gamma * out_grad / sqrt(var)
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
for (int i = 0; i < iterations; i++) { vals_arr[i] += ((-sum * vals_hat_arr[i]) / var_reg); }
sum = 0;
for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; }
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) inp_grad[i * iteration_stride + id] = (vals_arr[i] - sum);
if ((high_index) < row_stride) inp_grad[high_index] = (vals_arr[iterations] - sum);
}
__global__ void LayerNormBackward2(const __half* out_grad,
const __half* vals_hat,
const __half* gamma,
const __half* betta,
const __half* vars,
__half* inp_grad,
bool invertible,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id / WARP_SIZE;
int warp_num = (iteration_stride < row_stride ? iteration_stride : row_stride) / WARP_SIZE;
__shared__ float partialSum[MAX_WARP_NUM];
__half2 vals_arr[NORM_REG];
float2 vals_arr_f[NORM_REG];
__half2 vals_hat_arr[NORM_REG];
__half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad);
const __half2* out_grad_h = reinterpret_cast<const __half2*>(out_grad);
const __half2* vals_hat_h = reinterpret_cast<const __half2*>(vals_hat);
inp_grad_h += (row * row_stride);
out_grad_h += (row * row_stride);
vals_hat_h += (row * row_stride);
const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma);
const __half2* betta_h = (invertible ? reinterpret_cast<const __half2*>(betta) : nullptr);
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
__half2 gamma_reg = gamma_h[i * iteration_stride + id];
vals_arr[i] = out_grad_h[i * iteration_stride + id];
vals_arr[i] *= gamma_reg;
vals_hat_arr[i] =
(invertible
? (vals_hat_h[i * iteration_stride + id] - betta_h[i * iteration_stride + id]) /
gamma_reg
: vals_hat_h[i * iteration_stride + id]);
}
if ((high_index) < row_stride) {
__half2 gamma_reg = gamma_h[high_index];
vals_arr[iterations] = out_grad_h[high_index];
vals_arr[iterations] *= gamma_reg;
vals_hat_arr[iterations] =
(invertible ? (vals_hat_h[high_index] - betta_h[high_index]) / gamma_reg
: vals_hat_h[high_index]);
iterations++;
}
__half var_h = vars[row];
__half2 var_reg = __halves2half2(var_h, var_h);
float sum = 0.f;
for (int i = 0; i < iterations; i++) {
__half2 result_h = (vals_hat_arr[i] * vals_arr[i] * h2sqrt(var_reg));
float2 result_f = __half22float2(result_h);
sum += result_f.x;
sum += result_f.y;
vals_arr[i] *= h2rsqrt(var_reg);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
__half2 sum_h = __float2half2_rn(sum);
for (int i = 0; i < iterations; i++) {
__half2 temp = ((-sum_h * vals_hat_arr[i]) / (var_reg));
vals_arr_f[i] = __half22float2(vals_arr[i]);
float2 temp_f = __half22float2(temp);
vals_arr_f[i].x += temp_f.x;
vals_arr_f[i].y += temp_f.y;
}
sum = 0.f;
for (int i = 0; i < iterations; i++) {
sum += (vals_arr_f[i].x);
sum += (vals_arr_f[i].y);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) {
vals_arr_f[i].x -= sum;
vals_arr_f[i].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[i]);
inp_grad_h[i * iteration_stride + id] = temp;
}
if ((high_index) < row_stride) {
vals_arr_f[iterations].x -= sum;
vals_arr_f[iterations].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[iterations]);
inp_grad_h[high_index] = temp;
}
}
template <>
void launch_layerNorm_backward<float>(const float* out_grad,
const float* vals_hat,
const float* vars,
const float* gamma,
float* gamma_grad,
float* betta_grad,
float* inp_grad,
int batch,
int hidden_dim,
hipStream_t stream[2],
bool invertible,
const float* betta)
{
int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
hipLaunchKernelGGL(( LayerNormBackward1<float>), dim3(grid_dim), dim3(block_dim), 0, stream[0],
out_grad, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible);
dim3 grid_dim2(batch);
if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 1;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 2;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads);
hipLaunchKernelGGL(( LayerNormBackward2), dim3(grid_dim2), dim3(block_dim2), 0, stream[1],
out_grad, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim);
}
template <>
void launch_layerNorm_backward<__half>(const __half* out_grad,
const __half* vals_hat,
const __half* vars,
const __half* gamma,
__half* gamma_grad,
__half* betta_grad,
__half* inp_grad,
int batch,
int hidden_dim,
hipStream_t stream[2],
bool invertible,
const __half* betta)
{
int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
hipLaunchKernelGGL(( LayerNormBackward1<__half>), dim3(grid_dim), dim3(block_dim), 0, stream[0],
out_grad, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible);
dim3 grid_dim2(batch);
if (hidden_dim > 8192 && hidden_dim <= 16384)
threads <<= 1;
else if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 2;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 3;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads / 2);
hipLaunchKernelGGL(( LayerNormBackward2), dim3(grid_dim2), dim3(block_dim2), 0, stream[1],
out_grad, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim / 2);
}
/* Backward Normalize (Input-Gradient)
* Using the means and variances from the input
* This type of backward is not invertible!
* We do the backward using the input (X)
*/
__global__ void LayerNormBackward2(const float* out_grad,
const float* X_vals,
const float* gamma,
const float* vars,
const float* means,
float* inp_grad,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id / WARP_SIZE;
int warp_num = (THREADS < row_stride ? THREADS : row_stride) / WARP_SIZE;
__shared__ float partialSum[MAX_WARP_NUM];
out_grad += (row * row_stride);
X_vals += (row * row_stride);
inp_grad += (row * row_stride);
float vals_arr[NORM_REG];
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
float gamma_reg = gamma[i * iteration_stride + id];
vals_arr[i] = out_grad[i * iteration_stride + id];
vals_arr[i] *= gamma_reg;
}
if ((high_index) < row_stride) {
float gamma_reg = gamma[high_index];
vals_arr[iterations] = out_grad[high_index];
vals_arr[iterations] *= gamma_reg;
iterations++;
}
float var_reg = vars[row];
float mean_reg = means[row];
float sum = 0;
float xu[NORM_REG];
for (int i = 0; i < iterations; i++) {
xu[i] = (X_vals[i * iteration_stride + id] - mean_reg);
sum += vals_arr[i] * xu[i];
vals_arr[i] *= rsqrtf(var_reg);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
for (int i = 0; i < iterations; i++) {
vals_arr[i] += (-sum * xu[i] * rsqrtf(var_reg) / (var_reg));
}
sum = 0;
for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; }
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) inp_grad[i * iteration_stride + id] = (vals_arr[i] - sum);
if ((high_index) < row_stride) inp_grad[high_index] = (vals_arr[iterations] - sum);
}
__global__ void LayerNormBackward2(const __half* out_grad,
const __half* X_vals,
const __half* gamma,
const __half* vars,
const __half* means,
__half* inp_grad,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id / WARP_SIZE;
int warp_num = (iteration_stride < row_stride ? iteration_stride : row_stride) / WARP_SIZE;
__shared__ float partialSum[MAX_WARP_NUM];
__half2 vals_arr[NORM_REG];
float2 vals_arr_f[NORM_REG];
__half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad);
const __half2* out_grad_h = reinterpret_cast<const __half2*>(out_grad);
const __half2* vals_hat_h = reinterpret_cast<const __half2*>(X_vals);
inp_grad_h += (row * row_stride);
out_grad_h += (row * row_stride);
vals_hat_h += (row * row_stride);
const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma);
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
__half2 gamma_reg = gamma_h[i * iteration_stride + id];
vals_arr[i] = out_grad_h[i * iteration_stride + id];
vals_arr[i] *= gamma_reg; // out_grad * gamma
}
if ((high_index) < row_stride) {
__half2 gamma_reg = gamma_h[high_index];
vals_arr[iterations] = out_grad_h[high_index];
vals_arr[iterations] *= gamma_reg; // out_grad * gamma
iterations++;
}
__half mean_h = means[row];
__half var_h = vars[row];
__half2 var_reg = __halves2half2(var_h, var_h);
__half2 mean_reg = __halves2half2(mean_h, mean_h);
__half2 xu[NORM_REG];
float sum = 0.f;
for (int i = 0; i < iterations; i++) {
xu[i] = (vals_hat_h[i * iteration_stride + id] - mean_reg);
__half2 result_h = (xu[i] * vals_arr[i]);
float2 result_f = __half22float2(result_h);
sum += result_f.x;
sum += result_f.y;
vals_arr[i] *= h2rsqrt(var_reg);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
__half2 sum_h = __float2half2_rn(sum);
for (int i = 0; i < iterations; i++) {
__half2 xu_grad = ((-sum_h * xu[i] * h2rsqrt(var_reg)) / (var_reg));
vals_arr_f[i] = __half22float2(vals_arr[i]);
float2 xu_grad_f = __half22float2(xu_grad);
vals_arr_f[i].x += xu_grad_f.x;
vals_arr_f[i].y += xu_grad_f.y;
}
sum = 0.f;
for (int i = 0; i < iterations; i++) {
sum += (vals_arr_f[i].x);
sum += (vals_arr_f[i].y);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) {
vals_arr_f[i].x -= sum;
vals_arr_f[i].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[i]);
inp_grad_h[i * iteration_stride + id] = temp;
}
if ((high_index) < row_stride) {
vals_arr_f[iterations].x -= sum;
vals_arr_f[iterations].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[iterations]);
inp_grad_h[high_index] = temp;
}
}
template <>
void launch_layerNorm_backward<float>(const float* out_grad,
const float* X_data,
const float* vars,
const float* means,
const float* gamma,
float* gamma_grad,
float* betta_grad,
float* inp_grad,
int batch,
int hidden_dim,
hipStream_t stream[2])
{
int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
hipLaunchKernelGGL(( LayerNormBackward1<float>), dim3(grid_dim), dim3(block_dim), 0, stream[0],
out_grad, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim);
dim3 grid_dim2(batch);
if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 1;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 2;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads);
hipLaunchKernelGGL(( LayerNormBackward2), dim3(grid_dim2), dim3(block_dim2), 0, stream[1],
out_grad, X_data, gamma, vars, means, inp_grad, hidden_dim);
}
template <>
void launch_layerNorm_backward<__half>(const __half* out_grad,
const __half* X_data,
const __half* vars,
const __half* means,
const __half* gamma,
__half* gamma_grad,
__half* betta_grad,
__half* inp_grad,
int batch,
int hidden_dim,
hipStream_t stream[2])
{
int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
hipLaunchKernelGGL(( LayerNormBackward1<__half>), dim3(grid_dim), dim3(block_dim), 0, stream[0],
out_grad, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim);
dim3 grid_dim2(batch);
if (hidden_dim > 8192 && hidden_dim <= 16384)
threads <<= 1;
else if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 2;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 3;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads / 2);
hipLaunchKernelGGL(( LayerNormBackward2), dim3(grid_dim2), dim3(block_dim2), 0, stream[1],
out_grad, X_data, gamma, vars, means, inp_grad, hidden_dim / 2);
}
template <typename T>
__global__ void LayerNormBackward1_fused_add(const T* __restrict__ out_grad1,
const T* __restrict__ out_grad2,
const T* __restrict__ vals_hat,
const T* __restrict__ gamma,
const T* __restrict__ betta,
T* __restrict__ gamma_grad,
T* __restrict__ betta_grad,
int rows,
int width,
bool invertible)
{
__shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1];
__shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1];
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int offset = threadIdx.y * width + idx;
int y_stride = width * TILE_DIM;
float betta_reg = (invertible ? (float)betta[idx] : 0.0f);
float gamma_reg = (float)gamma[idx];
// Loop across matrix height
float betta_tmp = 0;
float gamma_tmp = 0;
for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
float grad = (float)out_grad1[offset] + (float)out_grad2[offset];
float val = (invertible ? ((float)vals_hat[offset] - betta_reg) / gamma_reg
: (float)vals_hat[offset]);
betta_tmp += grad;
gamma_tmp += (val * grad);
offset += y_stride;
}
betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp;
gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp;
__syncthreads();
// Sum the shared buffer.
float s1 = betta_buffer[threadIdx.y][threadIdx.x];
float s2 = gamma_buffer[threadIdx.y][threadIdx.x];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < TILE_DIM; i <<= 1) {
s1 += g.shfl_down(s1, i);
s2 += g.shfl_down(s2, i);
}
if (threadIdx.x == 0) {
int pos = blockIdx.x * TILE_DIM + threadIdx.y;
betta_grad[pos] = s1;
gamma_grad[pos] = s2;
}
}
template <typename T>
__global__ void LayerNormBackward1_fused_add(const T* __restrict__ out_grad1,
const T* __restrict__ out_grad2,
const T* __restrict__ X_data,
const T* __restrict__ vars,
const T* __restrict__ means,
T* __restrict__ gamma_grad,
T* __restrict__ betta_grad,
int rows,
int width)
{
__shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1];
__shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1];
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);
int idx = blockDim.x * blockIdx.x + threadIdx.x;
int offset = threadIdx.y * width + idx;
int y_stride = width * TILE_DIM;
int pos = blockIdx.x * TILE_DIM + threadIdx.y;
// Loop across matrix height
float betta_tmp = 0;
float gamma_tmp = 0;
for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
float grad = (float)out_grad1[offset] + (float)out_grad2[offset];
float val = (float)X_data[offset];
val = (val - (float)means[r]) * rsqrtf((float)vars[r]);
betta_tmp += grad;
gamma_tmp += (val * grad);
offset += y_stride;
}
betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp;
gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp;
__syncthreads();
// Sum the shared buffer.
float s1 = betta_buffer[threadIdx.y][threadIdx.x];
float s2 = gamma_buffer[threadIdx.y][threadIdx.x];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < TILE_DIM; i <<= 1) {
s1 += g.shfl_down(s1, i);
s2 += g.shfl_down(s2, i);
}
if (threadIdx.x == 0) {
betta_grad[pos] = s1;
gamma_grad[pos] = s2;
}
}
__global__ void LayerNormBackward2_fused_add(const float* out_grad1,
const float* out_grad2,
const float* vals_hat,
const float* gamma,
const float* betta,
const float* vars,
float* inp_grad,
bool invertible,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id / WARP_SIZE;
int warp_num = (THREADS < row_stride ? THREADS : row_stride) / WARP_SIZE;
__shared__ float partialSum[MAX_WARP_NUM];
out_grad1 += (row * row_stride);
out_grad2 += (row * row_stride);
vals_hat += (row * row_stride);
inp_grad += (row * row_stride);
float vals_arr[NORM_REG];
float vals_hat_arr[NORM_REG];
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
float gamma_reg = gamma[i * iteration_stride + id];
vals_arr[i] = out_grad1[i * iteration_stride + id];
vals_arr[i] *= gamma_reg;
vals_hat_arr[i] =
(invertible ? (vals_hat[i * iteration_stride + id] - betta[i * iteration_stride + id]) /
gamma_reg
: vals_hat[i * iteration_stride + id]);
}
if ((high_index) < row_stride) {
float gamma_reg = gamma[high_index];
vals_arr[iterations] = out_grad1[high_index];
vals_arr[iterations] *= gamma_reg;
vals_hat_arr[iterations] =
(invertible ? (vals_hat[high_index] - betta[high_index]) / gamma_reg
: vals_hat[high_index]);
iterations++;
}
float var_reg = vars[row];
float sum = 0;
for (int i = 0; i < iterations; i++) {
sum += vals_hat_arr[i] * vals_arr[i] * sqrtf(var_reg);
vals_arr[i] *= rsqrtf(var_reg);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
for (int i = 0; i < iterations; i++) { vals_arr[i] += ((-sum * vals_hat_arr[i]) / var_reg); }
sum = 0;
for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; }
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++)
inp_grad[i * iteration_stride + id] =
(vals_arr[i] - sum) + out_grad2[i * iteration_stride + id];
if ((high_index) < row_stride)
inp_grad[high_index] = (vals_arr[iterations] - sum) + out_grad2[high_index];
}
__global__ void LayerNormBackward2_fused_add(const __half* out_grad1,
const __half* out_grad2,
const __half* vals_hat,
const __half* gamma,
const __half* betta,
const __half* vars,
__half* inp_grad,
bool invertible,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id / WARP_SIZE;
int warp_num = (iteration_stride < row_stride ? iteration_stride : row_stride) / WARP_SIZE;
__shared__ float partialSum[MAX_WARP_NUM];
__half2 vals_arr[NORM_REG];
float2 vals_arr_f[NORM_REG];
__half2 vals_hat_arr[NORM_REG];
// float2 result[iterations];
__half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad);
const __half2* out_grad_h1 = reinterpret_cast<const __half2*>(out_grad1);
const __half2* out_grad_h2 = reinterpret_cast<const __half2*>(out_grad2);
const __half2* vals_hat_h = reinterpret_cast<const __half2*>(vals_hat);
inp_grad_h += (row * row_stride);
out_grad_h1 += (row * row_stride);
out_grad_h2 += (row * row_stride);
vals_hat_h += (row * row_stride);
const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma);
const __half2* betta_h = (invertible ? reinterpret_cast<const __half2*>(betta) : nullptr);
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
__half2 gamma_reg = gamma_h[i * iteration_stride + id];
vals_arr[i] = out_grad_h1[i * iteration_stride + id];
vals_arr[i] *= gamma_reg; // out_grad * gamma
vals_hat_arr[i] =
(invertible
? (vals_hat_h[i * iteration_stride + id] - betta_h[i * iteration_stride + id]) /
gamma_reg
: vals_hat_h[i * iteration_stride + id]);
}
if ((high_index) < row_stride) {
__half2 gamma_reg = gamma_h[high_index];
vals_arr[iterations] = out_grad_h1[high_index];
vals_arr[iterations] *= gamma_reg; // out_grad * gamma
vals_hat_arr[iterations] =
(invertible ? (vals_hat_h[high_index] - betta_h[high_index]) / gamma_reg
: vals_hat_h[high_index]);
iterations++;
}
__half var_h = vars[row];
__half2 var_reg = __halves2half2(var_h, var_h);
float sum = 0.f;
for (int i = 0; i < iterations; i++) {
__half2 result_h = (vals_hat_arr[i] * vals_arr[i] * h2sqrt(var_reg));
float2 result_f = __half22float2(result_h);
sum += result_f.x;
sum += result_f.y;
vals_arr[i] *= h2rsqrt(var_reg);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
__half2 sum_h = __float2half2_rn(sum);
for (int i = 0; i < iterations; i++) {
__half2 temp = ((-sum_h * vals_hat_arr[i]) / (var_reg));
vals_arr_f[i] = __half22float2(vals_arr[i]);
float2 temp_f = __half22float2(temp);
vals_arr_f[i].x += temp_f.x;
vals_arr_f[i].y += temp_f.y;
}
sum = 0.f;
for (int i = 0; i < iterations; i++) {
sum += (vals_arr_f[i].x);
sum += (vals_arr_f[i].y);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) {
vals_arr_f[i].x -= sum;
vals_arr_f[i].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[i]);
inp_grad_h[i * iteration_stride + id] = temp + out_grad_h2[i * iteration_stride + id];
}
if ((high_index) < row_stride) {
vals_arr_f[iterations].x -= sum;
vals_arr_f[iterations].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[iterations]);
inp_grad_h[high_index] = temp + out_grad_h2[high_index];
}
}
template <>
void launch_layerNorm_backward_fused_add<float>(const float* out_grad1,
const float* out_grad2,
const float* vals_hat,
const float* vars,
const float* gamma,
float* gamma_grad,
float* betta_grad,
float* inp_grad,
int batch,
int hidden_dim,
hipStream_t stream[2],
bool invertible,
const float* betta)
{
int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
hipLaunchKernelGGL(( LayerNormBackward1<float>), dim3(grid_dim), dim3(block_dim), 0, stream[0],
out_grad1, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible);
dim3 grid_dim2(batch);
if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 1;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 2;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads);
hipLaunchKernelGGL(( LayerNormBackward2_fused_add), dim3(grid_dim2), dim3(block_dim2), 0, stream[1],
out_grad1, out_grad2, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim);
}
template <>
void launch_layerNorm_backward_fused_add<__half>(const __half* out_grad1,
const __half* out_grad2,
const __half* vals_hat,
const __half* vars,
const __half* gamma,
__half* gamma_grad,
__half* betta_grad,
__half* inp_grad,
int batch,
int hidden_dim,
hipStream_t stream[2],
bool invertible,
const __half* betta)
{
int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
hipLaunchKernelGGL(( LayerNormBackward1<__half>), dim3(grid_dim), dim3(block_dim), 0, stream[0],
out_grad1, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible);
dim3 grid_dim2(batch);
if (hidden_dim > 8192 && hidden_dim <= 16384)
threads <<= 1;
else if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 2;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 3;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads / 2);
hipLaunchKernelGGL(( LayerNormBackward2_fused_add), dim3(grid_dim2), dim3(block_dim2), 0, stream[1],
out_grad1, out_grad2, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim / 2);
}
/* Backward Normalize (Input-Gradient)
* Using the means and variances from the input
* This type of backward is not invertible!
* We do the backward using the input (X)
*/
__global__ void LayerNormBackward2_fused_add(const float* out_grad1,
const float* out_grad2,
const float* X_vals,
const float* gamma,
const float* vars,
const float* means,
float* inp_grad,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id / WARP_SIZE;
int warp_num = (THREADS < row_stride ? THREADS : row_stride) / WARP_SIZE;
__shared__ float partialSum[MAX_WARP_NUM];
float vals_arr[NORM_REG];
float vals_hat_arr[NORM_REG];
out_grad1 += (row * row_stride);
out_grad2 += (row * row_stride);
X_vals += (row * row_stride);
inp_grad += (row * row_stride);
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
float gamma_reg = gamma[i * iteration_stride + id];
vals_arr[i] = out_grad1[i * iteration_stride + id];
vals_arr[i] *= gamma_reg;
vals_hat_arr[i] = X_vals[i * iteration_stride + id];
}
if ((high_index) < row_stride) {
float gamma_reg = gamma[high_index];
vals_arr[iterations] = out_grad1[high_index];
vals_arr[iterations] *= gamma_reg;
vals_hat_arr[iterations] = X_vals[high_index];
iterations++;
}
float var_reg = vars[row];
float mean_reg = means[row];
float sum = 0;
float xu[NORM_REG];
for (int i = 0; i < iterations; i++) {
xu[i] = (vals_hat_arr[i] - mean_reg);
sum += vals_arr[i] * xu[i];
vals_arr[i] *= rsqrtf(var_reg);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
for (int i = 0; i < iterations; i++) {
vals_arr[i] += (-sum * xu[i] * rsqrtf(var_reg) / (var_reg));
}
sum = 0;
for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; }
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= row_stride;
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++)
inp_grad[i * iteration_stride + id] =
(vals_arr[i] - sum) + out_grad2[i * iteration_stride + id];
if ((high_index) < row_stride)
inp_grad[high_index] = (vals_arr[iterations] - sum) + out_grad2[high_index];
}
__global__ void LayerNormBackward2_fused_add(const __half* out_grad1,
const __half* out_grad2,
const __half* X_vals,
const __half* gamma,
const __half* vars,
const __half* means,
__half* inp_grad,
int row_stride)
{
int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id / WARP_SIZE;
int warp_num = (iteration_stride < row_stride ? iteration_stride : row_stride) / WARP_SIZE;
__shared__ float partialSum[MAX_WARP_NUM];
__half2 vals_arr[NORM_REG];
float2 vals_arr_f[NORM_REG];
__half2 vals_hat_arr[NORM_REG];
__half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad);
const __half2* out_grad_h1 = reinterpret_cast<const __half2*>(out_grad1);
const __half2* out_grad_h2 = reinterpret_cast<const __half2*>(out_grad2);
const __half2* vals_hat_h = reinterpret_cast<const __half2*>(X_vals);
out_grad_h1 += (row * row_stride);
out_grad_h2 += (row * row_stride);
inp_grad_h += (row * row_stride);
vals_hat_h += (row * row_stride);
const __half2* gamma_h = reinterpret_cast<const __half2*>(gamma);
int high_index = iterations * iteration_stride + id;
#pragma unroll
for (int i = 0; i < iterations; i++) {
__half2 gamma_reg = gamma_h[i * iteration_stride + id];
vals_arr[i] = out_grad_h1[i * iteration_stride + id];
vals_arr[i] *= gamma_reg; // out_grad * gamma
vals_hat_arr[i] = vals_hat_h[i * iteration_stride + id];
}
if ((high_index) < row_stride) {
__half2 gamma_reg = gamma_h[high_index];
vals_arr[iterations] = out_grad_h1[high_index];
vals_arr[iterations] *= gamma_reg; // out_grad * gamma
vals_hat_arr[iterations] = vals_hat_h[high_index];
iterations++;
}
__half mean_h = means[row];
__half var_h = vars[row];
__half2 var_reg = __halves2half2(var_h, var_h);
__half2 mean_reg = __halves2half2(mean_h, mean_h);
__half2 xu[NORM_REG];
float sum = 0.f;
for (int i = 0; i < iterations; i++) {
xu[i] = (vals_hat_arr[i] - mean_reg);
__half2 result_h = (xu[i] * vals_arr[i]);
float2 result_f = __half22float2(result_h);
sum += result_f.x;
sum += result_f.y;
vals_arr[i] *= h2rsqrt(var_reg);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
__half2 sum_h = __float2half2_rn(sum);
for (int i = 0; i < iterations; i++) {
__half2 xu_grad = ((-sum_h * xu[i] * h2rsqrt(var_reg)) / (var_reg));
vals_arr_f[i] = __half22float2(vals_arr[i]);
float2 xu_grad_f = __half22float2(xu_grad);
vals_arr_f[i].x += xu_grad_f.x;
vals_arr_f[i].y += xu_grad_f.y;
}
sum = 0.f;
for (int i = 0; i < iterations; i++) {
sum += (vals_arr_f[i].x);
sum += (vals_arr_f[i].y);
}
for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); }
if (g.thread_rank() == 0) partialSum[wid] = sum;
__syncthreads();
if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()];
#ifndef __STOCHASTIC_MODE__
__syncthreads();
#endif
for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i);
sum = g.shfl(sum, 0);
sum /= (2 * row_stride);
iterations = row_stride / iteration_stride;
for (int i = 0; i < iterations; i++) {
vals_arr_f[i].x -= sum;
vals_arr_f[i].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[i]);
inp_grad_h[i * iteration_stride + id] = temp + out_grad_h2[i * iteration_stride + id];
}
if ((high_index) < row_stride) {
vals_arr_f[iterations].x -= sum;
vals_arr_f[iterations].y -= sum;
__half2 temp = __float22half2_rn(vals_arr_f[iterations]);
inp_grad_h[high_index] = temp + out_grad_h2[high_index];
}
}
template <>
void launch_layerNorm_backward_fused_add<float>(const float* out_grad1,
const float* out_grad2,
const float* X_data,
const float* vars,
const float* means,
const float* gamma,
float* gamma_grad,
float* betta_grad,
float* inp_grad,
int batch,
int hidden_dim,
hipStream_t stream[2])
{
int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
hipLaunchKernelGGL(( LayerNormBackward1<float>), dim3(grid_dim), dim3(block_dim), 0, stream[0],
out_grad1, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim);
dim3 grid_dim2(batch);
if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 1;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 2;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads);
hipLaunchKernelGGL(( LayerNormBackward2_fused_add), dim3(grid_dim2), dim3(block_dim2), 0, stream[1],
out_grad1, out_grad2, X_data, gamma, vars, means, inp_grad, hidden_dim);
}
template <>
void launch_layerNorm_backward_fused_add<__half>(const __half* out_grad1,
const __half* out_grad2,
const __half* X_data,
const __half* vars,
const __half* means,
const __half* gamma,
__half* gamma_grad,
__half* betta_grad,
__half* inp_grad,
int batch,
int hidden_dim,
hipStream_t stream[2])
{
int threads = THREADS;
dim3 grid_dim(hidden_dim / TILE_DIM);
dim3 block_dim(TILE_DIM, TILE_DIM);
hipLaunchKernelGGL(( LayerNormBackward1<__half>), dim3(grid_dim), dim3(block_dim), 0, stream[0],
out_grad1, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim);
dim3 grid_dim2(batch);
if (hidden_dim > 8192 && hidden_dim <= 16384)
threads <<= 1;
else if (hidden_dim > 16384 && hidden_dim <= 32768)
threads <<= 2;
else if (hidden_dim > 32768 && hidden_dim <= 65536)
threads <<= 3;
else if (hidden_dim > 65536)
throw std::runtime_error("Unsupport hidden_dim.");
dim3 block_dim2(threads / 2);
hipLaunchKernelGGL(( LayerNormBackward2_fused_add), dim3(grid_dim2), dim3(block_dim2), 0, stream[1],
out_grad1, out_grad2, X_data, gamma, vars, means, inp_grad, hidden_dim / 2);
}
csrc/transformer/hip/softmax_kernels.hip
0 → 100644
View file @
eadbbe09
#include "hip/hip_runtime.h"
#include <math.h>
#include "hip/custom_hip_layers.h"
#include "hip/general_kernels.h"
namespace cg = cooperative_groups;
// Fused attention + softmax
template <int tbSize, int blockStride, int tbSeq>
__global__ void attn_softmax(float* vals,
const float* attn_mask,
int heads,
int seq_length,
int iterations)
{
__shared__ float partialSum[MAX_WARP_NUM];
int warp_num = blockDim.x >> 5;
int iteration_stride = blockDim.x;
int block_width = blockStride * seq_length;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<tbSize> g = cg::tiled_partition<tbSize>(b);
int batch = blockIdx.x;
int row = blockIdx.y;
int max_threads_in_sequence = ::max(seq_length, tbSeq);
int seq_lane = threadIdx.x % max_threads_in_sequence;
int data_offset = batch * (gridDim.y * block_width) + row * block_width +
(threadIdx.x / max_threads_in_sequence) * seq_length;
int mask_offset = batch * seq_length;
int wid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
float4* val_cast = reinterpret_cast<float4*>(vals);
const float4* attn_mask_cast = reinterpret_cast<const float4*>(attn_mask);
float4 data[MAX_THREAD_ITERATIONS];
float max_val = minus_infinity;
for (int i = 0; i < iterations; i++) {
int data_id = i * iteration_stride + seq_lane;
if (data_id < seq_length) {
float4 mask = attn_mask_cast[mask_offset + data_id];
data[i] = val_cast[data_offset + data_id];
data[i].x += mask.x;
data[i].y += mask.y;
data[i].z += mask.z;
data[i].w += mask.w;
max_val = (data[i].x > max_val ? data[i].x : max_val);
max_val = (data[i].y > max_val ? data[i].y : max_val);
max_val = (data[i].z > max_val ? data[i].z : max_val);
max_val = (data[i].w > max_val ? data[i].w : max_val);
} else {
data[i].x = minus_infinity;
data[i].y = minus_infinity;
data[i].z = minus_infinity;
data[i].w = minus_infinity;
}
}
for (int i = 1; i < tbSize; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
if (seq_length > tbSize) {
if (lane == 0) partialSum[wid] = max_val;
b.sync();
if (lane < warp_num) max_val = partialSum[lane];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
int iters = warp_num;
if (seq_length < iteration_stride)
iters = warp_num / (iteration_stride / max_threads_in_sequence);
for (int i = 1; i < iters; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
max_val = g.shfl(max_val, threadIdx.x / tbSize);
}
float sum = 0;
for (int i = 0; i < iterations; i++) {
data[i].x = __expf(data[i].x - max_val);
data[i].y = __expf(data[i].y - max_val);
data[i].z = __expf(data[i].z - max_val);
data[i].w = __expf(data[i].w - max_val);
sum += (data[i].x + data[i].y + data[i].z + data[i].w);
}
for (int i = 1; i < tbSize; i *= 2) { sum += g.shfl_xor(sum, i); }
if (seq_length > tbSize) {
if (lane == 0) partialSum[wid] = sum;
b.sync();
if (lane < warp_num) sum = partialSum[lane];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
int iters = warp_num;
if (seq_length < iteration_stride)
iters = warp_num / (iteration_stride / max_threads_in_sequence);
for (int i = 1; i < iters; i *= 2) { sum += g.shfl_xor(sum, i); }
sum = g.shfl(sum, threadIdx.x / tbSize);
}
sum += 1e-6;
for (int i = 0; i < iterations; i++) {
data[i].x /= sum;
data[i].y /= sum;
data[i].z /= sum;
data[i].w /= sum;
int data_id = i * iteration_stride + seq_lane;
if (data_id < seq_length) val_cast[data_offset + data_id] = data[i];
}
}
template <int tbSize, int blockStride, int tbSeq>
__global__ void attn_softmax(__half* vals,
const __half* attn_mask,
int heads,
int seq_length,
int iterations)
{
#if __CUDA_ARCH__ >= 700
__shared__ float partialSum[MAX_WARP_NUM];
int warp_num = blockDim.x >> 5;
int iteration_stride = blockDim.x;
int block_width = blockStride * seq_length;
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<tbSize> g = cg::tiled_partition<tbSize>(b);
int batch = blockIdx.x;
int row = blockIdx.y;
int max_threads_in_sequence = ::max(seq_length, tbSeq);
int seq_lane = threadIdx.x % max_threads_in_sequence;
int data_offset = batch * (gridDim.y * block_width) + row * block_width +
(threadIdx.x / max_threads_in_sequence) * seq_length;
int mask_offset = batch * seq_length;
int wid = threadIdx.x >> 5;
int lane = threadIdx.x & 0x1f;
float2* val_cast = reinterpret_cast<float2*>(vals);
const float2* attn_mask_cast = reinterpret_cast<const float2*>(attn_mask);
val_cast += data_offset;
attn_mask_cast += mask_offset;
float2 low_data[MAX_THREAD_ITERATIONS];
float2 high_data[MAX_THREAD_ITERATIONS];
float max_val = minus_infinity;
for (int i = 0; i < iterations; i++) {
int data_id = i * iteration_stride + seq_lane;
if (data_id < seq_length) {
float2 data = val_cast[data_id];
float2 mask = attn_mask_cast[data_id];
__half2* data_arr = reinterpret_cast<__half2*>(&data);
__half2* mask_arr = reinterpret_cast<__half2*>(&mask);
low_data[i] = __half22float2(data_arr[0]);
high_data[i] = __half22float2(data_arr[1]);
float2 low_mask = __half22float2(mask_arr[0]);
float2 high_mask = __half22float2(mask_arr[1]);
low_data[i].x += low_mask.x;
low_data[i].y += low_mask.y;
high_data[i].x += high_mask.x;
high_data[i].y += high_mask.y;
max_val = (low_data[i].x > max_val ? low_data[i].x : max_val);
max_val = (low_data[i].y > max_val ? low_data[i].y : max_val);
max_val = (high_data[i].x > max_val ? high_data[i].x : max_val);
max_val = (high_data[i].y > max_val ? high_data[i].y : max_val);
}
}
for (int i = 1; i < tbSize; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
if (seq_length > tbSize) {
if (lane == 0) partialSum[wid] = max_val;
b.sync();
if (lane < warp_num) max_val = partialSum[lane];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
int iters = warp_num;
if (seq_length < iteration_stride)
iters = warp_num / (iteration_stride / max_threads_in_sequence);
for (int i = 1; i < iters; i *= 2) {
auto temp = g.shfl_xor(max_val, i);
max_val = (temp > max_val ? temp : max_val);
}
max_val = g.shfl(max_val, threadIdx.x / tbSize);
}
float sum = 0;
for (int i = 0; i < iterations; i++) {
int data_id = i * iteration_stride + seq_lane;
if (data_id < seq_length) {
low_data[i].x = __expf(low_data[i].x - max_val);
low_data[i].y = __expf(low_data[i].y - max_val);
high_data[i].x = __expf(high_data[i].x - max_val);
high_data[i].y = __expf(high_data[i].y - max_val);
sum += (low_data[i].x + low_data[i].y + high_data[i].x + high_data[i].y);
}
}
for (int i = 1; i < tbSize; i *= 2) { sum += g.shfl_xor(sum, i); }
if (seq_length > tbSize) {
if (lane == 0) partialSum[wid] = sum;
b.sync();
if (lane < warp_num) sum = partialSum[lane];
#ifndef __STOCHASTIC_MODE__
b.sync();
#endif
int iters = warp_num;
if (seq_length < iteration_stride)
iters = warp_num / (iteration_stride / max_threads_in_sequence);
for (int i = 1; i < iters; i *= 2) { sum += g.shfl_xor(sum, i); }
sum = g.shfl(sum, threadIdx.x / tbSize);
}
sum += 1e-6;
for (int i = 0; i < iterations; i++) {
int data_id = i * iteration_stride + seq_lane;
if (data_id < seq_length) {
float2 result_f;
__half2* result_h = reinterpret_cast<__half2*>(&result_f);
low_data[i].x /= sum;
low_data[i].y /= sum;
high_data[i].x /= sum;
high_data[i].y /= sum;
result_h[0] = __float22half2_rn(low_data[i]);
result_h[1] = __float22half2_rn(high_data[i]);
val_cast[data_id] = result_f;
}
}
#endif
}
template <typename T>
void launch_attn_softmax(T*, const T*, int, int, int, hipStream_t);
template <>
void launch_attn_softmax<float>(float* vals,
const float* attn_mask,
int batch_size,
int heads,
int sequence_length,
hipStream_t stream)
{
const int threads = 128;
int seq_length4 = sequence_length / 4;
int block_compute_size =
(seq_length4 < threads ? (int)pow(2.0, floor(log2((float)(threads / seq_length4)))) : 1);
dim3 grid_dim(batch_size, heads * sequence_length / block_compute_size);
int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads;
dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) /
subblock_max_workload * threads)
: threads);
int iterations =
(sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads
: MAX_THREAD_ITERATIONS);
if (sequence_length <= 8)
hipLaunchKernelGGL(( attn_softmax<2, (threads / 2), 2>)
, dim3(grid_dim), dim3(block_dim), 0, stream, vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 16)
hipLaunchKernelGGL(( attn_softmax<4, (threads / 4), 4>)
, dim3(grid_dim), dim3(block_dim), 0, stream, vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 32)
hipLaunchKernelGGL(( attn_softmax<8, (threads / 8), 8>)
, dim3(grid_dim), dim3(block_dim), 0, stream, vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 64)
hipLaunchKernelGGL(( attn_softmax<16, (threads / 16), 16>)
, dim3(grid_dim), dim3(block_dim), 0, stream, vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 128)
hipLaunchKernelGGL(( attn_softmax<32, (threads / 32), 32>)
, dim3(grid_dim), dim3(block_dim), 0, stream, vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 256)
hipLaunchKernelGGL(( attn_softmax<32, (threads / 64), 64>)
, dim3(grid_dim), dim3(block_dim), 0, stream, vals, attn_mask, heads, seq_length4, iterations);
else {
const int threads = 256;
block_compute_size =
(seq_length4 < threads ? (int)pow(2.0, floor(log2((float)(threads / seq_length4))))
: 1);
dim3 grid_dim(batch_size, heads * sequence_length / block_compute_size);
int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads;
dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) /
subblock_max_workload * threads)
: threads);
iterations =
(sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads
: MAX_THREAD_ITERATIONS);
if (sequence_length <= 512)
hipLaunchKernelGGL(( attn_softmax<32, (threads / 128), 128>), dim3(grid_dim), dim3(block_dim), 0, stream,
vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length < (MAX_THREADS * MAX_THREAD_ITERATIONS * 4))
hipLaunchKernelGGL(( attn_softmax<32, 1, 128>), dim3(grid_dim), dim3(block_dim), 0, stream,
vals, attn_mask, heads, seq_length4, iterations);
else
throw std::runtime_error(
"Unsupport Seq_Length! Check the restriction of the max_threads and "
"max_thread_iterations!");
}
}
template <>
void launch_attn_softmax<__half>(__half* vals,
const __half* attn_mask,
int batch_size,
int heads,
int sequence_length,
hipStream_t stream)
{
const int threads = 128;
int seq_length4 = sequence_length / 4;
int block_compute_size =
(seq_length4 < threads ? (int)pow(2.0, floor(log2((float)(threads / seq_length4)))) : 1);
dim3 grid_dim(batch_size, heads * sequence_length / block_compute_size);
int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads;
dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) /
subblock_max_workload * threads)
: threads);
int iterations =
(sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads
: MAX_THREAD_ITERATIONS);
if (sequence_length <= 8)
hipLaunchKernelGGL(( attn_softmax<2, (threads / 2), 2>)
, dim3(grid_dim), dim3(block_dim), 0, stream, vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 16)
hipLaunchKernelGGL(( attn_softmax<4, (threads / 4), 4>)
, dim3(grid_dim), dim3(block_dim), 0, stream, vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 32)
hipLaunchKernelGGL(( attn_softmax<8, (threads / 8), 8>)
, dim3(grid_dim), dim3(block_dim), 0, stream, vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 64)
hipLaunchKernelGGL(( attn_softmax<16, (threads / 16), 16>)
, dim3(grid_dim), dim3(block_dim), 0, stream, vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 128)
hipLaunchKernelGGL(( attn_softmax<32, (threads / 32), 32>)
, dim3(grid_dim), dim3(block_dim), 0, stream, vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length <= 256)
hipLaunchKernelGGL(( attn_softmax<32, (threads / 64), 64>)
, dim3(grid_dim), dim3(block_dim), 0, stream, vals, attn_mask, heads, seq_length4, iterations);
else {
const int threads = 256;
block_compute_size =
(seq_length4 < threads ? (int)pow(2.0, floor(log2((float)(threads / seq_length4))))
: 1);
dim3 grid_dim(batch_size, heads * sequence_length / block_compute_size);
int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads;
dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) /
subblock_max_workload * threads)
: threads);
iterations =
(sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads
: MAX_THREAD_ITERATIONS);
if (sequence_length <= 512)
hipLaunchKernelGGL(( attn_softmax<32, (threads / 128), 128>), dim3(grid_dim), dim3(block_dim), 0, stream,
vals, attn_mask, heads, seq_length4, iterations);
else if (sequence_length < (MAX_THREADS * MAX_THREAD_ITERATIONS * 4))
hipLaunchKernelGGL(( attn_softmax<32, 1, 128>), dim3(grid_dim), dim3(block_dim), 0, stream,
vals, attn_mask, heads, seq_length4, iterations);
else
throw std::runtime_error(
"Unsupport Seq_Length! Check the restriction of the max_threads and "
"max_thread_iterations!");
}
}
template <typename T, int tbSize, int blockStride>
__global__ void softmax_backward_kernel(T* out_grad, const T* soft_inp, int seq_length)
{
__shared__ float partialSum[MAX_WARP_NUM];
int warp_num = blockDim.x >> 5; // warp-count = num_threads / WARP_SIZE (32)
int iteration_stride = blockDim.x;
int block_width = blockStride * seq_length;
int iterations = (seq_length < (MAX_THREAD_ITERATIONS * iteration_stride)
? (seq_length + iteration_stride - 1) / iteration_stride
: MAX_THREAD_ITERATIONS);
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<tbSize> g = cg::tiled_partition<tbSize>(b);
int row = blockIdx.x;
int id = threadIdx.x;
int wid = id >> 5;
int lane = id & 0x1f;
T val_reg[MAX_THREAD_ITERATIONS];
T soft_reg[MAX_THREAD_ITERATIONS];
float grad_reg = 0.0f;
#pragma unroll
for (int i = 0; i < iterations; i++) {
int data_id = i * iteration_stride + id;
if (data_id < block_width) {
val_reg[i] = out_grad[row * block_width + data_id];
soft_reg[i] = soft_inp[row * block_width + data_id];
grad_reg += ((float)val_reg[i] *
(float)soft_reg[i]); // if done in half, the multiplication, we may lose
// 2% of accuracy in computation!!
}
}
for (int i = 1; i < tbSize; i *= 2) grad_reg += g.shfl_xor(grad_reg, i);
if (seq_length > tbSize) {
if (lane == 0) partialSum[wid] = grad_reg;
b.sync();
if (lane < warp_num) grad_reg = partialSum[lane];
int iters = warp_num;
if (seq_length < iteration_stride) iters = warp_num / (iteration_stride / seq_length);
for (int i = 1; i < iters; i *= 2) grad_reg += g.shfl_xor(grad_reg, i);
grad_reg = g.shfl(grad_reg, id / tbSize);
}
for (int i = 0; i < iterations; i++) {
int data_id = i * iteration_stride + id;
if (data_id < block_width) {
float temp = (float)soft_reg[i] * ((float)val_reg[i] - grad_reg);
out_grad[row * block_width + data_id] = (T)temp;
}
}
}
template <typename T, int ITERATIONS>
__global__ void softmax_backward_kernel_v2(T* grad /* input & output*/,
const T* output,
int softmax_length)
{
int batch_idx = blockIdx.x * blockDim.y + threadIdx.y;
int offset = batch_idx * softmax_length + threadIdx.x;
grad += offset;
output += offset;
T grad_reg[ITERATIONS];
T output_reg[ITERATIONS];
float sum = 0.0;
#pragma unroll
for (int i = 0; i < ITERATIONS; ++i) {
int curr_idx = threadIdx.x + i * WARP_SIZE;
if (curr_idx < softmax_length) {
grad_reg[i] = grad[i * WARP_SIZE];
output_reg[i] = output[i * WARP_SIZE];
sum += (float)grad_reg[i] * (float)output_reg[i];
}
}
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_xor(sum, i);
#pragma unroll
for (int i = 0; i < ITERATIONS; ++i) {
int curr_idx = threadIdx.x + i * WARP_SIZE;
if (curr_idx < softmax_length)
grad[i * WARP_SIZE] = (float)output_reg[i] * ((float)grad_reg[i] - sum);
}
}
template <typename T>
void launch_attn_softmax_backward_v2(T* out_grad,
const T* soft_inp,
int batch_size,
int heads,
int seq_length,
hipStream_t stream)
{
const int warps_per_block = 4;
dim3 grid_dim(batch_size * heads * seq_length / warps_per_block);
dim3 block_dim(WARP_SIZE, warps_per_block);
if (seq_length <= 32)
hipLaunchKernelGGL(( softmax_backward_kernel_v2<T, 1>)
, dim3(grid_dim), dim3(block_dim), 0, stream, out_grad, soft_inp, seq_length);
else if (seq_length <= 64)
hipLaunchKernelGGL(( softmax_backward_kernel_v2<T, 2>)
, dim3(grid_dim), dim3(block_dim), 0, stream, out_grad, soft_inp, seq_length);
else if (seq_length <= 128)
hipLaunchKernelGGL(( softmax_backward_kernel_v2<T, 4>)
, dim3(grid_dim), dim3(block_dim), 0, stream, out_grad, soft_inp, seq_length);
else if (seq_length <= 256)
hipLaunchKernelGGL(( softmax_backward_kernel_v2<T, 8>)
, dim3(grid_dim), dim3(block_dim), 0, stream, out_grad, soft_inp, seq_length);
else if (seq_length <= 384)
hipLaunchKernelGGL(( softmax_backward_kernel_v2<T, 12>)
, dim3(grid_dim), dim3(block_dim), 0, stream, out_grad, soft_inp, seq_length);
else if (seq_length <= 512)
hipLaunchKernelGGL(( softmax_backward_kernel_v2<T, 16>)
, dim3(grid_dim), dim3(block_dim), 0, stream, out_grad, soft_inp, seq_length);
else if (seq_length <= 768)
hipLaunchKernelGGL(( softmax_backward_kernel_v2<T, 24>)
, dim3(grid_dim), dim3(block_dim), 0, stream, out_grad, soft_inp, seq_length);
else if (seq_length <= 1024)
hipLaunchKernelGGL(( softmax_backward_kernel_v2<T, 32>)
, dim3(grid_dim), dim3(block_dim), 0, stream, out_grad, soft_inp, seq_length);
else if (seq_length <= 2048)
hipLaunchKernelGGL(( softmax_backward_kernel_v2<T, 64>)
, dim3(grid_dim), dim3(block_dim), 0, stream, out_grad, soft_inp, seq_length);
else
throw std::runtime_error(
std::string("Special sequence length found in softmax backward, seq_length: ") +
std::to_string(seq_length));
}
template void launch_attn_softmax_backward_v2<__half>(__half* out_grad,
const __half* soft_inp,
int batch_size,
int heads,
int seq_length,
hipStream_t stream);
template void launch_attn_softmax_backward_v2<float>(float* out_grad,
const float* soft_inp,
int batch_size,
int heads,
int seq_length,
hipStream_t stream);
csrc/transformer/hip/transform_kernels.hip
0 → 100644
View file @
eadbbe09
#include "hip/hip_runtime.h"
#include "hip/custom_hip_layers.h"
#define rows_trans 16
#define cols_trans 16
template <typename T>
__global__ void Transpose_Kernel(const T* inp, T* out, int row_width, int col_width)
{
__shared__ T data_block[rows_trans * (cols_trans + 1)];
int r = threadIdx.x / cols_trans;
int c = threadIdx.x % cols_trans;
int m = row_width / cols_trans;
int i = blockIdx.x / m * rows_trans + r;
int j = blockIdx.x % m * cols_trans + c;
int row_stride = rows_trans / ((rows_trans * cols_trans + THREADS - 1) / THREADS);
for (int k = 0; k < rows_trans; k += row_stride)
data_block[(k + r) * cols_trans + c] = inp[(i + k) * row_width + j];
__syncthreads();
i = blockIdx.x % m * rows_trans + r;
j = blockIdx.x / m * cols_trans + c;
for (int k = 0; k < rows_trans; k += row_stride)
out[(i + k) * col_width + j] = data_block[c * cols_trans + r + k];
}
template <>
void Transpose<__half>(const __half* inp_mat,
__half* out_mat,
int rows,
int cols,
hipStream_t stream)
{
int threads = THREADS;
hipLaunchKernelGGL(( Transpose_Kernel<__half>), dim3((rows * cols + threads - 1) / threads), dim3(threads), 0, stream,
inp_mat, out_mat, cols, rows);
}
template <>
void Transpose<float>(const float* inp_mat, float* out_mat, int rows, int cols, hipStream_t stream)
{
int threads = THREADS;
hipLaunchKernelGGL(( Transpose_Kernel<float>), dim3((rows * cols + threads - 1) / threads), dim3(threads), 0, stream,
inp_mat, out_mat, cols, rows);
}
template <typename T>
__global__ void transform_0213(T* output,
const T* vals,
int hidden_dim,
int seq_length,
int heads,
int head_ext);
template <>
__global__ void transform_0213<float>(float* output,
const float* vals,
int hidden_dim,
int seq_length,
int heads,
int head_ext)
{
int d0_stride = hidden_dim * seq_length;
int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads;
int d0_out_stride = d0_stride;
int d1_out_stride = d2_stride;
int d2_out_stride = d2_stride * seq_length;
int d0 = blockIdx.x; // Batch
int d1 = blockIdx.y / head_ext; // Sequence ID (0-127)
int d2 = threadIdx.y + (blockIdx.y % head_ext) * (heads / head_ext); // Head (0-11)
int d3 = threadIdx.x; // Values (groups of 4)
const float4* vals_vec = reinterpret_cast<const float4*>(vals);
float4* output_vec = reinterpret_cast<float4*>(output);
float4 inputs = vals_vec[d0 * d0_stride + d1 * d1_stride + d2 * d2_stride + d3];
output_vec[d0 * d0_out_stride + d1 * d1_out_stride + d2 * d2_out_stride + d3] = inputs;
}
template <>
__global__ void transform_0213<__half>(__half* output,
const __half* vals,
int hidden_dim,
int seq_length,
int heads,
int head_ext)
{
#if __CUDA_ARCH__ >= 700
int d0_stride = hidden_dim * seq_length;
int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads;
int d0_out_stride = d0_stride;
int d1_out_stride = d2_stride;
int d2_out_stride = d2_stride * seq_length;
int d0 = blockIdx.x; // Batch
int d1 = blockIdx.y / head_ext; // Sequence ID (0-127)
int d2 = threadIdx.y + (blockIdx.y % head_ext) * (heads / head_ext); // Head (0-11)
int d3 = threadIdx.x; // Values (groups of 4)
float4 vals_arr[1];
const float4* vals_vec = reinterpret_cast<const float4*>(vals);
float4* output_vec = reinterpret_cast<float4*>(output);
vals_arr[0] = vals_vec[d0 * d0_stride + d1 * d1_stride + d2 * d2_stride + d3];
output_vec[d0 * d0_out_stride + d1 * d1_out_stride + d2 * d2_out_stride + d3] = vals_arr[0];
#endif
}
template <>
void launch_transform_0213<float>(float* output,
const float* vals,
int batch_size,
int seq_length,
int hidden_dim,
int heads,
hipStream_t stream)
{
hidden_dim >>= 2;
int head_ext = (hidden_dim - 1) / MAX_THREADS + 1;
dim3 block_dim(hidden_dim / heads, (heads / head_ext));
dim3 grid_dim(batch_size, (seq_length * head_ext));
hipLaunchKernelGGL(( transform_0213<float>)
, dim3(grid_dim), dim3(block_dim), 0, stream, output, vals, hidden_dim, seq_length, heads, head_ext);
}
template <>
void launch_transform_0213<__half>(__half* output,
const __half* vals,
int batch_size,
int seq_length,
int hidden_dim,
int heads,
hipStream_t stream)
{
hidden_dim >>= 3;
int head_ext = (hidden_dim - 1) / MAX_THREADS + 1;
dim3 block_dim(hidden_dim / heads, (heads / head_ext));
dim3 grid_dim(batch_size, (seq_length * head_ext));
hipLaunchKernelGGL(( transform_0213<__half>)
, dim3(grid_dim), dim3(block_dim), 0, stream, output, vals, hidden_dim, seq_length, heads, head_ext);
}
// Bias add
template <typename T>
__global__ void bias_add_transform_0213(T* output,
const T* vals,
const T* bias,
int hidden_dim,
int seq_length,
int heads,
int head_ext);
template <>
__global__ void bias_add_transform_0213<float>(float* output,
const float* vals,
const float* bias,
int hidden_dim,
int seq_length,
int heads,
int head_ext)
{
int d0_stride = hidden_dim * seq_length;
int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads;
int d0_out_stride = d0_stride;
int d1_out_stride = d2_stride;
int d2_out_stride = d2_stride * seq_length;
int d0 = blockIdx.x; // Batch
int d1 = blockIdx.y; // Sequence ID (0-127)
int cnt = blockIdx.z / head_ext; // Hidden count
int d2 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head (0-11)
int d3 = threadIdx.x; // Values (groups of 4)
const float4* vals_vec = reinterpret_cast<const float4*>(vals);
const float4* bias_vec = reinterpret_cast<const float4*>(bias);
float4* output_vec = reinterpret_cast<float4*>(output);
float4 inputs = vals_vec[d0 * d0_stride * (gridDim.z / head_ext) + cnt * d1_stride +
d1 * d1_stride * (gridDim.z / head_ext) + d2 * d2_stride + d3];
float4 biases = bias_vec[cnt * d1_stride + d2 * d2_stride + d3];
float4 outputs;
outputs.x = inputs.x + biases.x;
outputs.y = inputs.y + biases.y;
outputs.z = inputs.z + biases.z;
outputs.w = inputs.w + biases.w;
output_vec[cnt * d0_out_stride * gridDim.x + d0 * d0_out_stride + d1 * d1_out_stride +
d2 * d2_out_stride + d3] = outputs;
}
#define ATTN_H 3
#define MAX_SEQ_LINE 10
template <>
__global__ void bias_add_transform_0213<__half>(__half* output,
const __half* vals,
const __half* bias,
int hidden_dim,
int seq_length,
int heads,
int head_ext)
{
#if __CUDA_ARCH__ >= 700
int d0_stride = hidden_dim * seq_length;
int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads;
int d2_out_stride = d2_stride * seq_length;
int d0 = blockIdx.x; // Batch
int d1 = blockIdx.y; // Sequence ID (0-127)
int cnt = blockIdx.z / head_ext; // Hidden count
int d2 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head (0-11)
int d3 = threadIdx.x; // Values (groups of 4)
float4 vals_arr;
float4 bias_arr;
float4 output_arr;
__half2* vals_half = reinterpret_cast<__half2*>(&vals_arr);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_arr);
__half2* output_half = reinterpret_cast<__half2*>(&output_arr);
const float4* vals_vec = reinterpret_cast<const float4*>(vals);
const float4* bias_vec = reinterpret_cast<const float4*>(bias);
float4* output_vec = reinterpret_cast<float4*>(output);
vals_vec += (d0 * d0_stride * (gridDim.z / head_ext));
vals_vec += (d1 * d1_stride * (gridDim.z / head_ext));
vals_vec += (cnt * d1_stride);
vals_vec += (d2 * d2_stride);
bias_vec += (cnt * d1_stride);
bias_vec += (d2 * d2_stride);
output_vec += (cnt * d0_stride * gridDim.x);
output_vec += (d1 * d2_stride);
output_vec += (d0 * d0_stride);
output_vec += (d2 * d2_out_stride);
bias_arr = bias_vec[d3];
vals_arr = vals_vec[d3];
#if defined(__ACC_HALF__)
output_half[0] = vals_half[0] + bias_half[0];
output_half[1] = vals_half[1] + bias_half[1];
output_half[2] = vals_half[2] + bias_half[2];
output_half[3] = vals_half[3] + bias_half[3];
#else
float2 bias_arr_f[4];
float2 vals_arr_f[4];
#pragma unroll
for (int l = 0; l < 4; l++) {
bias_arr_f[l] = __half22float2(bias_half[l]);
vals_arr_f[l] = __half22float2(vals_half[l]);
vals_arr_f[l].x += bias_arr_f[l].x;
vals_arr_f[l].y += bias_arr_f[l].y;
output_half[l] = __float22half2_rn(vals_arr_f[l]);
}
#endif
output_vec[d3] = output_arr;
#endif
}
__global__ void bias_add_transform_0213_v2(__half* output,
const __half* vals,
const __half* bias,
int hidden_dim,
int seq_length,
int heads)
{
#if __CUDA_ARCH__ >= 700
__shared__ float4 in_data[3072];
int d0_stride = hidden_dim * seq_length;
int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads;
int iteration_stride = d1_stride * blockDim.z; // Hidden * 3 / 8
int batch_stride = d0_stride * blockDim.z; // Hidden * S * 3 / 8
int d0_out_stride = d0_stride;
int d1_out_stride = d2_stride;
int d2_out_stride = d2_stride * seq_length;
int d0 = blockIdx.x; // Batch
int d1 = blockIdx.y; // Sequence ID (0-127)
int cnt = threadIdx.z; // blockIdx.z; // Hidden count
int d2 = threadIdx.y; // Head (0-11)
int d3 = threadIdx.x; // Values (groups of 4)
float4 vals_arr[1];
float4 bias_arr[1];
float4 output_arr[1];
__half2* vals_half = reinterpret_cast<__half2*>(vals_arr);
__half2* bias_half = reinterpret_cast<__half2*>(bias_arr);
__half2* output_half = reinterpret_cast<__half2*>(output_arr);
const float4* vals_vec = reinterpret_cast<const float4*>(vals);
const float4* bias_vec = reinterpret_cast<const float4*>(bias);
float4* output_vec = reinterpret_cast<float4*>(output);
int iter_index = cnt * d1_stride + d2 * d2_stride + d3;
int input_offset = d0 * batch_stride + d1 * (iteration_stride << 1);
bias_arr[0] = bias_vec[iter_index];
#pragma unroll
for (int iter = 0; iter < 2; iter++) {
int iter_id = iter * iteration_stride + iter_index;
vals_arr[0] = vals_vec[input_offset + iter_id];
output_half[0] = vals_half[0] + bias_half[0];
output_half[1] = vals_half[1] + bias_half[1];
output_half[2] = vals_half[2] + bias_half[2];
output_half[3] = vals_half[3] + bias_half[3];
in_data[iter_id] = output_arr[0];
}
__syncthreads();
iteration_stride = blockDim.z * (blockDim.y >> 1);
int matrix_stride = (d0_out_stride * gridDim.x);
int head_count = (d2 >> 1) + cnt * (blockDim.y >> 1);
int out_index = d0 * d0_out_stride + d1 * (d1_out_stride << 1) + d3 + (d2 % 2) * d2_stride;
#pragma unroll
for (int iter = 0; iter < 2; iter++) {
int iter_row = (iter * iteration_stride) + head_count;
int iter_offset =
(iter_row % blockDim.y) * d2_out_stride + (iter_row / blockDim.y) * matrix_stride;
output_vec[out_index + iter_offset] =
in_data[iter_row * d2_stride + d3 + (d2 % 2) * (d1_stride * blockDim.z)];
}
#endif
}
// [B S C*H] - > C * [B A S N]
template <>
void launch_bias_add_transform_0213<float>(float* output,
const float* vals,
const float* bias,
int batch_size,
int seq_length,
int hidden_dim,
int heads,
hipStream_t stream,
int trans_count)
{
hidden_dim >>= 2;
int head_ext = (hidden_dim - 1) / MAX_THREADS + 1;
dim3 block_dim(hidden_dim / heads, (heads / head_ext));
dim3 grid_dim(batch_size, seq_length, (trans_count * head_ext));
hipLaunchKernelGGL(( bias_add_transform_0213<float>), dim3(grid_dim), dim3(block_dim), 0, stream,
output, vals, bias, hidden_dim, seq_length, heads, head_ext);
}
template <>
void launch_bias_add_transform_0213<__half>(__half* output,
const __half* vals,
const __half* bias,
int batch_size,
int seq_length,
int hidden_dim,
int heads,
hipStream_t stream,
int trans_count)
{
hidden_dim >>= 3;
if (hidden_dim > 128 || hidden_dim < 16) {
int head_ext = (hidden_dim - 1) / MAX_THREADS + 1;
dim3 block_dim(hidden_dim / heads, (heads / head_ext));
dim3 grid_dim(batch_size, seq_length, (trans_count * head_ext));
hipLaunchKernelGGL(( bias_add_transform_0213<__half>), dim3(grid_dim), dim3(block_dim), 0, stream,
output, vals, bias, hidden_dim, seq_length, heads, head_ext);
} else {
dim3 block_dim(hidden_dim / heads, heads, trans_count);
dim3 grid_dim(batch_size, seq_length / 2);
hipLaunchKernelGGL(( bias_add_transform_0213_v2), dim3(grid_dim), dim3(block_dim), 0, stream,
output, vals, bias, hidden_dim, seq_length, heads);
}
}
template <typename T>
__global__ void transform4d_0213(T* out,
const T* in,
int heads,
int seq_length,
int hidden_dim,
int head_ext);
template <>
__global__ void transform4d_0213<float>(float* out,
const float* in,
int heads,
int seq_length,
int hidden_dim,
int head_ext)
{
int d0_stride = hidden_dim * seq_length;
int d1_stride = d0_stride / heads;
int d2_stride = hidden_dim / heads;
int d0_out_stride = d0_stride;
int d1_out_stride = d2_stride;
int d2_out_stride = hidden_dim;
int d0 = blockIdx.x; // Batch
int d1 = blockIdx.y / ((seq_length - 1) / blockDim.y + 1); // Head
int d2 = (threadIdx.y + blockDim.y * blockIdx.y) % seq_length;
int cnt = blockIdx.z;
int d3 = threadIdx.x; // Values (groups of 8)
if (d2 < seq_length) {
const float4* in_vec = reinterpret_cast<const float4*>(in);
float4* out_vec = reinterpret_cast<float4*>(out);
float4 vals_vec = in_vec[cnt * d0_stride * gridDim.x + d0 * d0_stride + d1 * d1_stride +
d2 * d2_stride + d3];
out_vec[d0 * d0_out_stride * gridDim.z + cnt * d2_out_stride + d1 * d1_out_stride +
d2 * d2_out_stride * gridDim.z + d3] = vals_vec;
}
}
template <>
__global__ void transform4d_0213<__half>(__half* out,
const __half* in,
int heads,
int seq_length,
int hidden_dim,
int head_ext)
{
#if __CUDA_ARCH__ >= 700
int d0_stride = hidden_dim * (seq_length / head_ext);
int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads;
int d0 = blockIdx.x; // Batch
int d1 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head
int d2 = blockIdx.z / head_ext; // Sequence
int cnt = blockIdx.y; // Hidden count
int d3 = threadIdx.x; // Values (groups of 8)
const float4* in_vec = reinterpret_cast<const float4*>(in);
float4* out_vec = reinterpret_cast<float4*>(out);
in_vec += (cnt * d0_stride * gridDim.x);
in_vec += (d0 * d0_stride);
in_vec += (d2 * d2_stride);
in_vec += (d1 * d2_stride * seq_length);
out_vec += (cnt * d1_stride);
out_vec += (d1 * d2_stride);
out_vec += (d0 * d0_stride * gridDim.y);
out_vec += (d2 * d1_stride * gridDim.y);
out_vec[d3] = in_vec[d3];
#endif
}
__global__ void transform4d_0213_v2(__half* out,
const __half* in,
int heads,
int seq_length,
int hidden_dim)
{
#if __CUDA_ARCH__ >= 700
__shared__ float4 in_data[3072];
int d0_stride = hidden_dim * seq_length;
int d1_stride = hidden_dim;
int d2_stride = hidden_dim / heads;
int d0 = blockIdx.x; // Batch
int d1 = threadIdx.y; // Head
int d2 = blockIdx.y; // Sequence
int cnt = threadIdx.z; // Hidden count
int d3 = threadIdx.x; // Values (groups of 8)
const float4* in_vec = reinterpret_cast<const float4*>(in);
float4* out_vec = reinterpret_cast<float4*>(out);
int input_offset = d0 * d0_stride + d2 * (d2_stride << 1) + d3 + (d1 % 2) * d2_stride;
int head_count = (d1 >> 1) + cnt * (blockDim.y >> 1);
int iteration_stride = blockDim.z * (blockDim.y >> 1);
int matrix_stride = (d0_stride * gridDim.x);
#pragma unroll
for (int iter = 0; iter < 2; iter++) {
int iter_row = iter * iteration_stride + head_count;
int iter_offset = (iter_row % blockDim.y) * d2_stride;
in_data[d3 + iter_offset + (iter_row / blockDim.y + (d1 % 2) * blockDim.z) * d1_stride] =
in_vec[input_offset + iter_offset * seq_length +
(iter_row / blockDim.y) * matrix_stride];
}
__syncthreads();
iteration_stride = d1_stride * blockDim.z;
int iter_index = cnt * d1_stride + d1 * d2_stride + d3;
int output_offset = d0 * d0_stride * blockDim.z + d2 * (iteration_stride << 1);
#pragma unroll
for (int iter = 0; iter < 2; iter++) {
int iter_id = iter * iteration_stride + iter_index;
out_vec[output_offset + iter_id] = in_data[iter_id];
}
#endif
}
// 3 * [B A S N] - > [B S C*H]
template <>
void launch_transform4d_0213<float>(float* out,
const float* in,
int batch_size,
int heads,
int seq_length,
int hidden_dim,
hipStream_t stream,
int trans_count)
{
hidden_dim >>= 2;
dim3 grid_dims(batch_size, heads * ((seq_length - 1) / 8 + 1), trans_count);
dim3 block_dims(hidden_dim / heads, 8);
hipLaunchKernelGGL(( transform4d_0213<float>)
, dim3(grid_dims), dim3(block_dims), 0, stream, out, in, heads, seq_length, hidden_dim, 1);
}
template <>
void launch_transform4d_0213<__half>(__half* out,
const __half* in,
int batch_size,
int heads,
int seq_length,
int hidden_dim,
hipStream_t stream,
int trans_count)
{
hidden_dim >>= 3;
if (hidden_dim > 128 || hidden_dim < 16) {
int head_ext = (hidden_dim - 1) / MAX_THREADS + 1;
dim3 grid_dims(batch_size, trans_count, (seq_length * head_ext));
dim3 block_dims(hidden_dim / heads, (heads / head_ext));
hipLaunchKernelGGL(( transform4d_0213<__half>), dim3(grid_dims), dim3(block_dims), 0, stream,
out, in, heads, seq_length, hidden_dim, head_ext);
} else {
dim3 grid_dims(batch_size, seq_length / 2);
dim3 block_dims(hidden_dim / heads, heads, trans_count);
hipLaunchKernelGGL(( transform4d_0213_v2), dim3(grid_dims), dim3(block_dims), 0, stream,
out, in, heads, seq_length, hidden_dim);
}
}
csrc/utils/hip/flatten_unflatten.cpp
0 → 100644
View file @
eadbbe09
/*
Copyright 2020 The Microsoft DeepSpeed Team
Copyright NVIDIA/apex
This file is adapted from fused adam in NVIDIA/apex, commit a109f85
*/
#include <torch/csrc/utils/tensor_flatten.h>
#include <torch/extension.h>
// https://github.com/pytorch/pytorch/blob/master/torch/csrc/utils/tensor_flatten.h
at
::
Tensor
flatten
(
std
::
vector
<
at
::
Tensor
>
tensors
)
{
return
torch
::
utils
::
flatten_dense_tensors
(
tensors
);
}
std
::
vector
<
at
::
Tensor
>
unflatten
(
at
::
Tensor
flat
,
std
::
vector
<
at
::
Tensor
>
tensors
)
{
return
torch
::
utils
::
unflatten_dense_tensors
(
flat
,
tensors
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"flatten"
,
&
flatten
,
"Flatten dense tensors"
);
m
.
def
(
"unflatten"
,
&
unflatten
,
"Unflatten dense tensors"
);
}
deepspeed/ops/csrc
deleted
120000 → 0
View file @
ab5534fc
../../csrc
\ No newline at end of file
deepspeed/ops/csrc/adam/compat.h
0 → 100644
View file @
eadbbe09
/* Copyright 2020 The Microsoft DeepSpeed Team
Copyright NVIDIA/apex
This file is adapted from fused adam in NVIDIA/apex, commit a109f85
*/
#ifndef TORCH_CHECK
#define TORCH_CHECK AT_CHECK
#endif
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif
deepspeed/ops/csrc/adam/cpu_adam.cpp
0 → 100644
View file @
eadbbe09
#include "cpu_adam.h"
#include <cuda_runtime_api.h>
#include <math.h>
#include <omp.h>
#include <torch/extension.h>
#include <iostream>
#include <memory>
#include <type_traits>
#include <unordered_map>
#include "cublas_v2.h"
#include "cuda.h"
#include "curand.h"
#include "custom_cuda_layers.h"
static
std
::
unordered_map
<
int
,
std
::
shared_ptr
<
void
>>
s_optimizers
;
#define ROUND_DOWN(size, step) ((size) & ~((step)-1))
// C++ interface
void
Adam_Optimizer
::
Step
(
float
*
_params
,
float
*
grads
,
float
*
_exp_avg
,
float
*
_exp_avg_sq
,
size_t
_param_size
,
__half
*
dev_params
)
{
float
betta1_minus1
=
1
-
_betta1
;
float
betta2_minus1
=
1
-
_betta2
;
float
step_size
=
-
1
*
_alpha
/
_bias_correction1
;
float
w_decay
=
-
1
*
_alpha
*
_weight_decay
;
size_t
rounded_size
=
0
;
#if defined(__AVX512__) or defined(__AVX256__)
AVX_Data
betta1_4
;
betta1_4
.
data
=
SIMD_SET
(
_betta1
);
AVX_Data
betta2_4
;
betta2_4
.
data
=
SIMD_SET
(
_betta2
);
AVX_Data
betta1_minus1_4
;
betta1_minus1_4
.
data
=
SIMD_SET
(
betta1_minus1
);
AVX_Data
betta2_minus1_4
;
betta2_minus1_4
.
data
=
SIMD_SET
(
betta2_minus1
);
AVX_Data
bias2_sqrt
;
bias2_sqrt
.
data
=
SIMD_SET
(
_bias_correction2
);
AVX_Data
eps_4
;
eps_4
.
data
=
SIMD_SET
(
_eps
);
AVX_Data
step_size_4
;
step_size_4
.
data
=
SIMD_SET
(
step_size
);
AVX_Data
weight_decay4
;
if
(
_weight_decay
>
0
)
weight_decay4
.
data
=
(
_adamw_mode
?
SIMD_SET
(
w_decay
)
:
SIMD_SET
(
_weight_decay
));
rounded_size
=
ROUND_DOWN
(
_param_size
,
SIMD_WIDTH
);
for
(
size_t
t
=
0
;
t
<
rounded_size
;
t
+=
TILE
)
{
size_t
copy_size
=
TILE
;
if
((
t
+
TILE
)
>
rounded_size
)
copy_size
=
rounded_size
-
t
;
size_t
offset
=
copy_size
+
t
;
if
((
t
/
TILE
)
>=
2
)
{
cudaStreamSynchronize
(
_streams
[
_buf_index
]);
}
#pragma omp parallel for
for
(
size_t
i
=
t
;
i
<
offset
;
i
+=
SIMD_WIDTH
)
{
AVX_Data
grad_4
;
grad_4
.
data
=
SIMD_LOAD
(
grads
+
i
);
AVX_Data
momentum_4
;
momentum_4
.
data
=
SIMD_LOAD
(
_exp_avg
+
i
);
AVX_Data
variance_4
;
variance_4
.
data
=
SIMD_LOAD
(
_exp_avg_sq
+
i
);
AVX_Data
param_4
;
param_4
.
data
=
SIMD_LOAD
(
_params
+
i
);
if
(
_weight_decay
>
0
&&
!
_adamw_mode
)
{
grad_4
.
data
=
SIMD_FMA
(
param_4
.
data
,
weight_decay4
.
data
,
grad_4
.
data
);
}
momentum_4
.
data
=
SIMD_MUL
(
momentum_4
.
data
,
betta1_4
.
data
);
momentum_4
.
data
=
SIMD_FMA
(
grad_4
.
data
,
betta1_minus1_4
.
data
,
momentum_4
.
data
);
variance_4
.
data
=
SIMD_MUL
(
variance_4
.
data
,
betta2_4
.
data
);
grad_4
.
data
=
SIMD_MUL
(
grad_4
.
data
,
grad_4
.
data
);
variance_4
.
data
=
SIMD_FMA
(
grad_4
.
data
,
betta2_minus1_4
.
data
,
variance_4
.
data
);
grad_4
.
data
=
SIMD_SQRT
(
variance_4
.
data
);
grad_4
.
data
=
SIMD_FMA
(
grad_4
.
data
,
bias2_sqrt
.
data
,
eps_4
.
data
);
grad_4
.
data
=
SIMD_DIV
(
momentum_4
.
data
,
grad_4
.
data
);
if
(
_weight_decay
>
0
&&
_adamw_mode
)
{
param_4
.
data
=
SIMD_FMA
(
param_4
.
data
,
weight_decay4
.
data
,
param_4
.
data
);
}
param_4
.
data
=
SIMD_FMA
(
grad_4
.
data
,
step_size_4
.
data
,
param_4
.
data
);
SIMD_STORE
(
_params
+
i
,
param_4
.
data
);
if
(
dev_params
)
SIMD_STORE
(
_doubled_buffer
[
_buf_index
]
+
(
i
-
t
),
param_4
.
data
);
SIMD_STORE
(
_exp_avg
+
i
,
momentum_4
.
data
);
SIMD_STORE
(
_exp_avg_sq
+
i
,
variance_4
.
data
);
}
if
(
dev_params
)
{
launch_param_update
(
_doubled_buffer
[
_buf_index
],
dev_params
+
t
,
copy_size
,
_streams
[
_buf_index
]);
_buf_index
=
!
_buf_index
;
}
}
#endif
if
(
_param_size
>
rounded_size
)
{
for
(
size_t
t
=
rounded_size
;
t
<
_param_size
;
t
+=
TILE
)
{
size_t
copy_size
=
TILE
;
if
((
t
+
TILE
)
>
_param_size
)
copy_size
=
_param_size
-
t
;
size_t
offset
=
copy_size
+
t
;
if
((
t
/
TILE
)
>=
2
)
{
cudaStreamSynchronize
(
_streams
[
_buf_index
]);
}
#pragma omp parallel for
for
(
size_t
k
=
t
;
k
<
offset
;
k
++
)
{
float
grad
=
grads
[
k
];
float
param
=
_params
[
k
];
float
momentum
=
_exp_avg
[
k
];
float
variance
=
_exp_avg_sq
[
k
];
if
(
_weight_decay
>
0
&&
!
_adamw_mode
)
{
grad
=
param
*
_weight_decay
+
grad
;
}
momentum
=
momentum
*
_betta1
;
momentum
=
grad
*
betta1_minus1
+
momentum
;
variance
=
variance
*
_betta2
;
grad
=
grad
*
grad
;
variance
=
grad
*
betta2_minus1
+
variance
;
grad
=
sqrt
(
variance
);
grad
=
grad
*
_bias_correction2
+
_eps
;
grad
=
momentum
/
grad
;
if
(
_weight_decay
>
0
&&
_adamw_mode
)
{
param
+=
w_decay
*
param
;
}
param
=
grad
*
step_size
+
param
;
if
(
dev_params
)
_doubled_buffer
[
_buf_index
][
k
-
t
]
=
param
;
_params
[
k
]
=
param
;
_exp_avg
[
k
]
=
momentum
;
_exp_avg_sq
[
k
]
=
variance
;
}
if
(
dev_params
)
{
launch_param_update
(
_doubled_buffer
[
_buf_index
],
dev_params
+
t
,
(
copy_size
),
_streams
[
_buf_index
]);
_buf_index
=
!
_buf_index
;
}
}
}
}
void
Adam_Optimizer
::
Step_4
(
float
*
_params
,
float
*
grads
,
float
*
_exp_avg
,
float
*
_exp_avg_sq
,
size_t
_param_size
,
__half
*
dev_params
)
{
size_t
rounded_size
=
0
;
#if defined(__AVX512__) or defined(__AVX256__)
AVX_Data
betta1_4
;
betta1_4
.
data
=
SIMD_SET
(
_betta1
);
AVX_Data
betta2_4
;
betta2_4
.
data
=
SIMD_SET
(
_betta2
);
float
betta1_minus1
=
1
-
_betta1
;
float
betta2_minus1
=
1
-
_betta2
;
AVX_Data
betta1_minus1_4
;
betta1_minus1_4
.
data
=
SIMD_SET
(
betta1_minus1
);
AVX_Data
betta2_minus1_4
;
betta2_minus1_4
.
data
=
SIMD_SET
(
betta2_minus1
);
AVX_Data
bias2_sqrt
;
bias2_sqrt
.
data
=
SIMD_SET
(
_bias_correction2
);
AVX_Data
eps_4
;
eps_4
.
data
=
SIMD_SET
(
_eps
);
float
step_size
=
-
1
*
_alpha
/
_bias_correction1
;
AVX_Data
step_size_4
;
step_size_4
.
data
=
SIMD_SET
(
step_size
);
float
w_decay
=
-
1
*
_alpha
*
_weight_decay
;
AVX_Data
weight_decay4
;
if
(
_weight_decay
>
0
)
weight_decay4
.
data
=
(
_adamw_mode
?
SIMD_SET
(
w_decay
)
:
SIMD_SET
(
_weight_decay
));
rounded_size
=
ROUND_DOWN
(
_param_size
,
(
SIMD_WIDTH
<<
2
));
for
(
size_t
t
=
0
;
t
<
rounded_size
;
t
+=
TILE
)
{
size_t
copy_size
=
TILE
;
if
((
t
+
TILE
)
>
rounded_size
)
copy_size
=
rounded_size
-
t
;
size_t
offset
=
copy_size
+
t
;
if
((
t
/
TILE
)
>=
2
)
{
cudaStreamSynchronize
(
_streams
[
_buf_index
]);
}
#pragma omp parallel for
for
(
size_t
i
=
t
;
i
<
offset
;
i
+=
(
SIMD_WIDTH
<<
2
))
{
AVX_Data
grad_4
[
4
];
grad_4
[
0
].
data
=
SIMD_LOAD
(
grads
+
i
);
grad_4
[
1
].
data
=
SIMD_LOAD
(
grads
+
i
+
SIMD_WIDTH
);
grad_4
[
2
].
data
=
SIMD_LOAD
(
grads
+
i
+
(
SIMD_WIDTH
<<
1
));
grad_4
[
3
].
data
=
SIMD_LOAD
(
grads
+
i
+
SIMD_WIDTH
*
3
);
AVX_Data
momentum_4
[
4
];
momentum_4
[
0
].
data
=
SIMD_LOAD
(
_exp_avg
+
i
);
momentum_4
[
1
].
data
=
SIMD_LOAD
(
_exp_avg
+
i
+
SIMD_WIDTH
);
momentum_4
[
2
].
data
=
SIMD_LOAD
(
_exp_avg
+
i
+
(
SIMD_WIDTH
<<
1
));
momentum_4
[
3
].
data
=
SIMD_LOAD
(
_exp_avg
+
i
+
SIMD_WIDTH
*
3
);
AVX_Data
variance_4
[
4
];
variance_4
[
0
].
data
=
SIMD_LOAD
(
_exp_avg_sq
+
i
);
variance_4
[
1
].
data
=
SIMD_LOAD
(
_exp_avg_sq
+
i
+
SIMD_WIDTH
);
variance_4
[
2
].
data
=
SIMD_LOAD
(
_exp_avg_sq
+
i
+
(
SIMD_WIDTH
<<
1
));
variance_4
[
3
].
data
=
SIMD_LOAD
(
_exp_avg_sq
+
i
+
SIMD_WIDTH
*
3
);
AVX_Data
param_4
[
4
];
param_4
[
0
].
data
=
SIMD_LOAD
(
_params
+
i
);
param_4
[
1
].
data
=
SIMD_LOAD
(
_params
+
i
+
SIMD_WIDTH
);
param_4
[
2
].
data
=
SIMD_LOAD
(
_params
+
i
+
(
SIMD_WIDTH
<<
1
));
param_4
[
3
].
data
=
SIMD_LOAD
(
_params
+
i
+
SIMD_WIDTH
*
3
);
if
(
_weight_decay
>
0
&&
!
_adamw_mode
)
{
grad_4
[
0
].
data
=
SIMD_FMA
(
param_4
[
0
].
data
,
weight_decay4
.
data
,
grad_4
[
0
].
data
);
grad_4
[
1
].
data
=
SIMD_FMA
(
param_4
[
1
].
data
,
weight_decay4
.
data
,
grad_4
[
1
].
data
);
grad_4
[
2
].
data
=
SIMD_FMA
(
param_4
[
2
].
data
,
weight_decay4
.
data
,
grad_4
[
2
].
data
);
grad_4
[
3
].
data
=
SIMD_FMA
(
param_4
[
3
].
data
,
weight_decay4
.
data
,
grad_4
[
3
].
data
);
}
momentum_4
[
0
].
data
=
SIMD_MUL
(
momentum_4
[
0
].
data
,
betta1_4
.
data
);
momentum_4
[
0
].
data
=
SIMD_FMA
(
grad_4
[
0
].
data
,
betta1_minus1_4
.
data
,
momentum_4
[
0
].
data
);
momentum_4
[
1
].
data
=
SIMD_MUL
(
momentum_4
[
1
].
data
,
betta1_4
.
data
);
momentum_4
[
1
].
data
=
SIMD_FMA
(
grad_4
[
1
].
data
,
betta1_minus1_4
.
data
,
momentum_4
[
1
].
data
);
momentum_4
[
2
].
data
=
SIMD_MUL
(
momentum_4
[
2
].
data
,
betta1_4
.
data
);
momentum_4
[
2
].
data
=
SIMD_FMA
(
grad_4
[
2
].
data
,
betta1_minus1_4
.
data
,
momentum_4
[
2
].
data
);
momentum_4
[
3
].
data
=
SIMD_MUL
(
momentum_4
[
3
].
data
,
betta1_4
.
data
);
momentum_4
[
3
].
data
=
SIMD_FMA
(
grad_4
[
3
].
data
,
betta1_minus1_4
.
data
,
momentum_4
[
3
].
data
);
variance_4
[
0
].
data
=
SIMD_MUL
(
variance_4
[
0
].
data
,
betta2_4
.
data
);
variance_4
[
1
].
data
=
SIMD_MUL
(
variance_4
[
1
].
data
,
betta2_4
.
data
);
variance_4
[
2
].
data
=
SIMD_MUL
(
variance_4
[
2
].
data
,
betta2_4
.
data
);
variance_4
[
3
].
data
=
SIMD_MUL
(
variance_4
[
3
].
data
,
betta2_4
.
data
);
grad_4
[
0
].
data
=
SIMD_MUL
(
grad_4
[
0
].
data
,
grad_4
[
0
].
data
);
grad_4
[
1
].
data
=
SIMD_MUL
(
grad_4
[
1
].
data
,
grad_4
[
1
].
data
);
grad_4
[
2
].
data
=
SIMD_MUL
(
grad_4
[
2
].
data
,
grad_4
[
2
].
data
);
grad_4
[
3
].
data
=
SIMD_MUL
(
grad_4
[
3
].
data
,
grad_4
[
3
].
data
);
variance_4
[
0
].
data
=
SIMD_FMA
(
grad_4
[
0
].
data
,
betta2_minus1_4
.
data
,
variance_4
[
0
].
data
);
variance_4
[
1
].
data
=
SIMD_FMA
(
grad_4
[
1
].
data
,
betta2_minus1_4
.
data
,
variance_4
[
1
].
data
);
variance_4
[
2
].
data
=
SIMD_FMA
(
grad_4
[
2
].
data
,
betta2_minus1_4
.
data
,
variance_4
[
2
].
data
);
variance_4
[
3
].
data
=
SIMD_FMA
(
grad_4
[
3
].
data
,
betta2_minus1_4
.
data
,
variance_4
[
3
].
data
);
grad_4
[
0
].
data
=
SIMD_SQRT
(
variance_4
[
0
].
data
);
grad_4
[
1
].
data
=
SIMD_SQRT
(
variance_4
[
1
].
data
);
grad_4
[
2
].
data
=
SIMD_SQRT
(
variance_4
[
2
].
data
);
grad_4
[
3
].
data
=
SIMD_SQRT
(
variance_4
[
3
].
data
);
grad_4
[
0
].
data
=
SIMD_FMA
(
grad_4
[
0
].
data
,
bias2_sqrt
.
data
,
eps_4
.
data
);
grad_4
[
1
].
data
=
SIMD_FMA
(
grad_4
[
1
].
data
,
bias2_sqrt
.
data
,
eps_4
.
data
);
grad_4
[
2
].
data
=
SIMD_FMA
(
grad_4
[
2
].
data
,
bias2_sqrt
.
data
,
eps_4
.
data
);
grad_4
[
3
].
data
=
SIMD_FMA
(
grad_4
[
3
].
data
,
bias2_sqrt
.
data
,
eps_4
.
data
);
grad_4
[
0
].
data
=
SIMD_DIV
(
momentum_4
[
0
].
data
,
grad_4
[
0
].
data
);
grad_4
[
1
].
data
=
SIMD_DIV
(
momentum_4
[
1
].
data
,
grad_4
[
1
].
data
);
grad_4
[
2
].
data
=
SIMD_DIV
(
momentum_4
[
2
].
data
,
grad_4
[
2
].
data
);
grad_4
[
3
].
data
=
SIMD_DIV
(
momentum_4
[
3
].
data
,
grad_4
[
3
].
data
);
if
(
_weight_decay
>
0
&&
_adamw_mode
)
{
param_4
[
0
].
data
=
SIMD_FMA
(
param_4
[
0
].
data
,
weight_decay4
.
data
,
param_4
[
0
].
data
);
param_4
[
1
].
data
=
SIMD_FMA
(
param_4
[
1
].
data
,
weight_decay4
.
data
,
param_4
[
1
].
data
);
param_4
[
2
].
data
=
SIMD_FMA
(
param_4
[
2
].
data
,
weight_decay4
.
data
,
param_4
[
2
].
data
);
param_4
[
3
].
data
=
SIMD_FMA
(
param_4
[
3
].
data
,
weight_decay4
.
data
,
param_4
[
3
].
data
);
}
param_4
[
0
].
data
=
SIMD_FMA
(
grad_4
[
0
].
data
,
step_size_4
.
data
,
param_4
[
0
].
data
);
param_4
[
1
].
data
=
SIMD_FMA
(
grad_4
[
1
].
data
,
step_size_4
.
data
,
param_4
[
1
].
data
);
param_4
[
2
].
data
=
SIMD_FMA
(
grad_4
[
2
].
data
,
step_size_4
.
data
,
param_4
[
2
].
data
);
param_4
[
3
].
data
=
SIMD_FMA
(
grad_4
[
3
].
data
,
step_size_4
.
data
,
param_4
[
3
].
data
);
SIMD_STORE
(
_params
+
i
,
param_4
[
0
].
data
);
SIMD_STORE
(
_params
+
i
+
SIMD_WIDTH
,
param_4
[
1
].
data
);
SIMD_STORE
(
_params
+
i
+
(
SIMD_WIDTH
<<
1
),
param_4
[
2
].
data
);
SIMD_STORE
(
_params
+
i
+
SIMD_WIDTH
*
3
,
param_4
[
3
].
data
);
if
(
dev_params
)
{
SIMD_STORE
(
_doubled_buffer
[
_buf_index
]
+
(
i
-
t
),
param_4
[
0
].
data
);
SIMD_STORE
(
_doubled_buffer
[
_buf_index
]
+
(
i
-
t
)
+
SIMD_WIDTH
,
param_4
[
1
].
data
);
SIMD_STORE
(
_doubled_buffer
[
_buf_index
]
+
(
i
-
t
)
+
(
SIMD_WIDTH
<<
1
),
param_4
[
2
].
data
);
SIMD_STORE
(
_doubled_buffer
[
_buf_index
]
+
(
i
-
t
)
+
SIMD_WIDTH
*
3
,
param_4
[
3
].
data
);
}
SIMD_STORE
(
_exp_avg
+
i
,
momentum_4
[
0
].
data
);
SIMD_STORE
(
_exp_avg
+
i
+
SIMD_WIDTH
,
momentum_4
[
1
].
data
);
SIMD_STORE
(
_exp_avg
+
i
+
(
SIMD_WIDTH
<<
1
),
momentum_4
[
2
].
data
);
SIMD_STORE
(
_exp_avg
+
i
+
SIMD_WIDTH
*
3
,
momentum_4
[
3
].
data
);
SIMD_STORE
(
_exp_avg_sq
+
i
,
variance_4
[
0
].
data
);
SIMD_STORE
(
_exp_avg_sq
+
i
+
SIMD_WIDTH
,
variance_4
[
1
].
data
);
SIMD_STORE
(
_exp_avg_sq
+
i
+
(
SIMD_WIDTH
<<
1
),
variance_4
[
2
].
data
);
SIMD_STORE
(
_exp_avg_sq
+
i
+
SIMD_WIDTH
*
3
,
variance_4
[
3
].
data
);
}
if
(
dev_params
)
{
launch_param_update
(
_doubled_buffer
[
_buf_index
],
dev_params
+
t
,
copy_size
,
_streams
[
_buf_index
]);
_buf_index
=
!
_buf_index
;
}
}
#endif
if
(
_param_size
>
rounded_size
)
Step
((
_params
+
rounded_size
),
(
grads
+
rounded_size
),
(
_exp_avg
+
rounded_size
),
(
_exp_avg_sq
+
rounded_size
),
(
_param_size
-
rounded_size
),
(
dev_params
!=
nullptr
?
(
dev_params
+
rounded_size
)
:
dev_params
));
}
int
create_adam_optimizer
(
int
optimizer_id
,
float
alpha
=
1e-3
,
float
betta1
=
0.9
,
float
betta2
=
0.999
,
float
eps
=
1e-8
,
float
weight_decay
=
0
,
bool
adamw_mode
=
true
)
{
auto
opt
=
std
::
make_shared
<
Adam_Optimizer
>
(
alpha
,
betta1
,
betta2
,
eps
,
weight_decay
,
adamw_mode
);
s_optimizers
[
optimizer_id
]
=
opt
;
#if defined(__AVX512__)
std
::
cout
<<
"Adam Optimizer #"
<<
optimizer_id
<<
" is created with AVX512 arithmetic capability."
<<
std
::
endl
;
printf
(
"Config: alpha=%f, betas=(%f, %f), weight_decay=%f, adam_w=%d
\n
"
,
alpha
,
betta1
,
betta2
,
weight_decay
,
(
int
)
adamw_mode
);
#else
#if defined(__AVX256__)
std
::
cout
<<
"Adam Optimizer #"
<<
optimizer_id
<<
" is created with AVX2 arithmetic capability."
<<
std
::
endl
;
printf
(
"Config: alpha=%f, betas=(%f, %f), weight_decay=%f, adam_w=%d
\n
"
,
alpha
,
betta1
,
betta2
,
weight_decay
,
(
int
)
adamw_mode
);
#else
std
::
cout
<<
"Adam Optimizer #"
<<
optimizer_id
<<
" is created with scalar arithmetic capability."
<<
std
::
endl
;
printf
(
"Config: alpha=%f, betas=(%f, %f), weight_decay=%f, adam_w=%d
\n
"
,
alpha
,
betta1
,
betta2
,
weight_decay
,
(
int
)
adamw_mode
);
#endif
#endif
return
0
;
}
void
Adam_Optimizer
::
Step_8
(
float
*
_params
,
float
*
grads
,
float
*
_exp_avg
,
float
*
_exp_avg_sq
,
size_t
_param_size
,
__half
*
dev_params
)
{
size_t
rounded_size
=
0
;
#if defined(__AVX512__) or defined(__AVX256__)
AVX_Data
betta1_4
;
betta1_4
.
data
=
SIMD_SET
(
_betta1
);
AVX_Data
betta2_4
;
betta2_4
.
data
=
SIMD_SET
(
_betta2
);
float
betta1_minus1
=
1
-
_betta1
;
float
betta2_minus1
=
1
-
_betta2
;
AVX_Data
betta1_minus1_4
;
betta1_minus1_4
.
data
=
SIMD_SET
(
betta1_minus1
);
AVX_Data
betta2_minus1_4
;
betta2_minus1_4
.
data
=
SIMD_SET
(
betta2_minus1
);
AVX_Data
bias2_sqrt
;
bias2_sqrt
.
data
=
SIMD_SET
(
_bias_correction2
);
AVX_Data
eps_4
;
eps_4
.
data
=
SIMD_SET
(
_eps
);
float
step_size
=
-
1
*
_alpha
/
_bias_correction1
;
AVX_Data
step_size_4
;
step_size_4
.
data
=
SIMD_SET
(
step_size
);
float
w_decay
=
-
1
*
_alpha
*
_weight_decay
;
AVX_Data
weight_decay4
;
if
(
_weight_decay
>
0
)
weight_decay4
.
data
=
(
_adamw_mode
?
SIMD_SET
(
w_decay
)
:
SIMD_SET
(
_weight_decay
));
rounded_size
=
ROUND_DOWN
(
_param_size
,
(
SIMD_WIDTH
<<
3
));
for
(
size_t
t
=
0
;
t
<
rounded_size
;
t
+=
TILE
)
{
size_t
copy_size
=
TILE
;
if
((
t
+
TILE
)
>
rounded_size
)
copy_size
=
rounded_size
-
t
;
size_t
offset
=
copy_size
+
t
;
if
((
t
/
TILE
)
>=
2
)
{
cudaStreamSynchronize
(
_streams
[
_buf_index
]);
}
#pragma omp parallel for
for
(
size_t
i
=
t
;
i
<
offset
;
i
+=
(
SIMD_WIDTH
<<
3
))
{
AVX_Data
grad_4
[
8
];
grad_4
[
0
].
data
=
SIMD_LOAD
(
grads
+
i
);
grad_4
[
1
].
data
=
SIMD_LOAD
(
grads
+
i
+
SIMD_WIDTH
);
grad_4
[
2
].
data
=
SIMD_LOAD
(
grads
+
i
+
(
SIMD_WIDTH
<<
1
));
grad_4
[
3
].
data
=
SIMD_LOAD
(
grads
+
i
+
SIMD_WIDTH
*
3
);
grad_4
[
4
].
data
=
SIMD_LOAD
(
grads
+
i
+
(
SIMD_WIDTH
<<
2
));
grad_4
[
5
].
data
=
SIMD_LOAD
(
grads
+
i
+
SIMD_WIDTH
*
5
);
grad_4
[
6
].
data
=
SIMD_LOAD
(
grads
+
i
+
SIMD_WIDTH
*
6
);
grad_4
[
7
].
data
=
SIMD_LOAD
(
grads
+
i
+
SIMD_WIDTH
*
7
);
AVX_Data
momentum_4
[
8
];
momentum_4
[
0
].
data
=
SIMD_LOAD
(
_exp_avg
+
i
);
momentum_4
[
1
].
data
=
SIMD_LOAD
(
_exp_avg
+
i
+
SIMD_WIDTH
);
momentum_4
[
2
].
data
=
SIMD_LOAD
(
_exp_avg
+
i
+
(
SIMD_WIDTH
<<
1
));
momentum_4
[
3
].
data
=
SIMD_LOAD
(
_exp_avg
+
i
+
SIMD_WIDTH
*
3
);
momentum_4
[
4
].
data
=
SIMD_LOAD
(
_exp_avg
+
i
+
(
SIMD_WIDTH
<<
2
));
momentum_4
[
5
].
data
=
SIMD_LOAD
(
_exp_avg
+
i
+
SIMD_WIDTH
*
5
);
momentum_4
[
6
].
data
=
SIMD_LOAD
(
_exp_avg
+
i
+
SIMD_WIDTH
*
6
);
momentum_4
[
7
].
data
=
SIMD_LOAD
(
_exp_avg
+
i
+
SIMD_WIDTH
*
7
);
AVX_Data
variance_4
[
8
];
variance_4
[
0
].
data
=
SIMD_LOAD
(
_exp_avg_sq
+
i
);
variance_4
[
1
].
data
=
SIMD_LOAD
(
_exp_avg_sq
+
i
+
SIMD_WIDTH
);
variance_4
[
2
].
data
=
SIMD_LOAD
(
_exp_avg_sq
+
i
+
(
SIMD_WIDTH
<<
1
));
variance_4
[
3
].
data
=
SIMD_LOAD
(
_exp_avg_sq
+
i
+
SIMD_WIDTH
*
3
);
variance_4
[
4
].
data
=
SIMD_LOAD
(
_exp_avg_sq
+
i
+
(
SIMD_WIDTH
<<
2
));
variance_4
[
5
].
data
=
SIMD_LOAD
(
_exp_avg_sq
+
i
+
SIMD_WIDTH
*
5
);
variance_4
[
6
].
data
=
SIMD_LOAD
(
_exp_avg_sq
+
i
+
SIMD_WIDTH
*
6
);
variance_4
[
7
].
data
=
SIMD_LOAD
(
_exp_avg_sq
+
i
+
SIMD_WIDTH
*
7
);
AVX_Data
param_4
[
8
];
param_4
[
0
].
data
=
SIMD_LOAD
(
_params
+
i
);
param_4
[
1
].
data
=
SIMD_LOAD
(
_params
+
i
+
SIMD_WIDTH
);
param_4
[
2
].
data
=
SIMD_LOAD
(
_params
+
i
+
(
SIMD_WIDTH
<<
1
));
param_4
[
3
].
data
=
SIMD_LOAD
(
_params
+
i
+
SIMD_WIDTH
*
3
);
param_4
[
4
].
data
=
SIMD_LOAD
(
_params
+
i
+
(
SIMD_WIDTH
<<
2
));
param_4
[
5
].
data
=
SIMD_LOAD
(
_params
+
i
+
SIMD_WIDTH
*
5
);
param_4
[
6
].
data
=
SIMD_LOAD
(
_params
+
i
+
SIMD_WIDTH
*
6
);
param_4
[
7
].
data
=
SIMD_LOAD
(
_params
+
i
+
SIMD_WIDTH
*
7
);
if
(
_weight_decay
>
0
&&
!
_adamw_mode
)
{
grad_4
[
0
].
data
=
SIMD_FMA
(
param_4
[
0
].
data
,
weight_decay4
.
data
,
grad_4
[
0
].
data
);
grad_4
[
1
].
data
=
SIMD_FMA
(
param_4
[
1
].
data
,
weight_decay4
.
data
,
grad_4
[
1
].
data
);
grad_4
[
2
].
data
=
SIMD_FMA
(
param_4
[
2
].
data
,
weight_decay4
.
data
,
grad_4
[
2
].
data
);
grad_4
[
3
].
data
=
SIMD_FMA
(
param_4
[
3
].
data
,
weight_decay4
.
data
,
grad_4
[
3
].
data
);
grad_4
[
4
].
data
=
SIMD_FMA
(
param_4
[
4
].
data
,
weight_decay4
.
data
,
grad_4
[
4
].
data
);
grad_4
[
5
].
data
=
SIMD_FMA
(
param_4
[
5
].
data
,
weight_decay4
.
data
,
grad_4
[
5
].
data
);
grad_4
[
6
].
data
=
SIMD_FMA
(
param_4
[
6
].
data
,
weight_decay4
.
data
,
grad_4
[
6
].
data
);
grad_4
[
7
].
data
=
SIMD_FMA
(
param_4
[
7
].
data
,
weight_decay4
.
data
,
grad_4
[
7
].
data
);
}
momentum_4
[
0
].
data
=
SIMD_MUL
(
momentum_4
[
0
].
data
,
betta1_4
.
data
);
momentum_4
[
0
].
data
=
SIMD_FMA
(
grad_4
[
0
].
data
,
betta1_minus1_4
.
data
,
momentum_4
[
0
].
data
);
momentum_4
[
1
].
data
=
SIMD_MUL
(
momentum_4
[
1
].
data
,
betta1_4
.
data
);
momentum_4
[
1
].
data
=
SIMD_FMA
(
grad_4
[
1
].
data
,
betta1_minus1_4
.
data
,
momentum_4
[
1
].
data
);
momentum_4
[
2
].
data
=
SIMD_MUL
(
momentum_4
[
2
].
data
,
betta1_4
.
data
);
momentum_4
[
2
].
data
=
SIMD_FMA
(
grad_4
[
2
].
data
,
betta1_minus1_4
.
data
,
momentum_4
[
2
].
data
);
momentum_4
[
3
].
data
=
SIMD_MUL
(
momentum_4
[
3
].
data
,
betta1_4
.
data
);
momentum_4
[
3
].
data
=
SIMD_FMA
(
grad_4
[
3
].
data
,
betta1_minus1_4
.
data
,
momentum_4
[
3
].
data
);
momentum_4
[
4
].
data
=
SIMD_MUL
(
momentum_4
[
4
].
data
,
betta1_4
.
data
);
momentum_4
[
4
].
data
=
SIMD_FMA
(
grad_4
[
4
].
data
,
betta1_minus1_4
.
data
,
momentum_4
[
4
].
data
);
momentum_4
[
5
].
data
=
SIMD_MUL
(
momentum_4
[
5
].
data
,
betta1_4
.
data
);
momentum_4
[
5
].
data
=
SIMD_FMA
(
grad_4
[
5
].
data
,
betta1_minus1_4
.
data
,
momentum_4
[
5
].
data
);
momentum_4
[
6
].
data
=
SIMD_MUL
(
momentum_4
[
6
].
data
,
betta1_4
.
data
);
momentum_4
[
6
].
data
=
SIMD_FMA
(
grad_4
[
6
].
data
,
betta1_minus1_4
.
data
,
momentum_4
[
6
].
data
);
momentum_4
[
7
].
data
=
SIMD_MUL
(
momentum_4
[
7
].
data
,
betta1_4
.
data
);
momentum_4
[
7
].
data
=
SIMD_FMA
(
grad_4
[
7
].
data
,
betta1_minus1_4
.
data
,
momentum_4
[
7
].
data
);
variance_4
[
0
].
data
=
SIMD_MUL
(
variance_4
[
0
].
data
,
betta2_4
.
data
);
variance_4
[
1
].
data
=
SIMD_MUL
(
variance_4
[
1
].
data
,
betta2_4
.
data
);
variance_4
[
2
].
data
=
SIMD_MUL
(
variance_4
[
2
].
data
,
betta2_4
.
data
);
variance_4
[
3
].
data
=
SIMD_MUL
(
variance_4
[
3
].
data
,
betta2_4
.
data
);
variance_4
[
4
].
data
=
SIMD_MUL
(
variance_4
[
4
].
data
,
betta2_4
.
data
);
variance_4
[
5
].
data
=
SIMD_MUL
(
variance_4
[
5
].
data
,
betta2_4
.
data
);
variance_4
[
6
].
data
=
SIMD_MUL
(
variance_4
[
6
].
data
,
betta2_4
.
data
);
variance_4
[
7
].
data
=
SIMD_MUL
(
variance_4
[
7
].
data
,
betta2_4
.
data
);
grad_4
[
0
].
data
=
SIMD_MUL
(
grad_4
[
0
].
data
,
grad_4
[
0
].
data
);
grad_4
[
1
].
data
=
SIMD_MUL
(
grad_4
[
1
].
data
,
grad_4
[
1
].
data
);
grad_4
[
2
].
data
=
SIMD_MUL
(
grad_4
[
2
].
data
,
grad_4
[
2
].
data
);
grad_4
[
3
].
data
=
SIMD_MUL
(
grad_4
[
3
].
data
,
grad_4
[
3
].
data
);
grad_4
[
4
].
data
=
SIMD_MUL
(
grad_4
[
4
].
data
,
grad_4
[
4
].
data
);
grad_4
[
5
].
data
=
SIMD_MUL
(
grad_4
[
5
].
data
,
grad_4
[
5
].
data
);
grad_4
[
6
].
data
=
SIMD_MUL
(
grad_4
[
6
].
data
,
grad_4
[
6
].
data
);
grad_4
[
7
].
data
=
SIMD_MUL
(
grad_4
[
7
].
data
,
grad_4
[
7
].
data
);
variance_4
[
0
].
data
=
SIMD_FMA
(
grad_4
[
0
].
data
,
betta2_minus1_4
.
data
,
variance_4
[
0
].
data
);
variance_4
[
1
].
data
=
SIMD_FMA
(
grad_4
[
1
].
data
,
betta2_minus1_4
.
data
,
variance_4
[
1
].
data
);
variance_4
[
2
].
data
=
SIMD_FMA
(
grad_4
[
2
].
data
,
betta2_minus1_4
.
data
,
variance_4
[
2
].
data
);
variance_4
[
3
].
data
=
SIMD_FMA
(
grad_4
[
3
].
data
,
betta2_minus1_4
.
data
,
variance_4
[
3
].
data
);
variance_4
[
4
].
data
=
SIMD_FMA
(
grad_4
[
4
].
data
,
betta2_minus1_4
.
data
,
variance_4
[
4
].
data
);
variance_4
[
5
].
data
=
SIMD_FMA
(
grad_4
[
5
].
data
,
betta2_minus1_4
.
data
,
variance_4
[
5
].
data
);
variance_4
[
6
].
data
=
SIMD_FMA
(
grad_4
[
6
].
data
,
betta2_minus1_4
.
data
,
variance_4
[
6
].
data
);
variance_4
[
7
].
data
=
SIMD_FMA
(
grad_4
[
7
].
data
,
betta2_minus1_4
.
data
,
variance_4
[
7
].
data
);
grad_4
[
0
].
data
=
SIMD_SQRT
(
variance_4
[
0
].
data
);
grad_4
[
1
].
data
=
SIMD_SQRT
(
variance_4
[
1
].
data
);
grad_4
[
2
].
data
=
SIMD_SQRT
(
variance_4
[
2
].
data
);
grad_4
[
3
].
data
=
SIMD_SQRT
(
variance_4
[
3
].
data
);
grad_4
[
4
].
data
=
SIMD_SQRT
(
variance_4
[
4
].
data
);
grad_4
[
5
].
data
=
SIMD_SQRT
(
variance_4
[
5
].
data
);
grad_4
[
6
].
data
=
SIMD_SQRT
(
variance_4
[
6
].
data
);
grad_4
[
7
].
data
=
SIMD_SQRT
(
variance_4
[
7
].
data
);
grad_4
[
0
].
data
=
SIMD_FMA
(
grad_4
[
0
].
data
,
bias2_sqrt
.
data
,
eps_4
.
data
);
grad_4
[
1
].
data
=
SIMD_FMA
(
grad_4
[
1
].
data
,
bias2_sqrt
.
data
,
eps_4
.
data
);
grad_4
[
2
].
data
=
SIMD_FMA
(
grad_4
[
2
].
data
,
bias2_sqrt
.
data
,
eps_4
.
data
);
grad_4
[
3
].
data
=
SIMD_FMA
(
grad_4
[
3
].
data
,
bias2_sqrt
.
data
,
eps_4
.
data
);
grad_4
[
4
].
data
=
SIMD_FMA
(
grad_4
[
4
].
data
,
bias2_sqrt
.
data
,
eps_4
.
data
);
grad_4
[
5
].
data
=
SIMD_FMA
(
grad_4
[
5
].
data
,
bias2_sqrt
.
data
,
eps_4
.
data
);
grad_4
[
6
].
data
=
SIMD_FMA
(
grad_4
[
6
].
data
,
bias2_sqrt
.
data
,
eps_4
.
data
);
grad_4
[
7
].
data
=
SIMD_FMA
(
grad_4
[
7
].
data
,
bias2_sqrt
.
data
,
eps_4
.
data
);
grad_4
[
0
].
data
=
SIMD_DIV
(
momentum_4
[
0
].
data
,
grad_4
[
0
].
data
);
grad_4
[
1
].
data
=
SIMD_DIV
(
momentum_4
[
1
].
data
,
grad_4
[
1
].
data
);
grad_4
[
2
].
data
=
SIMD_DIV
(
momentum_4
[
2
].
data
,
grad_4
[
2
].
data
);
grad_4
[
3
].
data
=
SIMD_DIV
(
momentum_4
[
3
].
data
,
grad_4
[
3
].
data
);
grad_4
[
4
].
data
=
SIMD_DIV
(
momentum_4
[
4
].
data
,
grad_4
[
4
].
data
);
grad_4
[
5
].
data
=
SIMD_DIV
(
momentum_4
[
5
].
data
,
grad_4
[
5
].
data
);
grad_4
[
6
].
data
=
SIMD_DIV
(
momentum_4
[
6
].
data
,
grad_4
[
6
].
data
);
grad_4
[
7
].
data
=
SIMD_DIV
(
momentum_4
[
7
].
data
,
grad_4
[
7
].
data
);
if
(
_weight_decay
>
0
&&
_adamw_mode
)
{
param_4
[
0
].
data
=
SIMD_FMA
(
param_4
[
0
].
data
,
weight_decay4
.
data
,
param_4
[
0
].
data
);
param_4
[
1
].
data
=
SIMD_FMA
(
param_4
[
1
].
data
,
weight_decay4
.
data
,
param_4
[
1
].
data
);
param_4
[
2
].
data
=
SIMD_FMA
(
param_4
[
2
].
data
,
weight_decay4
.
data
,
param_4
[
2
].
data
);
param_4
[
3
].
data
=
SIMD_FMA
(
param_4
[
3
].
data
,
weight_decay4
.
data
,
param_4
[
3
].
data
);
param_4
[
4
].
data
=
SIMD_FMA
(
param_4
[
4
].
data
,
weight_decay4
.
data
,
param_4
[
4
].
data
);
param_4
[
5
].
data
=
SIMD_FMA
(
param_4
[
5
].
data
,
weight_decay4
.
data
,
param_4
[
5
].
data
);
param_4
[
6
].
data
=
SIMD_FMA
(
param_4
[
6
].
data
,
weight_decay4
.
data
,
param_4
[
6
].
data
);
param_4
[
7
].
data
=
SIMD_FMA
(
param_4
[
7
].
data
,
weight_decay4
.
data
,
param_4
[
7
].
data
);
}
param_4
[
0
].
data
=
SIMD_FMA
(
grad_4
[
0
].
data
,
step_size_4
.
data
,
param_4
[
0
].
data
);
param_4
[
1
].
data
=
SIMD_FMA
(
grad_4
[
1
].
data
,
step_size_4
.
data
,
param_4
[
1
].
data
);
param_4
[
2
].
data
=
SIMD_FMA
(
grad_4
[
2
].
data
,
step_size_4
.
data
,
param_4
[
2
].
data
);
param_4
[
3
].
data
=
SIMD_FMA
(
grad_4
[
3
].
data
,
step_size_4
.
data
,
param_4
[
3
].
data
);
param_4
[
4
].
data
=
SIMD_FMA
(
grad_4
[
4
].
data
,
step_size_4
.
data
,
param_4
[
4
].
data
);
param_4
[
5
].
data
=
SIMD_FMA
(
grad_4
[
5
].
data
,
step_size_4
.
data
,
param_4
[
5
].
data
);
param_4
[
6
].
data
=
SIMD_FMA
(
grad_4
[
6
].
data
,
step_size_4
.
data
,
param_4
[
6
].
data
);
param_4
[
7
].
data
=
SIMD_FMA
(
grad_4
[
7
].
data
,
step_size_4
.
data
,
param_4
[
7
].
data
);
SIMD_STORE
(
_params
+
i
,
param_4
[
0
].
data
);
SIMD_STORE
(
_params
+
i
+
SIMD_WIDTH
,
param_4
[
1
].
data
);
SIMD_STORE
(
_params
+
i
+
(
SIMD_WIDTH
<<
1
),
param_4
[
2
].
data
);
SIMD_STORE
(
_params
+
i
+
SIMD_WIDTH
*
3
,
param_4
[
3
].
data
);
SIMD_STORE
(
_params
+
i
+
(
SIMD_WIDTH
<<
2
),
param_4
[
4
].
data
);
SIMD_STORE
(
_params
+
i
+
SIMD_WIDTH
*
5
,
param_4
[
5
].
data
);
SIMD_STORE
(
_params
+
i
+
SIMD_WIDTH
*
6
,
param_4
[
6
].
data
);
SIMD_STORE
(
_params
+
i
+
SIMD_WIDTH
*
7
,
param_4
[
7
].
data
);
if
(
dev_params
)
{
SIMD_STORE
(
_doubled_buffer
[
_buf_index
]
+
(
i
-
t
),
param_4
[
0
].
data
);
SIMD_STORE
(
_doubled_buffer
[
_buf_index
]
+
(
i
-
t
)
+
SIMD_WIDTH
,
param_4
[
1
].
data
);
SIMD_STORE
(
_doubled_buffer
[
_buf_index
]
+
(
i
-
t
)
+
(
SIMD_WIDTH
<<
1
),
param_4
[
2
].
data
);
SIMD_STORE
(
_doubled_buffer
[
_buf_index
]
+
(
i
-
t
)
+
SIMD_WIDTH
*
3
,
param_4
[
3
].
data
);
SIMD_STORE
(
_doubled_buffer
[
_buf_index
]
+
(
i
-
t
)
+
(
SIMD_WIDTH
<<
2
),
param_4
[
4
].
data
);
SIMD_STORE
(
_doubled_buffer
[
_buf_index
]
+
(
i
-
t
)
+
SIMD_WIDTH
*
5
,
param_4
[
5
].
data
);
SIMD_STORE
(
_doubled_buffer
[
_buf_index
]
+
(
i
-
t
)
+
SIMD_WIDTH
*
6
,
param_4
[
6
].
data
);
SIMD_STORE
(
_doubled_buffer
[
_buf_index
]
+
(
i
-
t
)
+
SIMD_WIDTH
*
7
,
param_4
[
7
].
data
);
}
SIMD_STORE
(
_exp_avg
+
i
,
momentum_4
[
0
].
data
);
SIMD_STORE
(
_exp_avg
+
i
+
SIMD_WIDTH
,
momentum_4
[
1
].
data
);
SIMD_STORE
(
_exp_avg
+
i
+
(
SIMD_WIDTH
<<
1
),
momentum_4
[
2
].
data
);
SIMD_STORE
(
_exp_avg
+
i
+
SIMD_WIDTH
*
3
,
momentum_4
[
3
].
data
);
SIMD_STORE
(
_exp_avg
+
i
+
(
SIMD_WIDTH
<<
2
),
momentum_4
[
4
].
data
);
SIMD_STORE
(
_exp_avg
+
i
+
SIMD_WIDTH
*
5
,
momentum_4
[
5
].
data
);
SIMD_STORE
(
_exp_avg
+
i
+
SIMD_WIDTH
*
6
,
momentum_4
[
6
].
data
);
SIMD_STORE
(
_exp_avg
+
i
+
SIMD_WIDTH
*
7
,
momentum_4
[
7
].
data
);
SIMD_STORE
(
_exp_avg_sq
+
i
,
variance_4
[
0
].
data
);
SIMD_STORE
(
_exp_avg_sq
+
i
+
SIMD_WIDTH
,
variance_4
[
1
].
data
);
SIMD_STORE
(
_exp_avg_sq
+
i
+
(
SIMD_WIDTH
<<
1
),
variance_4
[
2
].
data
);
SIMD_STORE
(
_exp_avg_sq
+
i
+
SIMD_WIDTH
*
3
,
variance_4
[
3
].
data
);
SIMD_STORE
(
_exp_avg_sq
+
i
+
(
SIMD_WIDTH
<<
2
),
variance_4
[
4
].
data
);
SIMD_STORE
(
_exp_avg_sq
+
i
+
SIMD_WIDTH
*
5
,
variance_4
[
5
].
data
);
SIMD_STORE
(
_exp_avg_sq
+
i
+
SIMD_WIDTH
*
6
,
variance_4
[
6
].
data
);
SIMD_STORE
(
_exp_avg_sq
+
i
+
SIMD_WIDTH
*
7
,
variance_4
[
7
].
data
);
}
if
(
dev_params
)
{
launch_param_update
(
_doubled_buffer
[
_buf_index
],
dev_params
+
t
,
copy_size
,
_streams
[
_buf_index
]);
_buf_index
=
!
_buf_index
;
}
}
#endif
if
(
_param_size
>
rounded_size
)
Step_4
((
_params
+
rounded_size
),
(
grads
+
rounded_size
),
(
_exp_avg
+
rounded_size
),
(
_exp_avg_sq
+
rounded_size
),
(
_param_size
-
rounded_size
),
(
dev_params
!=
nullptr
?
(
dev_params
+
rounded_size
)
:
dev_params
));
}
int
ds_adam_step
(
int
optimizer_id
,
size_t
step
,
float
lr
,
float
beta1
,
float
beta2
,
float
epsilon
,
float
weight_decay
,
bool
bias_correction
,
torch
::
Tensor
&
params
,
torch
::
Tensor
&
grads
,
torch
::
Tensor
&
exp_avg
,
torch
::
Tensor
&
exp_avg_sq
)
{
auto
params_c
=
params
.
contiguous
();
auto
grads_c
=
grads
.
contiguous
();
auto
exp_avg_c
=
exp_avg
.
contiguous
();
auto
exp_avg_sq_c
=
exp_avg_sq
.
contiguous
();
float
*
params_ptr
=
(
float
*
)
params_c
.
data_ptr
();
float
*
grads_ptr
=
(
float
*
)
grads_c
.
data_ptr
();
float
*
exp_avg_ptr
=
(
float
*
)
exp_avg_c
.
data_ptr
();
float
*
exp_avg_sq_ptr
=
(
float
*
)
exp_avg_sq_c
.
data_ptr
();
std
::
shared_ptr
<
Adam_Optimizer
>
opt
=
std
::
static_pointer_cast
<
Adam_Optimizer
>
(
s_optimizers
[
optimizer_id
]);
opt
->
IncrementStep
(
step
,
beta1
,
beta2
);
opt
->
update_state
(
lr
,
epsilon
,
weight_decay
,
bias_correction
);
opt
->
Step_8
(
params_ptr
,
grads_ptr
,
exp_avg_ptr
,
exp_avg_sq_ptr
,
params_c
.
size
(
0
));
opt
->
SynchronizeStreams
();
return
0
;
}
int
ds_adam_step_plus_copy
(
int
optimizer_id
,
size_t
step
,
float
lr
,
float
beta1
,
float
beta2
,
float
epsilon
,
float
weight_decay
,
bool
bias_correction
,
torch
::
Tensor
&
params
,
torch
::
Tensor
&
grads
,
torch
::
Tensor
&
exp_avg
,
torch
::
Tensor
&
exp_avg_sq
,
torch
::
Tensor
&
gpu_params
)
{
auto
params_c
=
params
.
contiguous
();
auto
gpu_params_c
=
gpu_params
.
contiguous
();
auto
exp_avg_c
=
exp_avg
.
contiguous
();
auto
exp_avg_sq_c
=
exp_avg_sq
.
contiguous
();
auto
grads_c
=
grads
.
contiguous
();
float
*
params_ptr
=
(
float
*
)
params_c
.
data_ptr
();
float
*
grads_ptr
=
(
float
*
)
grads_c
.
data_ptr
();
__half
*
gpu_params_ptr
=
(
__half
*
)
gpu_params_c
.
data_ptr
();
float
*
exp_avg_ptr
=
(
float
*
)
exp_avg_c
.
data_ptr
();
float
*
exp_avg_sq_ptr
=
(
float
*
)
exp_avg_sq_c
.
data_ptr
();
std
::
shared_ptr
<
Adam_Optimizer
>
opt
=
std
::
static_pointer_cast
<
Adam_Optimizer
>
(
s_optimizers
[
optimizer_id
]);
opt
->
IncrementStep
(
step
,
beta1
,
beta2
);
opt
->
update_state
(
lr
,
epsilon
,
weight_decay
,
bias_correction
);
opt
->
Step_8
(
params_ptr
,
grads_ptr
,
exp_avg_ptr
,
exp_avg_sq_ptr
,
params_c
.
size
(
0
),
gpu_params_ptr
);
opt
->
SynchronizeStreams
();
return
0
;
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"adam_update"
,
&
ds_adam_step
,
"DeepSpeed CPU Adam update (C++)"
);
m
.
def
(
"adam_update_copy"
,
&
ds_adam_step_plus_copy
,
"DeepSpeed CPU Adam update and param copy (C++)"
);
m
.
def
(
"create_adam"
,
&
create_adam_optimizer
,
"DeepSpeed CPU Adam (C++)"
);
}
deepspeed/ops/csrc/adam/custom_cuda_kernel.cu
0 → 100755
View file @
eadbbe09
#include "custom_cuda_layers.h"
__global__
void
param_update_kernel
(
const
float
*
input
,
__half
*
output
,
int
size
)
{
int
id
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
id
<
size
)
{
output
[
id
]
=
(
__half
)
input
[
id
];
}
}
void
launch_param_update
(
const
float
*
input
,
__half
*
output
,
int
size
,
cudaStream_t
stream
)
{
int
threads
=
1024
;
dim3
grid_dim
((
size
-
1
)
/
threads
+
1
);
dim3
block_dim
(
threads
);
param_update_kernel
<<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
input
,
output
,
size
);
}
deepspeed/ops/csrc/adam/fused_adam_frontend.cpp
0 → 100644
View file @
eadbbe09
#include <torch/extension.h>
void
multi_tensor_adam_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
const
float
lr
,
const
float
beta1
,
const
float
beta2
,
const
float
epsilon
,
const
int
step
,
const
int
mode
,
const
int
bias_correction
,
const
float
weight_decay
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"multi_tensor_adam"
,
&
multi_tensor_adam_cuda
,
"Compute and apply gradient update to parameters for Adam optimizer"
);
}
deepspeed/ops/csrc/adam/hip/compat.h
0 → 100644
View file @
eadbbe09
/* Copyright 2020 The Microsoft DeepSpeed Team
Copyright NVIDIA/apex
This file is adapted from fused adam in NVIDIA/apex, commit a109f85
*/
#ifndef TORCH_CHECK
#define TORCH_CHECK AT_CHECK
#endif
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif
deepspeed/ops/csrc/adam/hip/cpu_adam.cpp
0 → 100644
View file @
eadbbe09
#include "cpu_adam.h"
#include <hip/hip_runtime_api.h>
#include <math.h>
#include <omp.h>
#include <torch/extension.h>
#include <iostream>
#include <memory>
#include <type_traits>
#include <unordered_map>
#include "rocblas.h"
#include "hip/hip_runtime.h"
#include "hiprand.h"
#include "custom_cuda_layers.h"
static
std
::
unordered_map
<
int
,
std
::
shared_ptr
<
void
>>
s_optimizers
;
#define ROUND_DOWN(size, step) ((size) & ~((step)-1))
// C++ interface
void
Adam_Optimizer
::
Step
(
float
*
_params
,
float
*
grads
,
float
*
_exp_avg
,
float
*
_exp_avg_sq
,
size_t
_param_size
,
__half
*
dev_params
)
{
float
betta1_minus1
=
1
-
_betta1
;
float
betta2_minus1
=
1
-
_betta2
;
float
step_size
=
-
1
*
_alpha
/
_bias_correction1
;
float
w_decay
=
-
1
*
_alpha
*
_weight_decay
;
size_t
rounded_size
=
0
;
#if defined(__AVX512__) or defined(__AVX256__)
AVX_Data
betta1_4
;
betta1_4
.
data
=
SIMD_SET
(
_betta1
);
AVX_Data
betta2_4
;
betta2_4
.
data
=
SIMD_SET
(
_betta2
);
AVX_Data
betta1_minus1_4
;
betta1_minus1_4
.
data
=
SIMD_SET
(
betta1_minus1
);
AVX_Data
betta2_minus1_4
;
betta2_minus1_4
.
data
=
SIMD_SET
(
betta2_minus1
);
AVX_Data
bias2_sqrt
;
bias2_sqrt
.
data
=
SIMD_SET
(
_bias_correction2
);
AVX_Data
eps_4
;
eps_4
.
data
=
SIMD_SET
(
_eps
);
AVX_Data
step_size_4
;
step_size_4
.
data
=
SIMD_SET
(
step_size
);
AVX_Data
weight_decay4
;
if
(
_weight_decay
>
0
)
weight_decay4
.
data
=
(
_adamw_mode
?
SIMD_SET
(
w_decay
)
:
SIMD_SET
(
_weight_decay
));
rounded_size
=
ROUND_DOWN
(
_param_size
,
SIMD_WIDTH
);
for
(
size_t
t
=
0
;
t
<
rounded_size
;
t
+=
TILE
)
{
size_t
copy_size
=
TILE
;
if
((
t
+
TILE
)
>
rounded_size
)
copy_size
=
rounded_size
-
t
;
size_t
offset
=
copy_size
+
t
;
if
((
t
/
TILE
)
>=
2
)
{
hipStreamSynchronize
(
_streams
[
_buf_index
]);
}
#pragma omp parallel for
for
(
size_t
i
=
t
;
i
<
offset
;
i
+=
SIMD_WIDTH
)
{
AVX_Data
grad_4
;
grad_4
.
data
=
SIMD_LOAD
(
grads
+
i
);
AVX_Data
momentum_4
;
momentum_4
.
data
=
SIMD_LOAD
(
_exp_avg
+
i
);
AVX_Data
variance_4
;
variance_4
.
data
=
SIMD_LOAD
(
_exp_avg_sq
+
i
);
AVX_Data
param_4
;
param_4
.
data
=
SIMD_LOAD
(
_params
+
i
);
if
(
_weight_decay
>
0
&&
!
_adamw_mode
)
{
grad_4
.
data
=
SIMD_FMA
(
param_4
.
data
,
weight_decay4
.
data
,
grad_4
.
data
);
}
momentum_4
.
data
=
SIMD_MUL
(
momentum_4
.
data
,
betta1_4
.
data
);
momentum_4
.
data
=
SIMD_FMA
(
grad_4
.
data
,
betta1_minus1_4
.
data
,
momentum_4
.
data
);
variance_4
.
data
=
SIMD_MUL
(
variance_4
.
data
,
betta2_4
.
data
);
grad_4
.
data
=
SIMD_MUL
(
grad_4
.
data
,
grad_4
.
data
);
variance_4
.
data
=
SIMD_FMA
(
grad_4
.
data
,
betta2_minus1_4
.
data
,
variance_4
.
data
);
grad_4
.
data
=
SIMD_SQRT
(
variance_4
.
data
);
grad_4
.
data
=
SIMD_FMA
(
grad_4
.
data
,
bias2_sqrt
.
data
,
eps_4
.
data
);
grad_4
.
data
=
SIMD_DIV
(
momentum_4
.
data
,
grad_4
.
data
);
if
(
_weight_decay
>
0
&&
_adamw_mode
)
{
param_4
.
data
=
SIMD_FMA
(
param_4
.
data
,
weight_decay4
.
data
,
param_4
.
data
);
}
param_4
.
data
=
SIMD_FMA
(
grad_4
.
data
,
step_size_4
.
data
,
param_4
.
data
);
SIMD_STORE
(
_params
+
i
,
param_4
.
data
);
if
(
dev_params
)
SIMD_STORE
(
_doubled_buffer
[
_buf_index
]
+
(
i
-
t
),
param_4
.
data
);
SIMD_STORE
(
_exp_avg
+
i
,
momentum_4
.
data
);
SIMD_STORE
(
_exp_avg_sq
+
i
,
variance_4
.
data
);
}
if
(
dev_params
)
{
launch_param_update
(
_doubled_buffer
[
_buf_index
],
dev_params
+
t
,
copy_size
,
_streams
[
_buf_index
]);
_buf_index
=
!
_buf_index
;
}
}
#endif
if
(
_param_size
>
rounded_size
)
{
for
(
size_t
t
=
rounded_size
;
t
<
_param_size
;
t
+=
TILE
)
{
size_t
copy_size
=
TILE
;
if
((
t
+
TILE
)
>
_param_size
)
copy_size
=
_param_size
-
t
;
size_t
offset
=
copy_size
+
t
;
if
((
t
/
TILE
)
>=
2
)
{
hipStreamSynchronize
(
_streams
[
_buf_index
]);
}
#pragma omp parallel for
for
(
size_t
k
=
t
;
k
<
offset
;
k
++
)
{
float
grad
=
grads
[
k
];
float
param
=
_params
[
k
];
float
momentum
=
_exp_avg
[
k
];
float
variance
=
_exp_avg_sq
[
k
];
if
(
_weight_decay
>
0
&&
!
_adamw_mode
)
{
grad
=
param
*
_weight_decay
+
grad
;
}
momentum
=
momentum
*
_betta1
;
momentum
=
grad
*
betta1_minus1
+
momentum
;
variance
=
variance
*
_betta2
;
grad
=
grad
*
grad
;
variance
=
grad
*
betta2_minus1
+
variance
;
grad
=
sqrt
(
variance
);
grad
=
grad
*
_bias_correction2
+
_eps
;
grad
=
momentum
/
grad
;
if
(
_weight_decay
>
0
&&
_adamw_mode
)
{
param
+=
w_decay
*
param
;
}
param
=
grad
*
step_size
+
param
;
if
(
dev_params
)
_doubled_buffer
[
_buf_index
][
k
-
t
]
=
param
;
_params
[
k
]
=
param
;
_exp_avg
[
k
]
=
momentum
;
_exp_avg_sq
[
k
]
=
variance
;
}
if
(
dev_params
)
{
launch_param_update
(
_doubled_buffer
[
_buf_index
],
dev_params
+
t
,
(
copy_size
),
_streams
[
_buf_index
]);
_buf_index
=
!
_buf_index
;
}
}
}
}
void
Adam_Optimizer
::
Step_4
(
float
*
_params
,
float
*
grads
,
float
*
_exp_avg
,
float
*
_exp_avg_sq
,
size_t
_param_size
,
__half
*
dev_params
)
{
size_t
rounded_size
=
0
;
#if defined(__AVX512__) or defined(__AVX256__)
AVX_Data
betta1_4
;
betta1_4
.
data
=
SIMD_SET
(
_betta1
);
AVX_Data
betta2_4
;
betta2_4
.
data
=
SIMD_SET
(
_betta2
);
float
betta1_minus1
=
1
-
_betta1
;
float
betta2_minus1
=
1
-
_betta2
;
AVX_Data
betta1_minus1_4
;
betta1_minus1_4
.
data
=
SIMD_SET
(
betta1_minus1
);
AVX_Data
betta2_minus1_4
;
betta2_minus1_4
.
data
=
SIMD_SET
(
betta2_minus1
);
AVX_Data
bias2_sqrt
;
bias2_sqrt
.
data
=
SIMD_SET
(
_bias_correction2
);
AVX_Data
eps_4
;
eps_4
.
data
=
SIMD_SET
(
_eps
);
float
step_size
=
-
1
*
_alpha
/
_bias_correction1
;
AVX_Data
step_size_4
;
step_size_4
.
data
=
SIMD_SET
(
step_size
);
float
w_decay
=
-
1
*
_alpha
*
_weight_decay
;
AVX_Data
weight_decay4
;
if
(
_weight_decay
>
0
)
weight_decay4
.
data
=
(
_adamw_mode
?
SIMD_SET
(
w_decay
)
:
SIMD_SET
(
_weight_decay
));
rounded_size
=
ROUND_DOWN
(
_param_size
,
(
SIMD_WIDTH
<<
2
));
for
(
size_t
t
=
0
;
t
<
rounded_size
;
t
+=
TILE
)
{
size_t
copy_size
=
TILE
;
if
((
t
+
TILE
)
>
rounded_size
)
copy_size
=
rounded_size
-
t
;
size_t
offset
=
copy_size
+
t
;
if
((
t
/
TILE
)
>=
2
)
{
hipStreamSynchronize
(
_streams
[
_buf_index
]);
}
#pragma omp parallel for
for
(
size_t
i
=
t
;
i
<
offset
;
i
+=
(
SIMD_WIDTH
<<
2
))
{
AVX_Data
grad_4
[
4
];
grad_4
[
0
].
data
=
SIMD_LOAD
(
grads
+
i
);
grad_4
[
1
].
data
=
SIMD_LOAD
(
grads
+
i
+
SIMD_WIDTH
);
grad_4
[
2
].
data
=
SIMD_LOAD
(
grads
+
i
+
(
SIMD_WIDTH
<<
1
));
grad_4
[
3
].
data
=
SIMD_LOAD
(
grads
+
i
+
SIMD_WIDTH
*
3
);
AVX_Data
momentum_4
[
4
];
momentum_4
[
0
].
data
=
SIMD_LOAD
(
_exp_avg
+
i
);
momentum_4
[
1
].
data
=
SIMD_LOAD
(
_exp_avg
+
i
+
SIMD_WIDTH
);
momentum_4
[
2
].
data
=
SIMD_LOAD
(
_exp_avg
+
i
+
(
SIMD_WIDTH
<<
1
));
momentum_4
[
3
].
data
=
SIMD_LOAD
(
_exp_avg
+
i
+
SIMD_WIDTH
*
3
);
AVX_Data
variance_4
[
4
];
variance_4
[
0
].
data
=
SIMD_LOAD
(
_exp_avg_sq
+
i
);
variance_4
[
1
].
data
=
SIMD_LOAD
(
_exp_avg_sq
+
i
+
SIMD_WIDTH
);
variance_4
[
2
].
data
=
SIMD_LOAD
(
_exp_avg_sq
+
i
+
(
SIMD_WIDTH
<<
1
));
variance_4
[
3
].
data
=
SIMD_LOAD
(
_exp_avg_sq
+
i
+
SIMD_WIDTH
*
3
);
AVX_Data
param_4
[
4
];
param_4
[
0
].
data
=
SIMD_LOAD
(
_params
+
i
);
param_4
[
1
].
data
=
SIMD_LOAD
(
_params
+
i
+
SIMD_WIDTH
);
param_4
[
2
].
data
=
SIMD_LOAD
(
_params
+
i
+
(
SIMD_WIDTH
<<
1
));
param_4
[
3
].
data
=
SIMD_LOAD
(
_params
+
i
+
SIMD_WIDTH
*
3
);
if
(
_weight_decay
>
0
&&
!
_adamw_mode
)
{
grad_4
[
0
].
data
=
SIMD_FMA
(
param_4
[
0
].
data
,
weight_decay4
.
data
,
grad_4
[
0
].
data
);
grad_4
[
1
].
data
=
SIMD_FMA
(
param_4
[
1
].
data
,
weight_decay4
.
data
,
grad_4
[
1
].
data
);
grad_4
[
2
].
data
=
SIMD_FMA
(
param_4
[
2
].
data
,
weight_decay4
.
data
,
grad_4
[
2
].
data
);
grad_4
[
3
].
data
=
SIMD_FMA
(
param_4
[
3
].
data
,
weight_decay4
.
data
,
grad_4
[
3
].
data
);
}
momentum_4
[
0
].
data
=
SIMD_MUL
(
momentum_4
[
0
].
data
,
betta1_4
.
data
);
momentum_4
[
0
].
data
=
SIMD_FMA
(
grad_4
[
0
].
data
,
betta1_minus1_4
.
data
,
momentum_4
[
0
].
data
);
momentum_4
[
1
].
data
=
SIMD_MUL
(
momentum_4
[
1
].
data
,
betta1_4
.
data
);
momentum_4
[
1
].
data
=
SIMD_FMA
(
grad_4
[
1
].
data
,
betta1_minus1_4
.
data
,
momentum_4
[
1
].
data
);
momentum_4
[
2
].
data
=
SIMD_MUL
(
momentum_4
[
2
].
data
,
betta1_4
.
data
);
momentum_4
[
2
].
data
=
SIMD_FMA
(
grad_4
[
2
].
data
,
betta1_minus1_4
.
data
,
momentum_4
[
2
].
data
);
momentum_4
[
3
].
data
=
SIMD_MUL
(
momentum_4
[
3
].
data
,
betta1_4
.
data
);
momentum_4
[
3
].
data
=
SIMD_FMA
(
grad_4
[
3
].
data
,
betta1_minus1_4
.
data
,
momentum_4
[
3
].
data
);
variance_4
[
0
].
data
=
SIMD_MUL
(
variance_4
[
0
].
data
,
betta2_4
.
data
);
variance_4
[
1
].
data
=
SIMD_MUL
(
variance_4
[
1
].
data
,
betta2_4
.
data
);
variance_4
[
2
].
data
=
SIMD_MUL
(
variance_4
[
2
].
data
,
betta2_4
.
data
);
variance_4
[
3
].
data
=
SIMD_MUL
(
variance_4
[
3
].
data
,
betta2_4
.
data
);
grad_4
[
0
].
data
=
SIMD_MUL
(
grad_4
[
0
].
data
,
grad_4
[
0
].
data
);
grad_4
[
1
].
data
=
SIMD_MUL
(
grad_4
[
1
].
data
,
grad_4
[
1
].
data
);
grad_4
[
2
].
data
=
SIMD_MUL
(
grad_4
[
2
].
data
,
grad_4
[
2
].
data
);
grad_4
[
3
].
data
=
SIMD_MUL
(
grad_4
[
3
].
data
,
grad_4
[
3
].
data
);
variance_4
[
0
].
data
=
SIMD_FMA
(
grad_4
[
0
].
data
,
betta2_minus1_4
.
data
,
variance_4
[
0
].
data
);
variance_4
[
1
].
data
=
SIMD_FMA
(
grad_4
[
1
].
data
,
betta2_minus1_4
.
data
,
variance_4
[
1
].
data
);
variance_4
[
2
].
data
=
SIMD_FMA
(
grad_4
[
2
].
data
,
betta2_minus1_4
.
data
,
variance_4
[
2
].
data
);
variance_4
[
3
].
data
=
SIMD_FMA
(
grad_4
[
3
].
data
,
betta2_minus1_4
.
data
,
variance_4
[
3
].
data
);
grad_4
[
0
].
data
=
SIMD_SQRT
(
variance_4
[
0
].
data
);
grad_4
[
1
].
data
=
SIMD_SQRT
(
variance_4
[
1
].
data
);
grad_4
[
2
].
data
=
SIMD_SQRT
(
variance_4
[
2
].
data
);
grad_4
[
3
].
data
=
SIMD_SQRT
(
variance_4
[
3
].
data
);
grad_4
[
0
].
data
=
SIMD_FMA
(
grad_4
[
0
].
data
,
bias2_sqrt
.
data
,
eps_4
.
data
);
grad_4
[
1
].
data
=
SIMD_FMA
(
grad_4
[
1
].
data
,
bias2_sqrt
.
data
,
eps_4
.
data
);
grad_4
[
2
].
data
=
SIMD_FMA
(
grad_4
[
2
].
data
,
bias2_sqrt
.
data
,
eps_4
.
data
);
grad_4
[
3
].
data
=
SIMD_FMA
(
grad_4
[
3
].
data
,
bias2_sqrt
.
data
,
eps_4
.
data
);
grad_4
[
0
].
data
=
SIMD_DIV
(
momentum_4
[
0
].
data
,
grad_4
[
0
].
data
);
grad_4
[
1
].
data
=
SIMD_DIV
(
momentum_4
[
1
].
data
,
grad_4
[
1
].
data
);
grad_4
[
2
].
data
=
SIMD_DIV
(
momentum_4
[
2
].
data
,
grad_4
[
2
].
data
);
grad_4
[
3
].
data
=
SIMD_DIV
(
momentum_4
[
3
].
data
,
grad_4
[
3
].
data
);
if
(
_weight_decay
>
0
&&
_adamw_mode
)
{
param_4
[
0
].
data
=
SIMD_FMA
(
param_4
[
0
].
data
,
weight_decay4
.
data
,
param_4
[
0
].
data
);
param_4
[
1
].
data
=
SIMD_FMA
(
param_4
[
1
].
data
,
weight_decay4
.
data
,
param_4
[
1
].
data
);
param_4
[
2
].
data
=
SIMD_FMA
(
param_4
[
2
].
data
,
weight_decay4
.
data
,
param_4
[
2
].
data
);
param_4
[
3
].
data
=
SIMD_FMA
(
param_4
[
3
].
data
,
weight_decay4
.
data
,
param_4
[
3
].
data
);
}
param_4
[
0
].
data
=
SIMD_FMA
(
grad_4
[
0
].
data
,
step_size_4
.
data
,
param_4
[
0
].
data
);
param_4
[
1
].
data
=
SIMD_FMA
(
grad_4
[
1
].
data
,
step_size_4
.
data
,
param_4
[
1
].
data
);
param_4
[
2
].
data
=
SIMD_FMA
(
grad_4
[
2
].
data
,
step_size_4
.
data
,
param_4
[
2
].
data
);
param_4
[
3
].
data
=
SIMD_FMA
(
grad_4
[
3
].
data
,
step_size_4
.
data
,
param_4
[
3
].
data
);
SIMD_STORE
(
_params
+
i
,
param_4
[
0
].
data
);
SIMD_STORE
(
_params
+
i
+
SIMD_WIDTH
,
param_4
[
1
].
data
);
SIMD_STORE
(
_params
+
i
+
(
SIMD_WIDTH
<<
1
),
param_4
[
2
].
data
);
SIMD_STORE
(
_params
+
i
+
SIMD_WIDTH
*
3
,
param_4
[
3
].
data
);
if
(
dev_params
)
{
SIMD_STORE
(
_doubled_buffer
[
_buf_index
]
+
(
i
-
t
),
param_4
[
0
].
data
);
SIMD_STORE
(
_doubled_buffer
[
_buf_index
]
+
(
i
-
t
)
+
SIMD_WIDTH
,
param_4
[
1
].
data
);
SIMD_STORE
(
_doubled_buffer
[
_buf_index
]
+
(
i
-
t
)
+
(
SIMD_WIDTH
<<
1
),
param_4
[
2
].
data
);
SIMD_STORE
(
_doubled_buffer
[
_buf_index
]
+
(
i
-
t
)
+
SIMD_WIDTH
*
3
,
param_4
[
3
].
data
);
}
SIMD_STORE
(
_exp_avg
+
i
,
momentum_4
[
0
].
data
);
SIMD_STORE
(
_exp_avg
+
i
+
SIMD_WIDTH
,
momentum_4
[
1
].
data
);
SIMD_STORE
(
_exp_avg
+
i
+
(
SIMD_WIDTH
<<
1
),
momentum_4
[
2
].
data
);
SIMD_STORE
(
_exp_avg
+
i
+
SIMD_WIDTH
*
3
,
momentum_4
[
3
].
data
);
SIMD_STORE
(
_exp_avg_sq
+
i
,
variance_4
[
0
].
data
);
SIMD_STORE
(
_exp_avg_sq
+
i
+
SIMD_WIDTH
,
variance_4
[
1
].
data
);
SIMD_STORE
(
_exp_avg_sq
+
i
+
(
SIMD_WIDTH
<<
1
),
variance_4
[
2
].
data
);
SIMD_STORE
(
_exp_avg_sq
+
i
+
SIMD_WIDTH
*
3
,
variance_4
[
3
].
data
);
}
if
(
dev_params
)
{
launch_param_update
(
_doubled_buffer
[
_buf_index
],
dev_params
+
t
,
copy_size
,
_streams
[
_buf_index
]);
_buf_index
=
!
_buf_index
;
}
}
#endif
if
(
_param_size
>
rounded_size
)
Step
((
_params
+
rounded_size
),
(
grads
+
rounded_size
),
(
_exp_avg
+
rounded_size
),
(
_exp_avg_sq
+
rounded_size
),
(
_param_size
-
rounded_size
),
(
dev_params
!=
nullptr
?
(
dev_params
+
rounded_size
)
:
dev_params
));
}
int
create_adam_optimizer
(
int
optimizer_id
,
float
alpha
=
1e-3
,
float
betta1
=
0.9
,
float
betta2
=
0.999
,
float
eps
=
1e-8
,
float
weight_decay
=
0
,
bool
adamw_mode
=
true
)
{
auto
opt
=
std
::
make_shared
<
Adam_Optimizer
>
(
alpha
,
betta1
,
betta2
,
eps
,
weight_decay
,
adamw_mode
);
s_optimizers
[
optimizer_id
]
=
opt
;
#if defined(__AVX512__)
std
::
cout
<<
"Adam Optimizer #"
<<
optimizer_id
<<
" is created with AVX512 arithmetic capability."
<<
std
::
endl
;
printf
(
"Config: alpha=%f, betas=(%f, %f), weight_decay=%f, adam_w=%d
\n
"
,
alpha
,
betta1
,
betta2
,
weight_decay
,
(
int
)
adamw_mode
);
#else
#if defined(__AVX256__)
std
::
cout
<<
"Adam Optimizer #"
<<
optimizer_id
<<
" is created with AVX2 arithmetic capability."
<<
std
::
endl
;
printf
(
"Config: alpha=%f, betas=(%f, %f), weight_decay=%f, adam_w=%d
\n
"
,
alpha
,
betta1
,
betta2
,
weight_decay
,
(
int
)
adamw_mode
);
#else
std
::
cout
<<
"Adam Optimizer #"
<<
optimizer_id
<<
" is created with scalar arithmetic capability."
<<
std
::
endl
;
printf
(
"Config: alpha=%f, betas=(%f, %f), weight_decay=%f, adam_w=%d
\n
"
,
alpha
,
betta1
,
betta2
,
weight_decay
,
(
int
)
adamw_mode
);
#endif
#endif
return
0
;
}
void
Adam_Optimizer
::
Step_8
(
float
*
_params
,
float
*
grads
,
float
*
_exp_avg
,
float
*
_exp_avg_sq
,
size_t
_param_size
,
__half
*
dev_params
)
{
size_t
rounded_size
=
0
;
#if defined(__AVX512__) or defined(__AVX256__)
AVX_Data
betta1_4
;
betta1_4
.
data
=
SIMD_SET
(
_betta1
);
AVX_Data
betta2_4
;
betta2_4
.
data
=
SIMD_SET
(
_betta2
);
float
betta1_minus1
=
1
-
_betta1
;
float
betta2_minus1
=
1
-
_betta2
;
AVX_Data
betta1_minus1_4
;
betta1_minus1_4
.
data
=
SIMD_SET
(
betta1_minus1
);
AVX_Data
betta2_minus1_4
;
betta2_minus1_4
.
data
=
SIMD_SET
(
betta2_minus1
);
AVX_Data
bias2_sqrt
;
bias2_sqrt
.
data
=
SIMD_SET
(
_bias_correction2
);
AVX_Data
eps_4
;
eps_4
.
data
=
SIMD_SET
(
_eps
);
float
step_size
=
-
1
*
_alpha
/
_bias_correction1
;
AVX_Data
step_size_4
;
step_size_4
.
data
=
SIMD_SET
(
step_size
);
float
w_decay
=
-
1
*
_alpha
*
_weight_decay
;
AVX_Data
weight_decay4
;
if
(
_weight_decay
>
0
)
weight_decay4
.
data
=
(
_adamw_mode
?
SIMD_SET
(
w_decay
)
:
SIMD_SET
(
_weight_decay
));
rounded_size
=
ROUND_DOWN
(
_param_size
,
(
SIMD_WIDTH
<<
3
));
for
(
size_t
t
=
0
;
t
<
rounded_size
;
t
+=
TILE
)
{
size_t
copy_size
=
TILE
;
if
((
t
+
TILE
)
>
rounded_size
)
copy_size
=
rounded_size
-
t
;
size_t
offset
=
copy_size
+
t
;
if
((
t
/
TILE
)
>=
2
)
{
hipStreamSynchronize
(
_streams
[
_buf_index
]);
}
#pragma omp parallel for
for
(
size_t
i
=
t
;
i
<
offset
;
i
+=
(
SIMD_WIDTH
<<
3
))
{
AVX_Data
grad_4
[
8
];
grad_4
[
0
].
data
=
SIMD_LOAD
(
grads
+
i
);
grad_4
[
1
].
data
=
SIMD_LOAD
(
grads
+
i
+
SIMD_WIDTH
);
grad_4
[
2
].
data
=
SIMD_LOAD
(
grads
+
i
+
(
SIMD_WIDTH
<<
1
));
grad_4
[
3
].
data
=
SIMD_LOAD
(
grads
+
i
+
SIMD_WIDTH
*
3
);
grad_4
[
4
].
data
=
SIMD_LOAD
(
grads
+
i
+
(
SIMD_WIDTH
<<
2
));
grad_4
[
5
].
data
=
SIMD_LOAD
(
grads
+
i
+
SIMD_WIDTH
*
5
);
grad_4
[
6
].
data
=
SIMD_LOAD
(
grads
+
i
+
SIMD_WIDTH
*
6
);
grad_4
[
7
].
data
=
SIMD_LOAD
(
grads
+
i
+
SIMD_WIDTH
*
7
);
AVX_Data
momentum_4
[
8
];
momentum_4
[
0
].
data
=
SIMD_LOAD
(
_exp_avg
+
i
);
momentum_4
[
1
].
data
=
SIMD_LOAD
(
_exp_avg
+
i
+
SIMD_WIDTH
);
momentum_4
[
2
].
data
=
SIMD_LOAD
(
_exp_avg
+
i
+
(
SIMD_WIDTH
<<
1
));
momentum_4
[
3
].
data
=
SIMD_LOAD
(
_exp_avg
+
i
+
SIMD_WIDTH
*
3
);
momentum_4
[
4
].
data
=
SIMD_LOAD
(
_exp_avg
+
i
+
(
SIMD_WIDTH
<<
2
));
momentum_4
[
5
].
data
=
SIMD_LOAD
(
_exp_avg
+
i
+
SIMD_WIDTH
*
5
);
momentum_4
[
6
].
data
=
SIMD_LOAD
(
_exp_avg
+
i
+
SIMD_WIDTH
*
6
);
momentum_4
[
7
].
data
=
SIMD_LOAD
(
_exp_avg
+
i
+
SIMD_WIDTH
*
7
);
AVX_Data
variance_4
[
8
];
variance_4
[
0
].
data
=
SIMD_LOAD
(
_exp_avg_sq
+
i
);
variance_4
[
1
].
data
=
SIMD_LOAD
(
_exp_avg_sq
+
i
+
SIMD_WIDTH
);
variance_4
[
2
].
data
=
SIMD_LOAD
(
_exp_avg_sq
+
i
+
(
SIMD_WIDTH
<<
1
));
variance_4
[
3
].
data
=
SIMD_LOAD
(
_exp_avg_sq
+
i
+
SIMD_WIDTH
*
3
);
variance_4
[
4
].
data
=
SIMD_LOAD
(
_exp_avg_sq
+
i
+
(
SIMD_WIDTH
<<
2
));
variance_4
[
5
].
data
=
SIMD_LOAD
(
_exp_avg_sq
+
i
+
SIMD_WIDTH
*
5
);
variance_4
[
6
].
data
=
SIMD_LOAD
(
_exp_avg_sq
+
i
+
SIMD_WIDTH
*
6
);
variance_4
[
7
].
data
=
SIMD_LOAD
(
_exp_avg_sq
+
i
+
SIMD_WIDTH
*
7
);
AVX_Data
param_4
[
8
];
param_4
[
0
].
data
=
SIMD_LOAD
(
_params
+
i
);
param_4
[
1
].
data
=
SIMD_LOAD
(
_params
+
i
+
SIMD_WIDTH
);
param_4
[
2
].
data
=
SIMD_LOAD
(
_params
+
i
+
(
SIMD_WIDTH
<<
1
));
param_4
[
3
].
data
=
SIMD_LOAD
(
_params
+
i
+
SIMD_WIDTH
*
3
);
param_4
[
4
].
data
=
SIMD_LOAD
(
_params
+
i
+
(
SIMD_WIDTH
<<
2
));
param_4
[
5
].
data
=
SIMD_LOAD
(
_params
+
i
+
SIMD_WIDTH
*
5
);
param_4
[
6
].
data
=
SIMD_LOAD
(
_params
+
i
+
SIMD_WIDTH
*
6
);
param_4
[
7
].
data
=
SIMD_LOAD
(
_params
+
i
+
SIMD_WIDTH
*
7
);
if
(
_weight_decay
>
0
&&
!
_adamw_mode
)
{
grad_4
[
0
].
data
=
SIMD_FMA
(
param_4
[
0
].
data
,
weight_decay4
.
data
,
grad_4
[
0
].
data
);
grad_4
[
1
].
data
=
SIMD_FMA
(
param_4
[
1
].
data
,
weight_decay4
.
data
,
grad_4
[
1
].
data
);
grad_4
[
2
].
data
=
SIMD_FMA
(
param_4
[
2
].
data
,
weight_decay4
.
data
,
grad_4
[
2
].
data
);
grad_4
[
3
].
data
=
SIMD_FMA
(
param_4
[
3
].
data
,
weight_decay4
.
data
,
grad_4
[
3
].
data
);
grad_4
[
4
].
data
=
SIMD_FMA
(
param_4
[
4
].
data
,
weight_decay4
.
data
,
grad_4
[
4
].
data
);
grad_4
[
5
].
data
=
SIMD_FMA
(
param_4
[
5
].
data
,
weight_decay4
.
data
,
grad_4
[
5
].
data
);
grad_4
[
6
].
data
=
SIMD_FMA
(
param_4
[
6
].
data
,
weight_decay4
.
data
,
grad_4
[
6
].
data
);
grad_4
[
7
].
data
=
SIMD_FMA
(
param_4
[
7
].
data
,
weight_decay4
.
data
,
grad_4
[
7
].
data
);
}
momentum_4
[
0
].
data
=
SIMD_MUL
(
momentum_4
[
0
].
data
,
betta1_4
.
data
);
momentum_4
[
0
].
data
=
SIMD_FMA
(
grad_4
[
0
].
data
,
betta1_minus1_4
.
data
,
momentum_4
[
0
].
data
);
momentum_4
[
1
].
data
=
SIMD_MUL
(
momentum_4
[
1
].
data
,
betta1_4
.
data
);
momentum_4
[
1
].
data
=
SIMD_FMA
(
grad_4
[
1
].
data
,
betta1_minus1_4
.
data
,
momentum_4
[
1
].
data
);
momentum_4
[
2
].
data
=
SIMD_MUL
(
momentum_4
[
2
].
data
,
betta1_4
.
data
);
momentum_4
[
2
].
data
=
SIMD_FMA
(
grad_4
[
2
].
data
,
betta1_minus1_4
.
data
,
momentum_4
[
2
].
data
);
momentum_4
[
3
].
data
=
SIMD_MUL
(
momentum_4
[
3
].
data
,
betta1_4
.
data
);
momentum_4
[
3
].
data
=
SIMD_FMA
(
grad_4
[
3
].
data
,
betta1_minus1_4
.
data
,
momentum_4
[
3
].
data
);
momentum_4
[
4
].
data
=
SIMD_MUL
(
momentum_4
[
4
].
data
,
betta1_4
.
data
);
momentum_4
[
4
].
data
=
SIMD_FMA
(
grad_4
[
4
].
data
,
betta1_minus1_4
.
data
,
momentum_4
[
4
].
data
);
momentum_4
[
5
].
data
=
SIMD_MUL
(
momentum_4
[
5
].
data
,
betta1_4
.
data
);
momentum_4
[
5
].
data
=
SIMD_FMA
(
grad_4
[
5
].
data
,
betta1_minus1_4
.
data
,
momentum_4
[
5
].
data
);
momentum_4
[
6
].
data
=
SIMD_MUL
(
momentum_4
[
6
].
data
,
betta1_4
.
data
);
momentum_4
[
6
].
data
=
SIMD_FMA
(
grad_4
[
6
].
data
,
betta1_minus1_4
.
data
,
momentum_4
[
6
].
data
);
momentum_4
[
7
].
data
=
SIMD_MUL
(
momentum_4
[
7
].
data
,
betta1_4
.
data
);
momentum_4
[
7
].
data
=
SIMD_FMA
(
grad_4
[
7
].
data
,
betta1_minus1_4
.
data
,
momentum_4
[
7
].
data
);
variance_4
[
0
].
data
=
SIMD_MUL
(
variance_4
[
0
].
data
,
betta2_4
.
data
);
variance_4
[
1
].
data
=
SIMD_MUL
(
variance_4
[
1
].
data
,
betta2_4
.
data
);
variance_4
[
2
].
data
=
SIMD_MUL
(
variance_4
[
2
].
data
,
betta2_4
.
data
);
variance_4
[
3
].
data
=
SIMD_MUL
(
variance_4
[
3
].
data
,
betta2_4
.
data
);
variance_4
[
4
].
data
=
SIMD_MUL
(
variance_4
[
4
].
data
,
betta2_4
.
data
);
variance_4
[
5
].
data
=
SIMD_MUL
(
variance_4
[
5
].
data
,
betta2_4
.
data
);
variance_4
[
6
].
data
=
SIMD_MUL
(
variance_4
[
6
].
data
,
betta2_4
.
data
);
variance_4
[
7
].
data
=
SIMD_MUL
(
variance_4
[
7
].
data
,
betta2_4
.
data
);
grad_4
[
0
].
data
=
SIMD_MUL
(
grad_4
[
0
].
data
,
grad_4
[
0
].
data
);
grad_4
[
1
].
data
=
SIMD_MUL
(
grad_4
[
1
].
data
,
grad_4
[
1
].
data
);
grad_4
[
2
].
data
=
SIMD_MUL
(
grad_4
[
2
].
data
,
grad_4
[
2
].
data
);
grad_4
[
3
].
data
=
SIMD_MUL
(
grad_4
[
3
].
data
,
grad_4
[
3
].
data
);
grad_4
[
4
].
data
=
SIMD_MUL
(
grad_4
[
4
].
data
,
grad_4
[
4
].
data
);
grad_4
[
5
].
data
=
SIMD_MUL
(
grad_4
[
5
].
data
,
grad_4
[
5
].
data
);
grad_4
[
6
].
data
=
SIMD_MUL
(
grad_4
[
6
].
data
,
grad_4
[
6
].
data
);
grad_4
[
7
].
data
=
SIMD_MUL
(
grad_4
[
7
].
data
,
grad_4
[
7
].
data
);
variance_4
[
0
].
data
=
SIMD_FMA
(
grad_4
[
0
].
data
,
betta2_minus1_4
.
data
,
variance_4
[
0
].
data
);
variance_4
[
1
].
data
=
SIMD_FMA
(
grad_4
[
1
].
data
,
betta2_minus1_4
.
data
,
variance_4
[
1
].
data
);
variance_4
[
2
].
data
=
SIMD_FMA
(
grad_4
[
2
].
data
,
betta2_minus1_4
.
data
,
variance_4
[
2
].
data
);
variance_4
[
3
].
data
=
SIMD_FMA
(
grad_4
[
3
].
data
,
betta2_minus1_4
.
data
,
variance_4
[
3
].
data
);
variance_4
[
4
].
data
=
SIMD_FMA
(
grad_4
[
4
].
data
,
betta2_minus1_4
.
data
,
variance_4
[
4
].
data
);
variance_4
[
5
].
data
=
SIMD_FMA
(
grad_4
[
5
].
data
,
betta2_minus1_4
.
data
,
variance_4
[
5
].
data
);
variance_4
[
6
].
data
=
SIMD_FMA
(
grad_4
[
6
].
data
,
betta2_minus1_4
.
data
,
variance_4
[
6
].
data
);
variance_4
[
7
].
data
=
SIMD_FMA
(
grad_4
[
7
].
data
,
betta2_minus1_4
.
data
,
variance_4
[
7
].
data
);
grad_4
[
0
].
data
=
SIMD_SQRT
(
variance_4
[
0
].
data
);
grad_4
[
1
].
data
=
SIMD_SQRT
(
variance_4
[
1
].
data
);
grad_4
[
2
].
data
=
SIMD_SQRT
(
variance_4
[
2
].
data
);
grad_4
[
3
].
data
=
SIMD_SQRT
(
variance_4
[
3
].
data
);
grad_4
[
4
].
data
=
SIMD_SQRT
(
variance_4
[
4
].
data
);
grad_4
[
5
].
data
=
SIMD_SQRT
(
variance_4
[
5
].
data
);
grad_4
[
6
].
data
=
SIMD_SQRT
(
variance_4
[
6
].
data
);
grad_4
[
7
].
data
=
SIMD_SQRT
(
variance_4
[
7
].
data
);
grad_4
[
0
].
data
=
SIMD_FMA
(
grad_4
[
0
].
data
,
bias2_sqrt
.
data
,
eps_4
.
data
);
grad_4
[
1
].
data
=
SIMD_FMA
(
grad_4
[
1
].
data
,
bias2_sqrt
.
data
,
eps_4
.
data
);
grad_4
[
2
].
data
=
SIMD_FMA
(
grad_4
[
2
].
data
,
bias2_sqrt
.
data
,
eps_4
.
data
);
grad_4
[
3
].
data
=
SIMD_FMA
(
grad_4
[
3
].
data
,
bias2_sqrt
.
data
,
eps_4
.
data
);
grad_4
[
4
].
data
=
SIMD_FMA
(
grad_4
[
4
].
data
,
bias2_sqrt
.
data
,
eps_4
.
data
);
grad_4
[
5
].
data
=
SIMD_FMA
(
grad_4
[
5
].
data
,
bias2_sqrt
.
data
,
eps_4
.
data
);
grad_4
[
6
].
data
=
SIMD_FMA
(
grad_4
[
6
].
data
,
bias2_sqrt
.
data
,
eps_4
.
data
);
grad_4
[
7
].
data
=
SIMD_FMA
(
grad_4
[
7
].
data
,
bias2_sqrt
.
data
,
eps_4
.
data
);
grad_4
[
0
].
data
=
SIMD_DIV
(
momentum_4
[
0
].
data
,
grad_4
[
0
].
data
);
grad_4
[
1
].
data
=
SIMD_DIV
(
momentum_4
[
1
].
data
,
grad_4
[
1
].
data
);
grad_4
[
2
].
data
=
SIMD_DIV
(
momentum_4
[
2
].
data
,
grad_4
[
2
].
data
);
grad_4
[
3
].
data
=
SIMD_DIV
(
momentum_4
[
3
].
data
,
grad_4
[
3
].
data
);
grad_4
[
4
].
data
=
SIMD_DIV
(
momentum_4
[
4
].
data
,
grad_4
[
4
].
data
);
grad_4
[
5
].
data
=
SIMD_DIV
(
momentum_4
[
5
].
data
,
grad_4
[
5
].
data
);
grad_4
[
6
].
data
=
SIMD_DIV
(
momentum_4
[
6
].
data
,
grad_4
[
6
].
data
);
grad_4
[
7
].
data
=
SIMD_DIV
(
momentum_4
[
7
].
data
,
grad_4
[
7
].
data
);
if
(
_weight_decay
>
0
&&
_adamw_mode
)
{
param_4
[
0
].
data
=
SIMD_FMA
(
param_4
[
0
].
data
,
weight_decay4
.
data
,
param_4
[
0
].
data
);
param_4
[
1
].
data
=
SIMD_FMA
(
param_4
[
1
].
data
,
weight_decay4
.
data
,
param_4
[
1
].
data
);
param_4
[
2
].
data
=
SIMD_FMA
(
param_4
[
2
].
data
,
weight_decay4
.
data
,
param_4
[
2
].
data
);
param_4
[
3
].
data
=
SIMD_FMA
(
param_4
[
3
].
data
,
weight_decay4
.
data
,
param_4
[
3
].
data
);
param_4
[
4
].
data
=
SIMD_FMA
(
param_4
[
4
].
data
,
weight_decay4
.
data
,
param_4
[
4
].
data
);
param_4
[
5
].
data
=
SIMD_FMA
(
param_4
[
5
].
data
,
weight_decay4
.
data
,
param_4
[
5
].
data
);
param_4
[
6
].
data
=
SIMD_FMA
(
param_4
[
6
].
data
,
weight_decay4
.
data
,
param_4
[
6
].
data
);
param_4
[
7
].
data
=
SIMD_FMA
(
param_4
[
7
].
data
,
weight_decay4
.
data
,
param_4
[
7
].
data
);
}
param_4
[
0
].
data
=
SIMD_FMA
(
grad_4
[
0
].
data
,
step_size_4
.
data
,
param_4
[
0
].
data
);
param_4
[
1
].
data
=
SIMD_FMA
(
grad_4
[
1
].
data
,
step_size_4
.
data
,
param_4
[
1
].
data
);
param_4
[
2
].
data
=
SIMD_FMA
(
grad_4
[
2
].
data
,
step_size_4
.
data
,
param_4
[
2
].
data
);
param_4
[
3
].
data
=
SIMD_FMA
(
grad_4
[
3
].
data
,
step_size_4
.
data
,
param_4
[
3
].
data
);
param_4
[
4
].
data
=
SIMD_FMA
(
grad_4
[
4
].
data
,
step_size_4
.
data
,
param_4
[
4
].
data
);
param_4
[
5
].
data
=
SIMD_FMA
(
grad_4
[
5
].
data
,
step_size_4
.
data
,
param_4
[
5
].
data
);
param_4
[
6
].
data
=
SIMD_FMA
(
grad_4
[
6
].
data
,
step_size_4
.
data
,
param_4
[
6
].
data
);
param_4
[
7
].
data
=
SIMD_FMA
(
grad_4
[
7
].
data
,
step_size_4
.
data
,
param_4
[
7
].
data
);
SIMD_STORE
(
_params
+
i
,
param_4
[
0
].
data
);
SIMD_STORE
(
_params
+
i
+
SIMD_WIDTH
,
param_4
[
1
].
data
);
SIMD_STORE
(
_params
+
i
+
(
SIMD_WIDTH
<<
1
),
param_4
[
2
].
data
);
SIMD_STORE
(
_params
+
i
+
SIMD_WIDTH
*
3
,
param_4
[
3
].
data
);
SIMD_STORE
(
_params
+
i
+
(
SIMD_WIDTH
<<
2
),
param_4
[
4
].
data
);
SIMD_STORE
(
_params
+
i
+
SIMD_WIDTH
*
5
,
param_4
[
5
].
data
);
SIMD_STORE
(
_params
+
i
+
SIMD_WIDTH
*
6
,
param_4
[
6
].
data
);
SIMD_STORE
(
_params
+
i
+
SIMD_WIDTH
*
7
,
param_4
[
7
].
data
);
if
(
dev_params
)
{
SIMD_STORE
(
_doubled_buffer
[
_buf_index
]
+
(
i
-
t
),
param_4
[
0
].
data
);
SIMD_STORE
(
_doubled_buffer
[
_buf_index
]
+
(
i
-
t
)
+
SIMD_WIDTH
,
param_4
[
1
].
data
);
SIMD_STORE
(
_doubled_buffer
[
_buf_index
]
+
(
i
-
t
)
+
(
SIMD_WIDTH
<<
1
),
param_4
[
2
].
data
);
SIMD_STORE
(
_doubled_buffer
[
_buf_index
]
+
(
i
-
t
)
+
SIMD_WIDTH
*
3
,
param_4
[
3
].
data
);
SIMD_STORE
(
_doubled_buffer
[
_buf_index
]
+
(
i
-
t
)
+
(
SIMD_WIDTH
<<
2
),
param_4
[
4
].
data
);
SIMD_STORE
(
_doubled_buffer
[
_buf_index
]
+
(
i
-
t
)
+
SIMD_WIDTH
*
5
,
param_4
[
5
].
data
);
SIMD_STORE
(
_doubled_buffer
[
_buf_index
]
+
(
i
-
t
)
+
SIMD_WIDTH
*
6
,
param_4
[
6
].
data
);
SIMD_STORE
(
_doubled_buffer
[
_buf_index
]
+
(
i
-
t
)
+
SIMD_WIDTH
*
7
,
param_4
[
7
].
data
);
}
SIMD_STORE
(
_exp_avg
+
i
,
momentum_4
[
0
].
data
);
SIMD_STORE
(
_exp_avg
+
i
+
SIMD_WIDTH
,
momentum_4
[
1
].
data
);
SIMD_STORE
(
_exp_avg
+
i
+
(
SIMD_WIDTH
<<
1
),
momentum_4
[
2
].
data
);
SIMD_STORE
(
_exp_avg
+
i
+
SIMD_WIDTH
*
3
,
momentum_4
[
3
].
data
);
SIMD_STORE
(
_exp_avg
+
i
+
(
SIMD_WIDTH
<<
2
),
momentum_4
[
4
].
data
);
SIMD_STORE
(
_exp_avg
+
i
+
SIMD_WIDTH
*
5
,
momentum_4
[
5
].
data
);
SIMD_STORE
(
_exp_avg
+
i
+
SIMD_WIDTH
*
6
,
momentum_4
[
6
].
data
);
SIMD_STORE
(
_exp_avg
+
i
+
SIMD_WIDTH
*
7
,
momentum_4
[
7
].
data
);
SIMD_STORE
(
_exp_avg_sq
+
i
,
variance_4
[
0
].
data
);
SIMD_STORE
(
_exp_avg_sq
+
i
+
SIMD_WIDTH
,
variance_4
[
1
].
data
);
SIMD_STORE
(
_exp_avg_sq
+
i
+
(
SIMD_WIDTH
<<
1
),
variance_4
[
2
].
data
);
SIMD_STORE
(
_exp_avg_sq
+
i
+
SIMD_WIDTH
*
3
,
variance_4
[
3
].
data
);
SIMD_STORE
(
_exp_avg_sq
+
i
+
(
SIMD_WIDTH
<<
2
),
variance_4
[
4
].
data
);
SIMD_STORE
(
_exp_avg_sq
+
i
+
SIMD_WIDTH
*
5
,
variance_4
[
5
].
data
);
SIMD_STORE
(
_exp_avg_sq
+
i
+
SIMD_WIDTH
*
6
,
variance_4
[
6
].
data
);
SIMD_STORE
(
_exp_avg_sq
+
i
+
SIMD_WIDTH
*
7
,
variance_4
[
7
].
data
);
}
if
(
dev_params
)
{
launch_param_update
(
_doubled_buffer
[
_buf_index
],
dev_params
+
t
,
copy_size
,
_streams
[
_buf_index
]);
_buf_index
=
!
_buf_index
;
}
}
#endif
if
(
_param_size
>
rounded_size
)
Step_4
((
_params
+
rounded_size
),
(
grads
+
rounded_size
),
(
_exp_avg
+
rounded_size
),
(
_exp_avg_sq
+
rounded_size
),
(
_param_size
-
rounded_size
),
(
dev_params
!=
nullptr
?
(
dev_params
+
rounded_size
)
:
dev_params
));
}
int
ds_adam_step
(
int
optimizer_id
,
size_t
step
,
float
lr
,
float
beta1
,
float
beta2
,
float
epsilon
,
float
weight_decay
,
bool
bias_correction
,
torch
::
Tensor
&
params
,
torch
::
Tensor
&
grads
,
torch
::
Tensor
&
exp_avg
,
torch
::
Tensor
&
exp_avg_sq
)
{
auto
params_c
=
params
.
contiguous
();
auto
grads_c
=
grads
.
contiguous
();
auto
exp_avg_c
=
exp_avg
.
contiguous
();
auto
exp_avg_sq_c
=
exp_avg_sq
.
contiguous
();
float
*
params_ptr
=
(
float
*
)
params_c
.
data_ptr
();
float
*
grads_ptr
=
(
float
*
)
grads_c
.
data_ptr
();
float
*
exp_avg_ptr
=
(
float
*
)
exp_avg_c
.
data_ptr
();
float
*
exp_avg_sq_ptr
=
(
float
*
)
exp_avg_sq_c
.
data_ptr
();
std
::
shared_ptr
<
Adam_Optimizer
>
opt
=
std
::
static_pointer_cast
<
Adam_Optimizer
>
(
s_optimizers
[
optimizer_id
]);
opt
->
IncrementStep
(
step
,
beta1
,
beta2
);
opt
->
update_state
(
lr
,
epsilon
,
weight_decay
,
bias_correction
);
opt
->
Step_8
(
params_ptr
,
grads_ptr
,
exp_avg_ptr
,
exp_avg_sq_ptr
,
params_c
.
size
(
0
));
opt
->
SynchronizeStreams
();
return
0
;
}
int
ds_adam_step_plus_copy
(
int
optimizer_id
,
size_t
step
,
float
lr
,
float
beta1
,
float
beta2
,
float
epsilon
,
float
weight_decay
,
bool
bias_correction
,
torch
::
Tensor
&
params
,
torch
::
Tensor
&
grads
,
torch
::
Tensor
&
exp_avg
,
torch
::
Tensor
&
exp_avg_sq
,
torch
::
Tensor
&
gpu_params
)
{
auto
params_c
=
params
.
contiguous
();
auto
gpu_params_c
=
gpu_params
.
contiguous
();
auto
exp_avg_c
=
exp_avg
.
contiguous
();
auto
exp_avg_sq_c
=
exp_avg_sq
.
contiguous
();
auto
grads_c
=
grads
.
contiguous
();
float
*
params_ptr
=
(
float
*
)
params_c
.
data_ptr
();
float
*
grads_ptr
=
(
float
*
)
grads_c
.
data_ptr
();
__half
*
gpu_params_ptr
=
(
__half
*
)
gpu_params_c
.
data_ptr
();
float
*
exp_avg_ptr
=
(
float
*
)
exp_avg_c
.
data_ptr
();
float
*
exp_avg_sq_ptr
=
(
float
*
)
exp_avg_sq_c
.
data_ptr
();
std
::
shared_ptr
<
Adam_Optimizer
>
opt
=
std
::
static_pointer_cast
<
Adam_Optimizer
>
(
s_optimizers
[
optimizer_id
]);
opt
->
IncrementStep
(
step
,
beta1
,
beta2
);
opt
->
update_state
(
lr
,
epsilon
,
weight_decay
,
bias_correction
);
opt
->
Step_8
(
params_ptr
,
grads_ptr
,
exp_avg_ptr
,
exp_avg_sq_ptr
,
params_c
.
size
(
0
),
gpu_params_ptr
);
opt
->
SynchronizeStreams
();
return
0
;
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"adam_update"
,
&
ds_adam_step
,
"DeepSpeed CPU Adam update (C++)"
);
m
.
def
(
"adam_update_copy"
,
&
ds_adam_step_plus_copy
,
"DeepSpeed CPU Adam update and param copy (C++)"
);
m
.
def
(
"create_adam"
,
&
create_adam_optimizer
,
"DeepSpeed CPU Adam (C++)"
);
}
deepspeed/ops/csrc/adam/hip/custom_hip_kernel.hip
0 → 100644
View file @
eadbbe09
#include "hip/hip_runtime.h"
#include "custom_cuda_layers.h"
__global__ void param_update_kernel(const float* input, __half* output, int size)
{
int id = blockIdx.x * blockDim.x + threadIdx.x;
if (id < size) { output[id] = (__half)input[id]; }
}
void launch_param_update(const float* input, __half* output, int size, hipStream_t stream)
{
int threads = 1024;
dim3 grid_dim((size - 1) / threads + 1);
dim3 block_dim(threads);
hipLaunchKernelGGL(( param_update_kernel), dim3(grid_dim), dim3(block_dim), 0, stream, input, output, size);
}
deepspeed/ops/csrc/adam/hip/fused_adam_frontend.cpp
0 → 100644
View file @
eadbbe09
#include <torch/extension.h>
void
multi_tensor_adam_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
const
float
lr
,
const
float
beta1
,
const
float
beta2
,
const
float
epsilon
,
const
int
step
,
const
int
mode
,
const
int
bias_correction
,
const
float
weight_decay
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"multi_tensor_adam"
,
&
multi_tensor_adam_cuda
,
"Compute and apply gradient update to parameters for Adam optimizer"
);
}
deepspeed/ops/csrc/adam/hip/multi_tensor_adam.hip
0 → 100644
View file @
eadbbe09
/* Copyright 2020 The Microsoft DeepSpeed Team
Copyright NVIDIA/apex
This file is adapted from fused adam in NVIDIA/apex, commit a109f85
*/
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/hip/HIPContext.h>
#include <ATen/hip/Exceptions.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
#include "multi_tensor_apply.cuh"
#include "type_shim.h"
#define BLOCK_SIZE 512
#define ILP 4
typedef enum {
ADAM_MODE_0 = 0, // L2 regularization mode
ADAM_MODE_1 = 1 // Decoupled weight decay mode(AdamW)
} adamMode_t;
using MATH_T = float;
template <typename T>
struct AdamFunctor {
__device__ __forceinline__ void operator()(int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<4>& tl,
const float beta1,
const float beta2,
const float beta1_correction,
const float beta2_correction,
const float epsilon,
const float lr,
adamMode_t mode,
const float decay)
{
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
int tensor_loc = tl.block_to_tensor[blockIdx.x];
// potentially use to pass in list of scalar
// int tensor_num = tl.start_tensor_this_launch + tensor_loc;
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
T* g = (T*)tl.addresses[0][tensor_loc];
g += chunk_idx * chunk_size;
T* p = (T*)tl.addresses[1][tensor_loc];
p += chunk_idx * chunk_size;
T* m = (T*)tl.addresses[2][tensor_loc];
m += chunk_idx * chunk_size;
T* v = (T*)tl.addresses[3][tensor_loc];
v += chunk_idx * chunk_size;
n -= chunk_idx * chunk_size;
// see note in multi_tensor_scale_kernel.cu
for (int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * ILP) {
MATH_T r_g[ILP];
MATH_T r_p[ILP];
MATH_T r_m[ILP];
MATH_T r_v[ILP];
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
r_g[ii] = g[i];
r_p[ii] = p[i];
r_m[ii] = m[i];
r_v[ii] = v[i];
} else {
r_g[ii] = MATH_T(0);
r_p[ii] = MATH_T(0);
r_m[ii] = MATH_T(0);
r_v[ii] = MATH_T(0);
}
}
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
if (mode == ADAM_MODE_0) { // L2
r_g[ii] = r_g[ii] + (decay * r_p[ii]);
r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
MATH_T update = next_m_unbiased / denom;
r_p[ii] = r_p[ii] - (lr * update);
} else { // weight decay
r_m[ii] = beta1 * r_m[ii] + (1 - beta1) * r_g[ii];
r_v[ii] = beta2 * r_v[ii] + (1 - beta2) * r_g[ii] * r_g[ii];
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]);
r_p[ii] = r_p[ii] - (lr * update);
}
}
#pragma unroll
for (int ii = 0; ii < ILP; ii++) {
int i = i_start + threadIdx.x + ii * blockDim.x;
if (i < n && i < chunk_size) {
p[i] = r_p[ii];
m[i] = r_m[ii];
v[i] = r_v[ii];
}
}
}
}
};
void multi_tensor_adam_cuda(int chunk_size,
at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
const float lr,
const float beta1,
const float beta2,
const float epsilon,
const int step,
const int mode,
const int bias_correction,
const float weight_decay)
{
using namespace at;
// Handle bias correction mode
float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
if (bias_correction == 1) {
bias_correction1 = 1 - ::pow(beta1, step);
bias_correction2 = 1 - ::pow(beta2, step);
}
// Assume single type across p,g,m1,m2 now
DISPATCH_DOUBLE_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(),
0,
"adam",
multi_tensor_apply<4>(BLOCK_SIZE,
chunk_size,
noop_flag,
tensor_lists,
AdamFunctor<scalar_t_0>(),
beta1,
beta2,
bias_correction1,
bias_correction2,
epsilon,
lr,
(adamMode_t)mode,
weight_decay);)
AT_CUDA_CHECK(hipGetLastError());
}
deepspeed/ops/csrc/adam/hip/multi_tensor_apply.cuh
0 → 100644
View file @
eadbbe09
#include "hip/hip_runtime.h"
/* Copyright 2020 The Microsoft DeepSpeed Team
Copyright NVIDIA/apex
This file is adapted from fused adam in NVIDIA/apex, commit a109f85
*/
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/hip/HIPContext.h>
#include <ATen/hip/Exceptions.h>
#include <ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h>
#include "compat.h"
#include <assert.h>
// #include <iostream>
// This header is the one-stop shop for all your multi-tensor apply needs.
// TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson)
constexpr
int
depth_to_max_tensors
[
5
]
=
{
110
,
64
,
48
,
36
,
30
};
constexpr
int
depth_to_max_blocks
[
5
]
=
{
320
,
320
,
320
,
320
,
320
};
template
<
int
n
>
struct
TensorListMetadata
{
void
*
addresses
[
n
][
depth_to_max_tensors
[
n
-
1
]];
int
sizes
[
depth_to_max_tensors
[
n
-
1
]];
unsigned
char
block_to_tensor
[
depth_to_max_blocks
[
n
-
1
]];
int
block_to_chunk
[
depth_to_max_blocks
[
n
-
1
]];
// I fear this needs to be a full int.
int
start_tensor_this_launch
;
};
template
<
typename
T
,
typename
U
,
typename
...
ArgTypes
>
__global__
void
multi_tensor_apply_kernel
(
int
chunk_size
,
volatile
int
*
noop_flag
,
T
tl
,
U
callable
,
ArgTypes
...
args
)
{
// Hand the chunk information to the user-supplied functor to process however it likes.
callable
(
chunk_size
,
noop_flag
,
tl
,
args
...);
}
template
<
int
depth
,
typename
T
,
typename
...
ArgTypes
>
void
multi_tensor_apply
(
int
block_size
,
int
chunk_size
,
const
at
::
Tensor
&
noop_flag
,
const
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>&
tensor_lists
,
T
callable
,
ArgTypes
...
args
)
{
TORCH_CHECK
(
tensor_lists
.
size
()
==
depth
,
"tensor_lists.size() != depth"
);
int
len0
=
tensor_lists
[
0
].
size
();
TORCH_CHECK
(
len0
>
0
,
"tensor_lists[0].size() is not > 0"
);
auto
ref_device
=
tensor_lists
[
0
][
0
].
device
();
TORCH_CHECK
(
ref_device
.
type
()
==
at
::
kCUDA
,
"expected input to be on cuda"
);
for
(
int
l
=
0
;
l
<
tensor_lists
.
size
();
l
++
)
// No range-based for because I need indices
{
TORCH_CHECK
(
tensor_lists
[
l
].
size
()
==
len0
,
"Size mismatch among tensor lists"
);
for
(
int
t
=
0
;
t
<
tensor_lists
[
l
].
size
();
t
++
)
{
// TODO: Print which tensor fails.
bool
contiguous_memory
=
tensor_lists
[
l
][
t
].
is_contiguous
();
#ifdef VERSION_GE_1_5
contiguous_memory
=
(
contiguous_memory
||
tensor_lists
[
l
][
t
].
is_contiguous
(
at
::
MemoryFormat
::
ChannelsLast
));
#endif
TORCH_CHECK
(
contiguous_memory
,
"A tensor was not contiguous."
);
TORCH_CHECK
(
tensor_lists
[
l
][
t
].
device
()
==
ref_device
,
"A tensor was not on the same device as the first tensor"
);
TORCH_CHECK
(
tensor_lists
[
l
][
t
].
numel
()
==
tensor_lists
[
0
][
t
].
numel
(),
"Size mismatch"
);
}
}
int
ntensors
=
tensor_lists
[
0
].
size
();
TensorListMetadata
<
depth
>
tl
;
const
at
::
hip
::
OptionalHIPGuardMasqueradingAsCUDA
device_guard
(
device_of
(
tensor_lists
[
0
][
0
]));
auto
stream
=
at
::
hip
::
getCurrentHIPStreamMasqueradingAsCUDA
();
tl
.
start_tensor_this_launch
=
0
;
int
loc_block_info
=
0
;
int
loc_tensor_info
=
0
;
for
(
int
t
=
0
;
t
<
ntensors
;
t
++
)
{
tl
.
sizes
[
loc_tensor_info
]
=
tensor_lists
[
0
][
t
].
numel
();
for
(
int
d
=
0
;
d
<
depth
;
d
++
)
tl
.
addresses
[
d
][
loc_tensor_info
]
=
tensor_lists
[
d
][
t
].
data_ptr
();
loc_tensor_info
++
;
int
chunks_this_tensor
=
(
tensor_lists
[
0
][
t
].
numel
()
+
chunk_size
-
1
)
/
chunk_size
;
for
(
int
chunk
=
0
;
chunk
<
chunks_this_tensor
;
chunk
++
)
{
// std::cout << chunks_this_tensor << std::endl;
tl
.
block_to_tensor
[
loc_block_info
]
=
loc_tensor_info
-
1
;
tl
.
block_to_chunk
[
loc_block_info
]
=
chunk
;
loc_block_info
++
;
bool
tensors_full
=
(
loc_tensor_info
==
depth_to_max_tensors
[
depth
-
1
]
&&
chunk
==
chunks_this_tensor
-
1
);
bool
blocks_full
=
(
loc_block_info
==
depth_to_max_blocks
[
depth
-
1
]);
bool
last_chunk
=
(
t
==
ntensors
-
1
&&
chunk
==
chunks_this_tensor
-
1
);
if
(
tensors_full
||
blocks_full
||
last_chunk
)
{
// using accscalar_t = acc_type<scalar_t, true>;
hipLaunchKernelGGL
((
multi_tensor_apply_kernel
),
dim3
(
loc_block_info
),
dim3
(
block_size
),
0
,
stream
,
chunk_size
,
noop_flag
.
DATA_PTR
<
int
>
(),
tl
,
callable
,
args
...);
AT_CUDA_CHECK
(
hipGetLastError
());
// Reset. The control flow possibilities here make my brain hurt.
loc_block_info
=
0
;
if
(
chunk
==
chunks_this_tensor
-
1
)
{
// std::cout << "Hit case 1 " << cond1 << " " << cond2 << " " << cond3 <<
// std::endl;
loc_tensor_info
=
0
;
tl
.
start_tensor_this_launch
=
t
+
1
;
}
else
{
// std::cout << "Hit case 2 " << cond1 << " " << cond2 << " " << cond3 <<
// std::endl;
tl
.
sizes
[
0
]
=
tl
.
sizes
[
loc_tensor_info
-
1
];
for
(
int
d
=
0
;
d
<
depth
;
d
++
)
tl
.
addresses
[
d
][
0
]
=
tl
.
addresses
[
d
][
loc_tensor_info
-
1
];
loc_tensor_info
=
1
;
tl
.
start_tensor_this_launch
=
t
;
}
}
}
}
}
deepspeed/ops/csrc/adam/multi_tensor_adam.cu
0 → 100644
View file @
eadbbe09
/* Copyright 2020 The Microsoft DeepSpeed Team
Copyright NVIDIA/apex
This file is adapted from fused adam in NVIDIA/apex, commit a109f85
*/
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
#include "multi_tensor_apply.cuh"
#include "type_shim.h"
#define BLOCK_SIZE 512
#define ILP 4
typedef
enum
{
ADAM_MODE_0
=
0
,
// L2 regularization mode
ADAM_MODE_1
=
1
// Decoupled weight decay mode(AdamW)
}
adamMode_t
;
using
MATH_T
=
float
;
template
<
typename
T
>
struct
AdamFunctor
{
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
4
>&
tl
,
const
float
beta1
,
const
float
beta2
,
const
float
beta1_correction
,
const
float
beta2_correction
,
const
float
epsilon
,
const
float
lr
,
adamMode_t
mode
,
const
float
decay
)
{
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
// potentially use to pass in list of scalar
// int tensor_num = tl.start_tensor_this_launch + tensor_loc;
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
T
*
g
=
(
T
*
)
tl
.
addresses
[
0
][
tensor_loc
];
g
+=
chunk_idx
*
chunk_size
;
T
*
p
=
(
T
*
)
tl
.
addresses
[
1
][
tensor_loc
];
p
+=
chunk_idx
*
chunk_size
;
T
*
m
=
(
T
*
)
tl
.
addresses
[
2
][
tensor_loc
];
m
+=
chunk_idx
*
chunk_size
;
T
*
v
=
(
T
*
)
tl
.
addresses
[
3
][
tensor_loc
];
v
+=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
// see note in multi_tensor_scale_kernel.cu
for
(
int
i_start
=
0
;
i_start
<
n
&&
i_start
<
chunk_size
;
i_start
+=
blockDim
.
x
*
ILP
)
{
MATH_T
r_g
[
ILP
];
MATH_T
r_p
[
ILP
];
MATH_T
r_m
[
ILP
];
MATH_T
r_v
[
ILP
];
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
int
i
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
i
<
n
&&
i
<
chunk_size
)
{
r_g
[
ii
]
=
g
[
i
];
r_p
[
ii
]
=
p
[
i
];
r_m
[
ii
]
=
m
[
i
];
r_v
[
ii
]
=
v
[
i
];
}
else
{
r_g
[
ii
]
=
MATH_T
(
0
);
r_p
[
ii
]
=
MATH_T
(
0
);
r_m
[
ii
]
=
MATH_T
(
0
);
r_v
[
ii
]
=
MATH_T
(
0
);
}
}
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
if
(
mode
==
ADAM_MODE_0
)
{
// L2
r_g
[
ii
]
=
r_g
[
ii
]
+
(
decay
*
r_p
[
ii
]);
r_m
[
ii
]
=
beta1
*
r_m
[
ii
]
+
(
1
-
beta1
)
*
r_g
[
ii
];
r_v
[
ii
]
=
beta2
*
r_v
[
ii
]
+
(
1
-
beta2
)
*
r_g
[
ii
]
*
r_g
[
ii
];
MATH_T
next_m_unbiased
=
r_m
[
ii
]
/
beta1_correction
;
MATH_T
next_v_unbiased
=
r_v
[
ii
]
/
beta2_correction
;
MATH_T
denom
=
sqrtf
(
next_v_unbiased
)
+
epsilon
;
MATH_T
update
=
next_m_unbiased
/
denom
;
r_p
[
ii
]
=
r_p
[
ii
]
-
(
lr
*
update
);
}
else
{
// weight decay
r_m
[
ii
]
=
beta1
*
r_m
[
ii
]
+
(
1
-
beta1
)
*
r_g
[
ii
];
r_v
[
ii
]
=
beta2
*
r_v
[
ii
]
+
(
1
-
beta2
)
*
r_g
[
ii
]
*
r_g
[
ii
];
MATH_T
next_m_unbiased
=
r_m
[
ii
]
/
beta1_correction
;
MATH_T
next_v_unbiased
=
r_v
[
ii
]
/
beta2_correction
;
MATH_T
denom
=
sqrtf
(
next_v_unbiased
)
+
epsilon
;
MATH_T
update
=
(
next_m_unbiased
/
denom
)
+
(
decay
*
r_p
[
ii
]);
r_p
[
ii
]
=
r_p
[
ii
]
-
(
lr
*
update
);
}
}
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
int
i
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
i
<
n
&&
i
<
chunk_size
)
{
p
[
i
]
=
r_p
[
ii
];
m
[
i
]
=
r_m
[
ii
];
v
[
i
]
=
r_v
[
ii
];
}
}
}
}
};
void
multi_tensor_adam_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
const
float
lr
,
const
float
beta1
,
const
float
beta2
,
const
float
epsilon
,
const
int
step
,
const
int
mode
,
const
int
bias_correction
,
const
float
weight_decay
)
{
using
namespace
at
;
// Handle bias correction mode
float
bias_correction1
=
1.0
f
,
bias_correction2
=
1.0
f
;
if
(
bias_correction
==
1
)
{
bias_correction1
=
1
-
std
::
pow
(
beta1
,
step
);
bias_correction2
=
1
-
std
::
pow
(
beta2
,
step
);
}
// Assume single type across p,g,m1,m2 now
DISPATCH_DOUBLE_FLOAT_AND_HALF
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"adam"
,
multi_tensor_apply
<
4
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
AdamFunctor
<
scalar_t_0
>
(),
beta1
,
beta2
,
bias_correction1
,
bias_correction2
,
epsilon
,
lr
,
(
adamMode_t
)
mode
,
weight_decay
);)
AT_CUDA_CHECK
(
cudaGetLastError
());
}
deepspeed/ops/csrc/adam/multi_tensor_apply.cuh
0 → 100644
View file @
eadbbe09
/* Copyright 2020 The Microsoft DeepSpeed Team
Copyright NVIDIA/apex
This file is adapted from fused adam in NVIDIA/apex, commit a109f85
*/
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <c10/cuda/CUDAGuard.h>
#include "compat.h"
#include <assert.h>
// #include <iostream>
// This header is the one-stop shop for all your multi-tensor apply needs.
// TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson)
constexpr
int
depth_to_max_tensors
[
5
]
=
{
110
,
64
,
48
,
36
,
30
};
constexpr
int
depth_to_max_blocks
[
5
]
=
{
320
,
320
,
320
,
320
,
320
};
template
<
int
n
>
struct
TensorListMetadata
{
void
*
addresses
[
n
][
depth_to_max_tensors
[
n
-
1
]];
int
sizes
[
depth_to_max_tensors
[
n
-
1
]];
unsigned
char
block_to_tensor
[
depth_to_max_blocks
[
n
-
1
]];
int
block_to_chunk
[
depth_to_max_blocks
[
n
-
1
]];
// I fear this needs to be a full int.
int
start_tensor_this_launch
;
};
template
<
typename
T
,
typename
U
,
typename
...
ArgTypes
>
__global__
void
multi_tensor_apply_kernel
(
int
chunk_size
,
volatile
int
*
noop_flag
,
T
tl
,
U
callable
,
ArgTypes
...
args
)
{
// Hand the chunk information to the user-supplied functor to process however it likes.
callable
(
chunk_size
,
noop_flag
,
tl
,
args
...);
}
template
<
int
depth
,
typename
T
,
typename
...
ArgTypes
>
void
multi_tensor_apply
(
int
block_size
,
int
chunk_size
,
const
at
::
Tensor
&
noop_flag
,
const
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>&
tensor_lists
,
T
callable
,
ArgTypes
...
args
)
{
TORCH_CHECK
(
tensor_lists
.
size
()
==
depth
,
"tensor_lists.size() != depth"
);
int
len0
=
tensor_lists
[
0
].
size
();
TORCH_CHECK
(
len0
>
0
,
"tensor_lists[0].size() is not > 0"
);
auto
ref_device
=
tensor_lists
[
0
][
0
].
device
();
TORCH_CHECK
(
ref_device
.
type
()
==
at
::
kCUDA
,
"expected input to be on cuda"
);
for
(
int
l
=
0
;
l
<
tensor_lists
.
size
();
l
++
)
// No range-based for because I need indices
{
TORCH_CHECK
(
tensor_lists
[
l
].
size
()
==
len0
,
"Size mismatch among tensor lists"
);
for
(
int
t
=
0
;
t
<
tensor_lists
[
l
].
size
();
t
++
)
{
// TODO: Print which tensor fails.
bool
contiguous_memory
=
tensor_lists
[
l
][
t
].
is_contiguous
();
#ifdef VERSION_GE_1_5
contiguous_memory
=
(
contiguous_memory
||
tensor_lists
[
l
][
t
].
is_contiguous
(
at
::
MemoryFormat
::
ChannelsLast
));
#endif
TORCH_CHECK
(
contiguous_memory
,
"A tensor was not contiguous."
);
TORCH_CHECK
(
tensor_lists
[
l
][
t
].
device
()
==
ref_device
,
"A tensor was not on the same device as the first tensor"
);
TORCH_CHECK
(
tensor_lists
[
l
][
t
].
numel
()
==
tensor_lists
[
0
][
t
].
numel
(),
"Size mismatch"
);
}
}
int
ntensors
=
tensor_lists
[
0
].
size
();
TensorListMetadata
<
depth
>
tl
;
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
tensor_lists
[
0
][
0
]));
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
tl
.
start_tensor_this_launch
=
0
;
int
loc_block_info
=
0
;
int
loc_tensor_info
=
0
;
for
(
int
t
=
0
;
t
<
ntensors
;
t
++
)
{
tl
.
sizes
[
loc_tensor_info
]
=
tensor_lists
[
0
][
t
].
numel
();
for
(
int
d
=
0
;
d
<
depth
;
d
++
)
tl
.
addresses
[
d
][
loc_tensor_info
]
=
tensor_lists
[
d
][
t
].
data_ptr
();
loc_tensor_info
++
;
int
chunks_this_tensor
=
(
tensor_lists
[
0
][
t
].
numel
()
+
chunk_size
-
1
)
/
chunk_size
;
for
(
int
chunk
=
0
;
chunk
<
chunks_this_tensor
;
chunk
++
)
{
// std::cout << chunks_this_tensor << std::endl;
tl
.
block_to_tensor
[
loc_block_info
]
=
loc_tensor_info
-
1
;
tl
.
block_to_chunk
[
loc_block_info
]
=
chunk
;
loc_block_info
++
;
bool
tensors_full
=
(
loc_tensor_info
==
depth_to_max_tensors
[
depth
-
1
]
&&
chunk
==
chunks_this_tensor
-
1
);
bool
blocks_full
=
(
loc_block_info
==
depth_to_max_blocks
[
depth
-
1
]);
bool
last_chunk
=
(
t
==
ntensors
-
1
&&
chunk
==
chunks_this_tensor
-
1
);
if
(
tensors_full
||
blocks_full
||
last_chunk
)
{
// using accscalar_t = acc_type<scalar_t, true>;
multi_tensor_apply_kernel
<<<
loc_block_info
,
block_size
,
0
,
stream
>>>
(
chunk_size
,
noop_flag
.
DATA_PTR
<
int
>
(),
tl
,
callable
,
args
...);
AT_CUDA_CHECK
(
cudaGetLastError
());
// Reset. The control flow possibilities here make my brain hurt.
loc_block_info
=
0
;
if
(
chunk
==
chunks_this_tensor
-
1
)
{
// std::cout << "Hit case 1 " << cond1 << " " << cond2 << " " << cond3 <<
// std::endl;
loc_tensor_info
=
0
;
tl
.
start_tensor_this_launch
=
t
+
1
;
}
else
{
// std::cout << "Hit case 2 " << cond1 << " " << cond2 << " " << cond3 <<
// std::endl;
tl
.
sizes
[
0
]
=
tl
.
sizes
[
loc_tensor_info
-
1
];
for
(
int
d
=
0
;
d
<
depth
;
d
++
)
tl
.
addresses
[
d
][
0
]
=
tl
.
addresses
[
d
][
loc_tensor_info
-
1
];
loc_tensor_info
=
1
;
tl
.
start_tensor_this_launch
=
t
;
}
}
}
}
}
deepspeed/ops/csrc/includes/StopWatch.h
0 → 100644
View file @
eadbbe09
#pragma once
#ifdef _WIN32
#include <windows.h>
#else
#include <time.h>
#endif
#ifdef _WIN32
class
Stopwatch
{
private:
double
m_total_time
;
LARGE_INTEGER
m_start_time
;
public:
Stopwatch
()
{
m_total_time
=
0.0
;
}
~
Stopwatch
()
{}
void
Reset
()
{
m_total_time
=
0.0
;
}
void
Start
()
{
QueryPerformanceCounter
(
&
m_start_time
);
}
void
Restart
()
{
m_total_time
=
0.0
;
QueryPerformanceCounter
(
&
m_start_time
);
}
void
Stop
()
{
LARGE_INTEGER
frequency
;
LARGE_INTEGER
stop_time
;
QueryPerformanceFrequency
(
&
frequency
);
QueryPerformanceCounter
(
&
stop_time
);
m_total_time
+=
((
double
)(
stop_time
.
QuadPart
-
m_start_time
.
QuadPart
)
/
(
double
)
frequency
.
QuadPart
);
}
double
GetTimeInSeconds
()
{
return
m_total_time
;
}
};
#else
class
Stopwatch
{
private:
double
m_total_time
;
struct
timespec
m_start_time
;
bool
m_is_started
;
public:
Stopwatch
()
{
m_total_time
=
0.0
;
m_is_started
=
false
;
}
~
Stopwatch
()
{}
void
Reset
()
{
m_total_time
=
0.0
;
}
void
Start
()
{
clock_gettime
(
CLOCK_MONOTONIC
,
&
m_start_time
);
m_is_started
=
true
;
}
void
Restart
()
{
m_total_time
=
0.0
;
clock_gettime
(
CLOCK_MONOTONIC
,
&
m_start_time
);
m_is_started
=
true
;
}
void
Stop
()
{
if
(
m_is_started
)
{
m_is_started
=
false
;
struct
timespec
end_time
;
clock_gettime
(
CLOCK_MONOTONIC
,
&
end_time
);
m_total_time
+=
(
double
)(
end_time
.
tv_sec
-
m_start_time
.
tv_sec
)
+
(
double
)(
end_time
.
tv_nsec
-
m_start_time
.
tv_nsec
)
/
1e9
;
}
}
double
GetTimeInSeconds
()
{
if
(
m_is_started
)
{
Stop
();
Start
();
}
return
m_total_time
;
}
};
#endif
deepspeed/ops/csrc/includes/Timer.h
0 → 100644
View file @
eadbbe09
#ifndef __TIMER_H__
#define __TIMER_H__
#include <cuda_runtime.h>
#include <chrono>
#include "cuda.h"
class
GPUTimer
{
cudaEvent_t
start
,
stop
;
public:
GPUTimer
()
{
cudaEventCreate
(
&
start
);
cudaEventCreate
(
&
stop
);
}
~
GPUTimer
()
{
cudaEventDestroy
(
start
);
cudaEventDestroy
(
stop
);
}
inline
void
Record
()
{
cudaEventRecord
(
start
);
}
inline
void
Elapsed
(
float
&
time_elapsed
)
{
cudaEventRecord
(
stop
);
cudaEventSynchronize
(
stop
);
cudaEventElapsedTime
(
&
time_elapsed
,
start
,
stop
);
}
};
class
CPUTimer
{
std
::
chrono
::
high_resolution_clock
::
time_point
start
;
public:
CPUTimer
()
:
start
(
std
::
chrono
::
high_resolution_clock
::
now
())
{}
inline
void
Reset
()
{
start
=
std
::
chrono
::
high_resolution_clock
::
now
();
}
inline
float
Elapsed
()
{
auto
temp
=
start
;
start
=
std
::
chrono
::
high_resolution_clock
::
now
();
return
(
float
)(
std
::
chrono
::
duration_cast
<
std
::
chrono
::
microseconds
>
(
start
-
temp
).
count
()
/
1e3
);
}
};
#endif
deepspeed/ops/csrc/includes/context.h
0 → 100644
View file @
eadbbe09
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime_api.h>
#include <cassert>
#include <iostream>
#include <vector>
#include "cublas_v2.h"
#include "cuda.h"
#include "curand.h"
#include "gemm_test.h"
#define WARP_SIZE 32
#define CUDA_CHECK(callstr) \
{ \
cudaError_t error_code = callstr; \
if (error_code != cudaSuccess) { \
std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__; \
assert(0); \
} \
}
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x)
#define CUDA_2D_KERNEL_LOOP(i, n, j, m) \
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x) \
for (size_t j = blockIdx.y * blockDim.y + threadIdx.y; j < (m); j += blockDim.y * gridDim.y)
#define DS_CUDA_NUM_THREADS 512
#define DS_MAXIMUM_NUM_BLOCKS 262144
inline
int
DS_GET_BLOCKS
(
const
int
N
)
{
return
(
std
::
max
)(
(
std
::
min
)((
N
+
DS_CUDA_NUM_THREADS
-
1
)
/
DS_CUDA_NUM_THREADS
,
DS_MAXIMUM_NUM_BLOCKS
),
// Use at least 1 block, since CUDA does not allow empty block
1
);
}
class
Context
{
public:
Context
()
:
_workspace
(
nullptr
),
_seed
(
42
),
_curr_offset
(
0
)
{
curandCreateGenerator
(
&
_gen
,
CURAND_RNG_PSEUDO_DEFAULT
);
curandSetPseudoRandomGeneratorSeed
(
_gen
,
123
);
if
(
cublasCreate
(
&
_cublasHandle
)
!=
CUBLAS_STATUS_SUCCESS
)
{
auto
message
=
std
::
string
(
"Fail to create cublas handle."
);
std
::
cerr
<<
message
<<
std
::
endl
;
throw
std
::
runtime_error
(
message
);
}
}
virtual
~
Context
()
{
cublasDestroy
(
_cublasHandle
);
cudaFree
(
_workspace
);
}
static
Context
&
Instance
()
{
static
Context
_ctx
;
return
_ctx
;
}
void
SetWorkSpace
(
void
*
workspace
)
{
if
(
!
workspace
)
{
throw
std
::
runtime_error
(
"Workspace is null."
);
}
_workspace
=
workspace
;
}
void
*
GetWorkSpace
()
{
return
_workspace
;
}
curandGenerator_t
&
GetRandGenerator
()
{
return
_gen
;
}
cudaStream_t
GetCurrentStream
()
{
// get current pytorch stream.
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
return
stream
;
}
cudaStream_t
GetNewStream
()
{
return
at
::
cuda
::
getStreamFromPool
();
}
cublasHandle_t
GetCublasHandle
()
{
return
_cublasHandle
;
}
std
::
pair
<
uint64_t
,
uint64_t
>
IncrementOffset
(
uint64_t
offset_inc
)
{
uint64_t
offset
=
_curr_offset
;
_curr_offset
+=
offset_inc
;
return
std
::
pair
<
uint64_t
,
uint64_t
>
(
_seed
,
offset
);
}
void
SetSeed
(
uint64_t
new_seed
)
{
_seed
=
new_seed
;
}
void
TestGemmFP16
(
bool
test_gemm
,
int
batch_size
,
int
seq_len
,
int
head_num
,
int
size_per_head
)
{
// avoid rerun.
if
(
_gemm_algos
.
size
()
>
0
)
return
;
if
(
test_gemm
)
{
cublasHandle_t
handle
=
GetCublasHandle
();
std
::
unique_ptr
<
GemmTest
<
__half
>>
test_qkv_fw
(
new
GemmTest
<
__half
>
(
batch_size
*
seq_len
,
// M
head_num
*
size_per_head
,
// N
head_num
*
size_per_head
,
// K
CUBLAS_OP_T
,
CUBLAS_OP_N
,
handle
));
std
::
unique_ptr
<
GemmTest
<
__half
>>
test_inter
(
new
GemmTest
<
__half
>
(
batch_size
*
seq_len
,
// M
4
*
head_num
*
size_per_head
,
// N
head_num
*
size_per_head
,
// K
CUBLAS_OP_T
,
CUBLAS_OP_N
,
handle
));
std
::
unique_ptr
<
GemmTest
<
__half
>>
test_output
(
new
GemmTest
<
__half
>
(
batch_size
*
seq_len
,
// M
head_num
*
size_per_head
,
// N
4
*
head_num
*
size_per_head
,
// K
CUBLAS_OP_T
,
CUBLAS_OP_N
,
handle
));
std
::
unique_ptr
<
StridedGemmTest
<
__half
>>
test_attn_scores
(
new
StridedGemmTest
<
__half
>
(
batch_size
*
head_num
,
// batch
seq_len
,
// M
seq_len
,
// N
size_per_head
,
// K
CUBLAS_OP_T
,
CUBLAS_OP_N
,
handle
));
std
::
unique_ptr
<
StridedGemmTest
<
__half
>>
test_attn_context
(
new
StridedGemmTest
<
__half
>
(
batch_size
*
head_num
,
// batch
size_per_head
,
// M
seq_len
,
// N
seq_len
,
// K
CUBLAS_OP_N
,
CUBLAS_OP_N
,
handle
));
_gemm_algos
.
push_back
(
test_qkv_fw
->
TestAlgo
(
100
));
_gemm_algos
.
push_back
(
test_inter
->
TestAlgo
(
100
));
_gemm_algos
.
push_back
(
test_output
->
TestAlgo
(
100
));
_gemm_algos
.
push_back
(
test_attn_scores
->
TestAlgo
(
100
));
_gemm_algos
.
push_back
(
test_attn_context
->
TestAlgo
(
100
));
}
else
{
// Use default algo.
_gemm_algos
.
push_back
(
std
::
array
<
int
,
3
>
({
99
,
99
,
99
}));
_gemm_algos
.
push_back
(
std
::
array
<
int
,
3
>
({
99
,
99
,
99
}));
_gemm_algos
.
push_back
(
std
::
array
<
int
,
3
>
({
99
,
99
,
99
}));
_gemm_algos
.
push_back
(
std
::
array
<
int
,
3
>
({
99
,
99
,
99
}));
_gemm_algos
.
push_back
(
std
::
array
<
int
,
3
>
({
99
,
99
,
99
}));
}
}
const
std
::
vector
<
std
::
array
<
int
,
3
>>&
GetGemmAlgos
()
const
{
return
_gemm_algos
;
}
private:
curandGenerator_t
_gen
;
cublasHandle_t
_cublasHandle
;
void
*
_workspace
;
uint64_t
_seed
;
uint64_t
_curr_offset
;
std
::
vector
<
std
::
array
<
int
,
3
>>
_gemm_algos
;
};
Prev
1
2
3
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment