Commit fbb8cd93 authored by Jeff Daily's avatar Jeff Daily
Browse files

Revert "pass all TensorListMetadata as pointer to pinned host memory (#13)"

This reverts commit bdd481d1.
parent 3f49dbf0
......@@ -76,7 +76,7 @@ struct AdamFunctor
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<DEPTH>* tl,
TensorListMetadata<DEPTH>& tl,
const float b1,
const float b2,
const float eps,
......@@ -85,21 +85,21 @@ struct AdamFunctor
adamMode_t mode,
const float decay)
{
int tensor_loc = tl->block_to_tensor[blockIdx.x];
int chunk_idx = tl->block_to_chunk[blockIdx.x];
int n = tl->sizes[tensor_loc];
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
T* p = (T *)tl->addresses[0][tensor_loc];
T* p = (T *)tl.addresses[0][tensor_loc];
p += chunk_idx*chunk_size;
T* m = (T *)tl->addresses[1][tensor_loc];
T* m = (T *)tl.addresses[1][tensor_loc];
m += chunk_idx*chunk_size;
T* v = (T *)tl->addresses[2][tensor_loc];
T* v = (T *)tl.addresses[2][tensor_loc];
v += chunk_idx*chunk_size;
GRAD_T* g = (GRAD_T *)tl->addresses[3][tensor_loc];
GRAD_T* g = (GRAD_T *)tl.addresses[3][tensor_loc];
g += chunk_idx*chunk_size;
GRAD_T* p_copy = NULL;
if (DEPTH == 5) {
p_copy = (GRAD_T *)tl->addresses[4][tensor_loc];
p_copy = (GRAD_T *)tl.addresses[4][tensor_loc];
p_copy += chunk_idx*chunk_size;
}
......@@ -736,17 +736,17 @@ struct MaybeCastFunctor
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* overflow_flag,
TensorListMetadata<DEPTH>* tl)
TensorListMetadata<DEPTH>& tl)
{
if (overflow_flag && *overflow_flag != 0) return;
int tensor_loc = tl->block_to_tensor[blockIdx.x];
int chunk_idx = tl->block_to_chunk[blockIdx.x];
int n = tl->sizes[tensor_loc];
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
FROM_T* p_in = (FROM_T *)tl->addresses[0][tensor_loc];
FROM_T* p_in = (FROM_T *)tl.addresses[0][tensor_loc];
p_in += chunk_idx*chunk_size;
TO_T* p_out = (TO_T *)tl->addresses[1][tensor_loc];
TO_T* p_out = (TO_T *)tl.addresses[1][tensor_loc];
p_out += chunk_idx*chunk_size;
n -= chunk_idx*chunk_size;
......
......@@ -32,7 +32,7 @@ struct LAMBStage1Functor
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<4>* tl,
TensorListMetadata<4>& tl,
const float beta1,
const float beta2,
const float beta3,
......@@ -48,22 +48,22 @@ struct LAMBStage1Functor
// if(*noop_gmem == 1)
// return;
int tensor_loc = tl->block_to_tensor[blockIdx.x];
int chunk_idx = tl->block_to_chunk[blockIdx.x];
int n = tl->sizes[tensor_loc];
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
float clipped_global_grad_norm = global_grad_norm > max_global_grad_norm ? global_grad_norm / max_global_grad_norm : 1.0f;
T* g = (T*)tl->addresses[0][tensor_loc];
T* g = (T*)tl.addresses[0][tensor_loc];
g += chunk_idx*chunk_size;
T* p = (T*)tl->addresses[1][tensor_loc];
T* p = (T*)tl.addresses[1][tensor_loc];
p += chunk_idx*chunk_size;
T* m = (T*)tl->addresses[2][tensor_loc];
T* m = (T*)tl.addresses[2][tensor_loc];
m += chunk_idx*chunk_size;
T* v = (T*)tl->addresses[3][tensor_loc];
T* v = (T*)tl.addresses[3][tensor_loc];
v += chunk_idx*chunk_size;
n -= chunk_idx*chunk_size;
......@@ -147,7 +147,7 @@ struct LAMBStage2Functor
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<2>* tl,
TensorListMetadata<2>& tl,
const float* per_tensor_param_norm,
const float* per_tensor_update_norm,
const float learning_rate,
......@@ -157,10 +157,10 @@ struct LAMBStage2Functor
// if(*noop_gmem == 1)
// return;
int tensor_loc = tl->block_to_tensor[blockIdx.x];
int tensor_num = tl->start_tensor_this_launch + tensor_loc;
int chunk_idx = tl->block_to_chunk[blockIdx.x];
int n = tl->sizes[tensor_loc];
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int tensor_num = tl.start_tensor_this_launch + tensor_loc;
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
MATH_T ratio = learning_rate;
// apply adaptive learning rate to parameters with non-zero weight decay
......@@ -171,10 +171,10 @@ struct LAMBStage2Functor
ratio = (update_norm != 0.0f && param_norm != 0.0f) ? learning_rate * (param_norm / update_norm) : learning_rate;
}
T* update = (T*)tl->addresses[0][tensor_loc];
T* update = (T*)tl.addresses[0][tensor_loc];
update += chunk_idx*chunk_size;
T* p = (T*)tl->addresses[1][tensor_loc];
T* p = (T*)tl.addresses[1][tensor_loc];
p += chunk_idx*chunk_size;
n -= chunk_idx*chunk_size;
......
......@@ -23,20 +23,20 @@ using MATH_T = float;
template <typename T> struct AdagradFunctor {
__device__ __forceinline__ void
operator()(int chunk_size, volatile int *noop_gmem, TensorListMetadata<3> *tl,
operator()(int chunk_size, volatile int *noop_gmem, TensorListMetadata<3> &tl,
const float epsilon, const float lr, adagradMode_t mode,
const float weight_decay) {
int tensor_loc = tl->block_to_tensor[blockIdx.x];
int chunk_idx = tl->block_to_chunk[blockIdx.x];
int n = tl->sizes[tensor_loc];
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
T *g = (T *)tl->addresses[0][tensor_loc];
T *g = (T *)tl.addresses[0][tensor_loc];
g += chunk_idx * chunk_size;
T *p = (T *)tl->addresses[1][tensor_loc];
T *p = (T *)tl.addresses[1][tensor_loc];
p += chunk_idx * chunk_size;
T *h = (T *)tl->addresses[2][tensor_loc];
T *h = (T *)tl.addresses[2][tensor_loc];
h += chunk_idx * chunk_size;
n -= chunk_idx * chunk_size;
......
......@@ -26,7 +26,7 @@ struct AdamFunctor
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<4>* tl,
TensorListMetadata<4>& tl,
const float beta1,
const float beta2,
const float beta1_correction,
......@@ -40,24 +40,24 @@ struct AdamFunctor
// if(*noop_gmem == 1)
// return;
int tensor_loc = tl->block_to_tensor[blockIdx.x];
int tensor_loc = tl.block_to_tensor[blockIdx.x];
// potentially use to pass in list of scalar
// int tensor_num = tl->start_tensor_this_launch + tensor_loc;
// int tensor_num = tl.start_tensor_this_launch + tensor_loc;
int chunk_idx = tl->block_to_chunk[blockIdx.x];
int n = tl->sizes[tensor_loc];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
T* g = (T*)tl->addresses[0][tensor_loc];
T* g = (T*)tl.addresses[0][tensor_loc];
g += chunk_idx*chunk_size;
T* p = (T*)tl->addresses[1][tensor_loc];
T* p = (T*)tl.addresses[1][tensor_loc];
p += chunk_idx*chunk_size;
T* m = (T*)tl->addresses[2][tensor_loc];
T* m = (T*)tl.addresses[2][tensor_loc];
m += chunk_idx*chunk_size;
T* v = (T*)tl->addresses[3][tensor_loc];
T* v = (T*)tl.addresses[3][tensor_loc];
v += chunk_idx*chunk_size;
n -= chunk_idx*chunk_size;
......
......@@ -34,7 +34,7 @@ __launch_bounds__(1024)
__global__ void multi_tensor_apply_kernel(
int chunk_size,
volatile int* noop_flag,
T* tl,
T tl,
U callable,
ArgTypes... args)
{
......@@ -111,15 +111,11 @@ void multi_tensor_apply(
bool last_chunk = (t == ntensors - 1 && chunk == chunks_this_tensor - 1);
if(tensors_full || blocks_full || last_chunk)
{
auto storage = at::empty(sizeof(tl), c10::TensorOptions(at::kStrided).dtype(at::kByte).device(at::kCPU).pinned_memory(true));
auto tl_as_host_pinned_ptr = static_cast<decltype(tl)*>(storage.data_ptr());
memcpy(tl_as_host_pinned_ptr, &tl, sizeof(tl));
AT_CUDA_CHECK(THCCachingHostAllocator_recordEvent(tl_as_host_pinned_ptr, stream));
// using accscalar_t = acc_type<scalar_t, true>;
multi_tensor_apply_kernel<<<loc_block_info, block_size, 0, stream>>>(
chunk_size,
noop_flag.DATA_PTR<int>(),
tl_as_host_pinned_ptr,
tl,
callable,
args...);
......
......@@ -30,7 +30,7 @@ struct AxpbyFunctor
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<3>* tl,
TensorListMetadata<3>& tl,
float a,
float b,
int arg_to_check)
......@@ -39,17 +39,17 @@ struct AxpbyFunctor
// if(*noop_gmem == 1)
// return;
int tensor_loc = tl->block_to_tensor[blockIdx.x];
int chunk_idx = tl->block_to_chunk[blockIdx.x];
int n = tl->sizes[tensor_loc];
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
x_t* x = (x_t*)tl->addresses[0][tensor_loc];
x_t* x = (x_t*)tl.addresses[0][tensor_loc];
x += chunk_idx*chunk_size;
y_t* y = (y_t*)tl->addresses[1][tensor_loc];
y_t* y = (y_t*)tl.addresses[1][tensor_loc];
y += chunk_idx*chunk_size;
out_t* out = (out_t*)tl->addresses[2][tensor_loc];
out_t* out = (out_t*)tl.addresses[2][tensor_loc];
out += chunk_idx*chunk_size;
n -= chunk_idx*chunk_size;
......
......@@ -31,7 +31,7 @@ struct L2NormFunctor
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<1>* tl,
TensorListMetadata<1>& tl,
float* output,
float* output_per_tensor,
bool per_tensor,
......@@ -41,11 +41,11 @@ struct L2NormFunctor
// if(*noop_gmem == 1)
// return;
int tensor_loc = tl->block_to_tensor[blockIdx.x];
int chunk_idx = tl->block_to_chunk[blockIdx.x];
int n = tl->sizes[tensor_loc];
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
x_t* x = (x_t*)tl->addresses[0][tensor_loc];
x_t* x = (x_t*)tl.addresses[0][tensor_loc];
x += chunk_idx*chunk_size;
n -= chunk_idx*chunk_size;
......@@ -104,7 +104,7 @@ struct L2NormFunctor
*noop_gmem = 1; // Blindly fire off a write. These will race but that's ok.
output[blockIdx.x] += final;
if(per_tensor)
output_per_tensor[(tl->start_tensor_this_launch + tensor_loc)*max_chunks_per_tensor + chunk_idx] = final;
output_per_tensor[(tl.start_tensor_this_launch + tensor_loc)*max_chunks_per_tensor + chunk_idx] = final;
}
}
};
......@@ -116,7 +116,7 @@ struct MaxNormFunctor
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<1>* tl,
TensorListMetadata<1>& tl,
float* output,
float* output_per_tensor,
bool per_tensor,
......@@ -126,11 +126,11 @@ struct MaxNormFunctor
// if(*noop_gmem == 1)
// return;
int tensor_loc = tl->block_to_tensor[blockIdx.x];
int chunk_idx = tl->block_to_chunk[blockIdx.x];
int n = tl->sizes[tensor_loc];
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
x_t* x = (x_t*)tl->addresses[0][tensor_loc];
x_t* x = (x_t*)tl.addresses[0][tensor_loc];
x += chunk_idx*chunk_size;
n -= chunk_idx*chunk_size;
......@@ -189,7 +189,7 @@ struct MaxNormFunctor
*noop_gmem = 1; // Blindly fire off a write. These will race but that's ok.
output[blockIdx.x] = fmaxf(fabsf(output[blockIdx.x]), fabsf(final));
if(per_tensor)
output_per_tensor[(tl->start_tensor_this_launch + tensor_loc)*max_chunks_per_tensor + chunk_idx] = final;
output_per_tensor[(tl.start_tensor_this_launch + tensor_loc)*max_chunks_per_tensor + chunk_idx] = final;
}
}
};
......
......@@ -43,7 +43,7 @@ struct LAMBStage1Functor
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<4>* tl,
TensorListMetadata<4>& tl,
const float beta1,
const float beta2,
const float beta3,
......@@ -59,22 +59,22 @@ struct LAMBStage1Functor
// if(*noop_gmem == 1)
// return;
int tensor_loc = tl->block_to_tensor[blockIdx.x];
int chunk_idx = tl->block_to_chunk[blockIdx.x];
int n = tl->sizes[tensor_loc];
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
float clipped_global_grad_norm = (*global_grad_norm) > max_global_grad_norm ? (*global_grad_norm) / max_global_grad_norm : 1.0f;
T* g = (T*)tl->addresses[0][tensor_loc];
T* g = (T*)tl.addresses[0][tensor_loc];
g += chunk_idx*chunk_size;
T* p = (T*)tl->addresses[1][tensor_loc];
T* p = (T*)tl.addresses[1][tensor_loc];
p += chunk_idx*chunk_size;
T* m = (T*)tl->addresses[2][tensor_loc];
T* m = (T*)tl.addresses[2][tensor_loc];
m += chunk_idx*chunk_size;
T* v = (T*)tl->addresses[3][tensor_loc];
T* v = (T*)tl.addresses[3][tensor_loc];
v += chunk_idx*chunk_size;
n -= chunk_idx*chunk_size;
......@@ -236,7 +236,7 @@ struct LAMBStage2Functor
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<2>* tl,
TensorListMetadata<2>& tl,
const float* per_tensor_param_norm,
const float* per_tensor_update_norm,
const float learning_rate,
......@@ -247,10 +247,10 @@ struct LAMBStage2Functor
// if(*noop_gmem == 1)
// return;
int tensor_loc = tl->block_to_tensor[blockIdx.x];
int tensor_num = tl->start_tensor_this_launch + tensor_loc;
int chunk_idx = tl->block_to_chunk[blockIdx.x];
int n = tl->sizes[tensor_loc];
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int tensor_num = tl.start_tensor_this_launch + tensor_loc;
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
MATH_T ratio = learning_rate;
// nvlamb: apply adaptive learning rate to all parameters
......@@ -262,10 +262,10 @@ struct LAMBStage2Functor
ratio = (update_norm != 0.0f && param_norm != 0.0f) ? learning_rate * (param_norm / update_norm) : learning_rate;
}
T* update = (T*)tl->addresses[0][tensor_loc];
T* update = (T*)tl.addresses[0][tensor_loc];
update += chunk_idx*chunk_size;
T* p = (T*)tl->addresses[1][tensor_loc];
T* p = (T*)tl.addresses[1][tensor_loc];
p += chunk_idx*chunk_size;
n -= chunk_idx*chunk_size;
......
......@@ -20,7 +20,7 @@ struct LAMBStage1Functor
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<5>* tl,
TensorListMetadata<5>& tl,
const float* per_tensor_decay,
const float beta1,
const float beta2,
......@@ -33,26 +33,26 @@ struct LAMBStage1Functor
// if(*noop_gmem == 1)
// return;
int tensor_loc = tl->block_to_tensor[blockIdx.x];
int tensor_num = tl->start_tensor_this_launch + tensor_loc;
int chunk_idx = tl->block_to_chunk[blockIdx.x];
int n = tl->sizes[tensor_loc];
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int tensor_num = tl.start_tensor_this_launch + tensor_loc;
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
float decay = per_tensor_decay[tensor_num];
GRAD_T* g = (GRAD_T*)tl->addresses[0][tensor_loc];
GRAD_T* g = (GRAD_T*)tl.addresses[0][tensor_loc];
g += chunk_idx*chunk_size;
T* p = (T*)tl->addresses[1][tensor_loc];
T* p = (T*)tl.addresses[1][tensor_loc];
p += chunk_idx*chunk_size;
T* m = (T*)tl->addresses[2][tensor_loc];
T* m = (T*)tl.addresses[2][tensor_loc];
m += chunk_idx*chunk_size;
T* v = (T*)tl->addresses[3][tensor_loc];
T* v = (T*)tl.addresses[3][tensor_loc];
v += chunk_idx*chunk_size;
UPD_T* update = (UPD_T*)tl->addresses[4][tensor_loc];
UPD_T* update = (UPD_T*)tl.addresses[4][tensor_loc];
update += chunk_idx*chunk_size;
n -= chunk_idx*chunk_size;
......
......@@ -23,7 +23,7 @@ struct LAMBStage2Functor
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<2>* tl,
TensorListMetadata<2>& tl,
const float* per_tensor_param_norm,
const float* per_tensor_update_norm,
const float learning_rate,
......@@ -34,10 +34,10 @@ struct LAMBStage2Functor
// if(*noop_gmem == 1)
// return;
int tensor_loc = tl->block_to_tensor[blockIdx.x];
int tensor_num = tl->start_tensor_this_launch + tensor_loc;
int chunk_idx = tl->block_to_chunk[blockIdx.x];
int n = tl->sizes[tensor_loc];
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int tensor_num = tl.start_tensor_this_launch + tensor_loc;
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
MATH_T ratio = learning_rate;
// nvlamb: apply adaptive learning rate to all parameters
......@@ -49,10 +49,10 @@ struct LAMBStage2Functor
ratio = (update_norm != 0.0f && param_norm != 0.0f) ? learning_rate * (param_norm / update_norm) : learning_rate;
}
T* p = (T*)tl->addresses[0][tensor_loc];
T* p = (T*)tl.addresses[0][tensor_loc];
p += chunk_idx*chunk_size;
UPD_T* update = (UPD_T*)tl->addresses[1][tensor_loc];
UPD_T* update = (UPD_T*)tl.addresses[1][tensor_loc];
update += chunk_idx*chunk_size;
n -= chunk_idx*chunk_size;
......
......@@ -35,7 +35,7 @@ struct NovoGradFunctor
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<3>* tl,
TensorListMetadata<3>& tl,
const float beta1,
const float beta2,
const float beta3,
......@@ -51,20 +51,20 @@ struct NovoGradFunctor
// if(*noop_gmem == 1)
// return;
int tensor_loc = tl->block_to_tensor[blockIdx.x];
int tensor_num = tl->start_tensor_this_launch + tensor_loc;
int chunk_idx = tl->block_to_chunk[blockIdx.x];
int n = tl->sizes[tensor_loc];
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int tensor_num = tl.start_tensor_this_launch + tensor_loc;
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
float grad_norm = per_tensor_grad_norm[tensor_num];
T* g = (T*)tl->addresses[0][tensor_loc];
T* g = (T*)tl.addresses[0][tensor_loc];
g += chunk_idx*chunk_size;
T* p = (T*)tl->addresses[1][tensor_loc];
T* p = (T*)tl.addresses[1][tensor_loc];
p += chunk_idx*chunk_size;
T* m = (T*)tl->addresses[2][tensor_loc];
T* m = (T*)tl.addresses[2][tensor_loc];
m += chunk_idx*chunk_size;
n -= chunk_idx*chunk_size;
......
......@@ -32,21 +32,21 @@ struct ScaleFunctor
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<2>* tl,
TensorListMetadata<2>& tl,
float scale)
{
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
int tensor_loc = tl->block_to_tensor[blockIdx.x];
int chunk_idx = tl->block_to_chunk[blockIdx.x];
int n = tl->sizes[tensor_loc];
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
in_t* in = (in_t*)tl->addresses[0][tensor_loc];
in_t* in = (in_t*)tl.addresses[0][tensor_loc];
in += chunk_idx*chunk_size;
out_t* out = (out_t*)tl->addresses[1][tensor_loc];
out_t* out = (out_t*)tl.addresses[1][tensor_loc];
out += chunk_idx*chunk_size;
n -= chunk_idx*chunk_size;
......
......@@ -32,7 +32,7 @@ struct SGDFunctor
__device__ __forceinline__ void operator()(
int chunk_size,
volatile int* noop_gmem,
TensorListMetadata<N>* tl,
TensorListMetadata<N>& tl,
float wd,
float momentum,
float dampening,
......@@ -45,23 +45,23 @@ struct SGDFunctor
// Early exit if we don't need to do anything
if (*noop_gmem) return;
int tensor_loc = tl->block_to_tensor[blockIdx.x];
int chunk_idx = tl->block_to_chunk[blockIdx.x];
int n = tl->sizes[tensor_loc];
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
T_grad* grad_in = (T_grad*)tl->addresses[0][tensor_loc];
T_grad* grad_in = (T_grad*)tl.addresses[0][tensor_loc];
grad_in += chunk_idx*chunk_size;
T_weight* weight_in = (T_weight*)tl->addresses[1][tensor_loc];
T_weight* weight_in = (T_weight*)tl.addresses[1][tensor_loc];
weight_in += chunk_idx*chunk_size;
T_weight* mom_in = (T_weight*)tl->addresses[2][tensor_loc];
T_weight* mom_in = (T_weight*)tl.addresses[2][tensor_loc];
mom_in += chunk_idx*chunk_size;
at::Half *model_weights_out = nullptr;
if(N == 4)
{
model_weights_out = (at::Half*)tl->addresses[3][tensor_loc];
model_weights_out = (at::Half*)tl.addresses[3][tensor_loc];
model_weights_out += chunk_idx*chunk_size;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment