Unverified Commit 17ee854e authored by Deyu Fu's avatar Deyu Fu Committed by GitHub
Browse files

enable wider load/store for multi_tensor_apply kernels (#763)

* modify MTA axpby for wider load/store

* Make scale/axpby/l2/adam/lamb multi_tensor uses wider load
parent 31aceeaa
...@@ -14,6 +14,17 @@ ...@@ -14,6 +14,17 @@
#define BLOCK_SIZE 512 #define BLOCK_SIZE 512
#define ILP 4 #define ILP 4
template<typename T>
__device__ __forceinline__ bool is_aligned(T* p){
return ((uint64_t)p) % (ILP*sizeof(T)) == 0;
}
template<typename T>
__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){
typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LT;
((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
}
#include "type_shim.h" #include "type_shim.h"
typedef enum{ typedef enum{
...@@ -99,24 +110,64 @@ struct AdamFunctor ...@@ -99,24 +110,64 @@ struct AdamFunctor
T incoming_v[ILP]; T incoming_v[ILP];
T incoming_g[ILP]; T incoming_g[ILP];
for(int i_start = 0; // to make things simple, we put aligned case in a different code path
i_start < n && i_start < chunk_size; if(n % ILP == 0 &&
i_start += blockDim.x*ILP) { chunk_size % ILP == 0 &&
is_aligned(p) &&
is_aligned(m) &&
is_aligned(v) &&
is_aligned(g) &&
is_aligned(p_copy))
{
for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x)
{
// load
GRAD_T tmp_g[ILP];
load_store(incoming_p, p, 0, i_start);
load_store(incoming_m, m, 0, i_start);
load_store(incoming_v, v, 0, i_start);
load_store(tmp_g, g, 0, i_start);
#pragma unroll
for(int ii = 0; ii < ILP; ii++) {
incoming_g[ii] = static_cast<T>(tmp_g[ii]);
T scaled_grad = incoming_g[ii]/grad_scale;
incoming_m[ii] = b1*incoming_m[ii] + (1-b1)*scaled_grad;
incoming_v[ii] = b2*incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad;
float denom;
if (mode == ADAM_MODE_0)
denom = sqrtf(incoming_v[ii] + eps);
else // Mode 1
denom = sqrtf(incoming_v[ii]) + eps;
float update = (incoming_m[ii]/denom) + (decay*incoming_p[ii]);
incoming_p[ii] = incoming_p[ii] - (step_size*update);
if (DEPTH == 5) tmp_g[ii] = static_cast<GRAD_T>(incoming_p[ii]);
}
load_store(p, incoming_p, i_start, 0);
load_store(m, incoming_m, i_start, 0);
load_store(v, incoming_v, i_start, 0);
if (DEPTH == 5) load_store(p_copy, tmp_g, i_start, 0);
}
}
else
{
for(int i_start = 0;
i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP) {
#pragma unroll #pragma unroll
for(int ii = 0; ii < ILP; ii++) { for(int ii = 0; ii < ILP; ii++) {
incoming_p[ii] = 0; incoming_p[ii] = 0;
incoming_m[ii] = 0; incoming_m[ii] = 0;
incoming_v[ii] = 0; incoming_v[ii] = 0;
incoming_g[ii] = 0; incoming_g[ii] = 0;
int i = i_start + threadIdx.x + ii*blockDim.x; int i = i_start + threadIdx.x + ii*blockDim.x;
if (i < n && i < chunk_size) { if (i < n && i < chunk_size) {
incoming_p[ii] = p[i]; incoming_p[ii] = p[i];
incoming_m[ii] = m[i]; incoming_m[ii] = m[i];
incoming_v[ii] = v[i]; incoming_v[ii] = v[i];
incoming_g[ii] = static_cast<T>(g[i]); incoming_g[ii] = static_cast<T>(g[i]);
} }
} }
// note for clarification to future michael: // note for clarification to future michael:
...@@ -124,24 +175,25 @@ struct AdamFunctor ...@@ -124,24 +175,25 @@ struct AdamFunctor
// the write loop, since writes just fire off once their LDGs arrive. // the write loop, since writes just fire off once their LDGs arrive.
// Put another way, the STGs are dependent on the LDGs, but not on each other. // Put another way, the STGs are dependent on the LDGs, but not on each other.
// There is still compute ILP benefit from unrolling the loop though. // There is still compute ILP benefit from unrolling the loop though.
#pragma unroll #pragma unroll
for(int ii = 0; ii < ILP; ii++) { for(int ii = 0; ii < ILP; ii++) {
int j = i_start + threadIdx.x + ii*blockDim.x; int j = i_start + threadIdx.x + ii*blockDim.x;
if(j < n && j < chunk_size) { if(j < n && j < chunk_size) {
T scaled_grad = incoming_g[ii]/grad_scale; T scaled_grad = incoming_g[ii]/grad_scale;
m[j] = b1*incoming_m[ii] + (1-b1)*scaled_grad; m[j] = b1*incoming_m[ii] + (1-b1)*scaled_grad;
v[j] = b2*incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad; v[j] = b2*incoming_v[ii] + (1-b2)*scaled_grad*scaled_grad;
float denom; float denom;
if (mode == ADAM_MODE_0) if (mode == ADAM_MODE_0)
denom = sqrtf(v[j] + eps); denom = sqrtf(v[j] + eps);
else // Mode 1 else // Mode 1
denom = sqrtf(v[j]) + eps; denom = sqrtf(v[j]) + eps;
float update = (m[j]/denom) + (decay*incoming_p[ii]); float update = (m[j]/denom) + (decay*incoming_p[ii]);
p[j] = incoming_p[ii] - (step_size*update); p[j] = incoming_p[ii] - (step_size*update);
if (DEPTH == 5) p_copy[j] = (GRAD_T) p[j]; if (DEPTH == 5) p_copy[j] = (GRAD_T) p[j];
} }
} }
}
} }
} }
}; };
...@@ -332,4 +384,3 @@ void fused_adam_cuda_mt( ...@@ -332,4 +384,3 @@ void fused_adam_cuda_mt(
} }
THCudaCheck(cudaGetLastError()); THCudaCheck(cudaGetLastError());
} }
...@@ -13,6 +13,17 @@ ...@@ -13,6 +13,17 @@
#define BLOCK_SIZE 512 #define BLOCK_SIZE 512
#define ILP 4 #define ILP 4
template<typename T>
__device__ __forceinline__ bool is_aligned(T* p){
return ((uint64_t)p) % (ILP*sizeof(T)) == 0;
}
template<typename T>
__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){
typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LT;
((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
}
template<typename x_t, typename y_t, typename out_t> template<typename x_t, typename y_t, typename out_t>
struct AxpbyFunctor struct AxpbyFunctor
{ {
...@@ -43,46 +54,74 @@ struct AxpbyFunctor ...@@ -43,46 +54,74 @@ struct AxpbyFunctor
n -= chunk_idx*chunk_size; n -= chunk_idx*chunk_size;
// Non-divergent exit condition for __syncthreads, not necessary here bool finite = true;
float xs[ILP]; x_t r_x[ILP];
float ys[ILP]; y_t r_y[ILP];
for(int i_start = 0; out_t r_out[ILP];
i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP) // to make things simple, we put aligned case in a different code path
if(n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x) && is_aligned(y) && is_aligned(out))
{ {
#pragma unroll for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x)
for(int ii = 0; ii < ILP; ii++)
{ {
xs[ii] = 0; // load
ys[ii] = 0; load_store(r_x, x, 0 , i_start);
int i = i_start + threadIdx.x + ii*blockDim.x; load_store(r_y, y, 0 , i_start);
if(i < n && i < chunk_size) #pragma unroll
for(int ii = 0; ii < ILP; ii++)
{ {
xs[ii] = static_cast<float>(x[i]); r_out[ii] = a*static_cast<float>(r_x[ii]) + b*static_cast<float>(r_y[ii]);
ys[ii] = static_cast<float>(y[i]); if(arg_to_check == -1)
finite = finite && (isfinite(r_x[ii]) && isfinite(r_y[ii]));
if(arg_to_check == 0)
finite = finite && isfinite(r_x[ii]);
if(arg_to_check == 1)
finite = finite && isfinite(r_y[ii]);
} }
// store
load_store(out, r_out, i_start , 0);
} }
}
// see note in multi_tensor_scale_kernel.cu else
#pragma unroll {
for(int ii = 0; ii < ILP; ii++) // Non-divergent exit condition for __syncthreads, not necessary here
for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x*ILP)
{ {
int i = i_start + threadIdx.x + ii*blockDim.x; #pragma unroll
if(i < n && i < chunk_size) for(int ii = 0; ii < ILP; ii++)
{ {
out[i] = static_cast<out_t>(a*xs[ii] + b*ys[ii]); r_x[ii] = 0;
bool finite = true; r_y[ii] = 0;
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
r_x[ii] = x[i];
r_y[ii] = y[i];
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
r_out[ii] = a*static_cast<float>(r_x[ii]) + b*static_cast<float>(r_y[ii]);
if(arg_to_check == -1) if(arg_to_check == -1)
finite = (isfinite(xs[ii]) && isfinite(ys[ii])); finite = finite && (isfinite(r_x[ii]) && isfinite(r_y[ii]));
if(arg_to_check == 0) if(arg_to_check == 0)
finite = isfinite(xs[ii]); finite = finite && isfinite(r_x[ii]);
if(arg_to_check == 1) if(arg_to_check == 1)
finite = isfinite(ys[ii]); finite = finite && isfinite(r_y[ii]);
if(!finite) }
*noop_gmem = 1; // Blindly fire off a write. These will race but that's ok. // see note in multi_tensor_scale_kernel.cu
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
out[i] = r_out[ii];
} }
} }
} }
if(!finite)
*noop_gmem = 1; // Blindly fire off a write. These will race but that's ok.
} }
}; };
......
...@@ -13,6 +13,17 @@ ...@@ -13,6 +13,17 @@
#define BLOCK_SIZE 512 #define BLOCK_SIZE 512
#define ILP 4 #define ILP 4
template<typename T>
__device__ __forceinline__ bool is_aligned(T* p){
return ((uint64_t)p) % (ILP*sizeof(T)) == 0;
}
template<typename T>
__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){
typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LT;
((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
}
template<typename x_t> template<typename x_t>
struct L2NormFunctor struct L2NormFunctor
{ {
...@@ -41,22 +52,44 @@ struct L2NormFunctor ...@@ -41,22 +52,44 @@ struct L2NormFunctor
__shared__ float s_vals[512]; __shared__ float s_vals[512];
float vals[ILP]; // = {0}; // this probably works too but I want to be sure... float vals[ILP]; // = {0}; // this probably works too but I want to be sure...
x_t r_x[ILP];
for(int i = 0; i < ILP; i++) for(int i = 0; i < ILP; i++)
{
vals[i] = 0.f; vals[i] = 0.f;
r_x[i] = 0;
}
for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x*ILP) // to make things simple, we put aligned case in a different code path
if(n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x))
{ {
#pragma unroll for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x)
for(int ii = 0; ii < ILP; ii++)
{ {
int i = i_start + threadIdx.x + ii*blockDim.x; // load
if(i < n && i < chunk_size) load_store(r_x, x, 0 , i_start);
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{ {
float next = static_cast<float>(x[i]); float next = static_cast<float>(r_x[ii]);
vals[ii] += next*next; vals[ii] += next*next;
} }
} }
} }
else
{
for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x*ILP)
{
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
float next = static_cast<float>(x[i]);
vals[ii] += next*next;
}
}
}
}
float val = 0.f; float val = 0.f;
for(int i = 0; i < ILP; i++) for(int i = 0; i < ILP; i++)
...@@ -104,22 +137,44 @@ struct MaxNormFunctor ...@@ -104,22 +137,44 @@ struct MaxNormFunctor
__shared__ float s_vals[512]; __shared__ float s_vals[512];
float vals[ILP]; // = {0}; // this probably works too but I want to be sure... float vals[ILP]; // = {0}; // this probably works too but I want to be sure...
x_t r_x[ILP];
for(int i = 0; i < ILP; i++) for(int i = 0; i < ILP; i++)
{
vals[i] = 0.f; vals[i] = 0.f;
r_x[i] = 0;
}
for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x*ILP) // to make things simple, we put aligned case in a different code path
if(n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(x))
{ {
#pragma unroll for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x)
for(int ii = 0; ii < ILP; ii++)
{ {
int i = i_start + threadIdx.x + ii*blockDim.x; // load
if(i < n && i < chunk_size) load_store(r_x, x, 0 , i_start);
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{ {
float next = static_cast<float>(x[i]); float next = static_cast<float>(r_x[ii]);
vals[ii] = fmaxf(fabsf(vals[ii]), fabsf(next)); vals[ii] = fmaxf(fabsf(vals[ii]), fabsf(next));
} }
} }
} }
else
{
for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x*ILP)
{
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
float next = static_cast<float>(x[i]);
vals[ii] = fmaxf(fabsf(vals[ii]), fabsf(next));
}
}
}
}
float val = 0.f; float val = 0.f;
for(int i = 0; i < ILP; i++) for(int i = 0; i < ILP; i++)
......
...@@ -13,6 +13,17 @@ ...@@ -13,6 +13,17 @@
#define BLOCK_SIZE 512 #define BLOCK_SIZE 512
#define ILP 4 #define ILP 4
template<typename T>
__device__ __forceinline__ bool is_aligned(T* p){
return ((uint64_t)p) % (ILP*sizeof(T)) == 0;
}
template<typename T>
__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){
typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LT;
((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
}
typedef enum{ typedef enum{
MOMENT_MODE_0 =0, // L2 regularization mode MOMENT_MODE_0 =0, // L2 regularization mode
MOMENT_MODE_1 =1 // Decoupled weight decay mode MOMENT_MODE_1 =1 // Decoupled weight decay mode
...@@ -68,71 +79,149 @@ struct LAMBStage1Functor ...@@ -68,71 +79,149 @@ struct LAMBStage1Functor
n -= chunk_idx*chunk_size; n -= chunk_idx*chunk_size;
// see note in multi_tensor_scale_kernel.cu MATH_T r_g[ILP];
for(int i_start = 0; MATH_T r_p[ILP];
i_start < n && i_start < chunk_size; MATH_T r_m[ILP];
i_start += blockDim.x*ILP) MATH_T r_v[ILP];
// to make things simple, we put aligned case in a different code path
if(n % ILP == 0 &&
chunk_size % ILP == 0 &&
is_aligned(g) &&
is_aligned(p) &&
is_aligned(m) &&
is_aligned(v))
{ {
MATH_T r_g[ILP]; T l_g[ILP];
MATH_T r_p[ILP]; T l_p[ILP];
MATH_T r_m[ILP]; T l_m[ILP];
MATH_T r_v[ILP]; T l_v[ILP];
#pragma unroll for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x)
for(int ii = 0; ii < ILP; ii++)
{ {
int i = i_start + threadIdx.x + ii*blockDim.x; // load
if(i < n && i < chunk_size) load_store(l_g, g, 0, i_start);
if (decay != 0)
load_store(l_p, p, 0, i_start);
load_store(l_m, m, 0, i_start);
load_store(l_v, v, 0, i_start);
// unpack
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{ {
r_g[ii] = g[i]; r_g[ii] = l_g[ii];
// special ?optimization? for lamb stage 1
if (decay == 0) { if (decay == 0) {
r_p[ii] = MATH_T(0); r_p[ii] = MATH_T(0);
} }
else { else {
r_p[ii] = p[i]; r_p[ii] = l_p[ii];
} }
r_m[ii] = m[i]; r_m[ii] = l_m[ii];
r_v[ii] = v[i]; r_v[ii] = l_v[ii];
} else {
r_g[ii] = MATH_T(0);
r_p[ii] = MATH_T(0);
r_m[ii] = MATH_T(0);
r_v[ii] = MATH_T(0);
} }
}
#pragma unroll #pragma unroll
for(int ii = 0; ii < ILP; ii++) for(int ii = 0; ii < ILP; ii++)
{ {
if (mode == MOMENT_MODE_0) { if (mode == MOMENT_MODE_0) {
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm; MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
// L2 on scaled grad // L2 on scaled grad
scaled_grad = scaled_grad + decay*r_p[ii]; scaled_grad = scaled_grad + decay*r_p[ii];
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad; r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad; r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
MATH_T next_m_unbiased = r_m[ii] / beta1_correction; MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction; MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon; MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
r_p[ii] = next_m_unbiased / denom; r_p[ii] = next_m_unbiased / denom;
}
else {
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
r_p[ii] = (next_m_unbiased/denom) + (decay*r_p[ii]);
}
} }
else { #pragma unroll
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm; for(int ii = 0; ii < ILP; ii++)
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad; {
r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad; l_p[ii] = r_p[ii];
MATH_T next_m_unbiased = r_m[ii] / beta1_correction; l_m[ii] = r_m[ii];
MATH_T next_v_unbiased = r_v[ii] / beta2_correction; l_v[ii] = r_v[ii];
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
r_p[ii] = (next_m_unbiased/denom) + (decay*r_p[ii]);
} }
// store
load_store(g, l_p, i_start, 0);
load_store(m, l_m, i_start, 0);
load_store(v, l_v, i_start, 0);
} }
#pragma unroll }
for(int ii = 0; ii < ILP; ii++) else
{
// see note in multi_tensor_scale_kernel.cu
for(int i_start = 0;
i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP)
{ {
int i = i_start + threadIdx.x + ii*blockDim.x; MATH_T r_g[ILP];
if(i < n && i < chunk_size) MATH_T r_p[ILP];
MATH_T r_m[ILP];
MATH_T r_v[ILP];
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{ {
g[i] = r_p[ii]; int i = i_start + threadIdx.x + ii*blockDim.x;
m[i] = r_m[ii]; if(i < n && i < chunk_size)
v[i] = r_v[ii]; {
r_g[ii] = g[i];
// special ?optimization? for lamb stage 1
if (decay == 0) {
r_p[ii] = MATH_T(0);
}
else {
r_p[ii] = p[i];
}
r_m[ii] = m[i];
r_v[ii] = v[i];
} else {
r_g[ii] = MATH_T(0);
r_p[ii] = MATH_T(0);
r_m[ii] = MATH_T(0);
r_v[ii] = MATH_T(0);
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
if (mode == MOMENT_MODE_0) {
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
// L2 on scaled grad
scaled_grad = scaled_grad + decay*r_p[ii];
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
r_p[ii] = next_m_unbiased / denom;
}
else {
MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
r_p[ii] = (next_m_unbiased/denom) + (decay*r_p[ii]);
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
g[i] = r_p[ii];
m[i] = r_m[ii];
v[i] = r_v[ii];
}
} }
} }
} }
...@@ -173,34 +262,58 @@ struct LAMBStage2Functor ...@@ -173,34 +262,58 @@ struct LAMBStage2Functor
n -= chunk_idx*chunk_size; n -= chunk_idx*chunk_size;
for(int i_start = 0; // to make things simple, we put aligned case in a different code path
i_start < n && i_start < chunk_size; if(n % ILP == 0 &&
i_start += blockDim.x*ILP) chunk_size % ILP == 0 &&
is_aligned(p) &&
is_aligned(update))
{ {
MATH_T r_p[ILP]; T r_p[ILP];
MATH_T r_update[ILP]; T r_update[ILP];
#pragma unroll for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x)
for(int ii = 0; ii < ILP; ii++)
{ {
int i = i_start + threadIdx.x + ii*blockDim.x; // load
if(i < n && i < chunk_size) load_store(r_p, p, 0, i_start);
load_store(r_update, update, 0, i_start);
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{ {
r_p[ii] = p[i]; r_p[ii] = static_cast<MATH_T>(r_p[ii]) - (ratio * static_cast<MATH_T>(r_update[ii]));
r_update[ii] = update[i];
} }
load_store(p, r_p, i_start, 0);
} }
#pragma unroll }
for(int ii = 0; ii < ILP; ii++) else
{
for(int i_start = 0;
i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP)
{ {
r_p[ii] = r_p[ii] - (ratio * r_update[ii]); MATH_T r_p[ILP];
} MATH_T r_update[ILP];
#pragma unroll #pragma unroll
for(int ii = 0; ii < ILP; ii++) for(int ii = 0; ii < ILP; ii++)
{ {
int i = i_start + threadIdx.x + ii*blockDim.x; int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size) if(i < n && i < chunk_size)
{
r_p[ii] = p[i];
r_update[ii] = update[i];
}
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
r_p[ii] = r_p[ii] - (ratio * r_update[ii]);
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{ {
p[i] = r_p[ii]; int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
{
p[i] = r_p[ii];
}
} }
} }
} }
......
...@@ -15,6 +15,17 @@ ...@@ -15,6 +15,17 @@
#define BLOCK_SIZE 512 #define BLOCK_SIZE 512
#define ILP 4 #define ILP 4
template<typename T>
__device__ __forceinline__ bool is_aligned(T* p){
return ((uint64_t)p) % (ILP*sizeof(T)) == 0;
}
template<typename T>
__device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){
typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LT;
((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
}
template<typename in_t, typename out_t> template<typename in_t, typename out_t>
struct ScaleFunctor struct ScaleFunctor
{ {
...@@ -34,44 +45,68 @@ struct ScaleFunctor ...@@ -34,44 +45,68 @@ struct ScaleFunctor
in_t* in = (in_t*)tl.addresses[0][tensor_loc]; in_t* in = (in_t*)tl.addresses[0][tensor_loc];
in += chunk_idx*chunk_size; in += chunk_idx*chunk_size;
out_t* out = (out_t*)tl.addresses[1][tensor_loc]; out_t* out = (out_t*)tl.addresses[1][tensor_loc];
out += chunk_idx*chunk_size; out += chunk_idx*chunk_size;
n -= chunk_idx*chunk_size; n -= chunk_idx*chunk_size;
// Non-divergent exit condition for __syncthreads, not necessary here bool finite = true;
float incoming_vals[ILP]; in_t r_in[ILP];
for(int i_start = 0; out_t r_out[ILP];
i_start < n && i_start < chunk_size;
i_start += blockDim.x*ILP) // to make things simple, we put aligned case in a different code path
if(n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(in) && is_aligned(out))
{ {
#pragma unroll for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x)
for(int ii = 0; ii < ILP; ii++)
{ {
incoming_vals[ii] = 0; // load
int i = i_start + threadIdx.x + ii*blockDim.x; load_store(r_in, in, 0 , i_start);
if(i < n && i < chunk_size) #pragma unroll
incoming_vals[ii] = static_cast<float>(in[i]); for(int ii = 0; ii < ILP; ii++)
{
r_out[ii] = static_cast<float>(r_in[ii]) * scale;
finite = finite && isfinite(r_in[ii]);
}
// store
load_store(out, r_out, i_start, 0);
} }
}
// note for clarification to future michael: else
// From a pure memory dependency perspective, there's likely no point unrolling {
// the write loop, since writes just fire off once their LDGs arrive. // Non-divergent exit condition for __syncthreads, not necessary here
// Put another way, the STGs are dependent on the LDGs, but not on each other. for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x*ILP)
// There is still compute ILP benefit from unrolling the loop though.
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{ {
int i = i_start + threadIdx.x + ii*blockDim.x; #pragma unroll
if(i < n && i < chunk_size) for(int ii = 0; ii < ILP; ii++)
{
r_in[ii] = 0;
int i = i_start + threadIdx.x + ii*blockDim.x;
if(i < n && i < chunk_size)
r_in[ii] = in[i];
}
// note for clarification to future michael:
// From a pure memory dependency perspective, there's likely no point unrolling
// the write loop, since writes just fire off once their LDGs arrive.
// Put another way, the STGs are dependent on the LDGs, but not on each other.
// There is still compute ILP benefit from unrolling the loop though.
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{
r_out[ii] = static_cast<float>(r_in[ii]) * scale;
finite = finite && isfinite(r_in[ii]);
}
#pragma unroll
for(int ii = 0; ii < ILP; ii++)
{ {
out[i] = static_cast<out_t>(incoming_vals[ii]*scale); int i = i_start + threadIdx.x + ii*blockDim.x;
if(!isfinite(incoming_vals[ii])) if(i < n && i < chunk_size)
*noop_gmem = 1; // Blindly fire off a write. These will race but that's ok. out[i] = r_out[ii];
} }
} }
} }
if(!finite)
*noop_gmem = 1; // Blindly fire off a write. These will race but that's ok.
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment