Unverified Commit cf50dc7c authored by Deyu Fu's avatar Deyu Fu Committed by GitHub
Browse files

Changes to make xentropysoftmax load/store vectorized when possible: (#725)

* Changes to make xentropysoftmax load/store vectorized when possible:
Increase default ILP so that each thread handle 16 Bytes data in one step
Make thread load/store longest vector possible
Make unroll case handle adjacent data instead of strided, so same order compare to vector case

* Add shift for not aligned case. Remove less than 16 bytes aligned access
parent 17ee854e
/**
* From PyTorch:
*
*
* Copyright (c) 2016- Facebook, Inc (Adam Paszke)
* Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
* Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
......@@ -10,54 +10,54 @@
* Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
* Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
* Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
*
*
* From Caffe2:
*
*
* Copyright (c) 2016-present, Facebook Inc. All rights reserved.
*
*
* All contributions by Facebook:
* Copyright (c) 2016 Facebook Inc.
*
*
* All contributions by Google:
* Copyright (c) 2015 Google Inc.
* All rights reserved.
*
*
* All contributions by Yangqing Jia:
* Copyright (c) 2015 Yangqing Jia
* All rights reserved.
*
*
* All contributions from Caffe:
* Copyright(c) 2013, 2014, 2015, the respective contributors
* All rights reserved.
*
*
* All other contributions:
* Copyright(c) 2015, 2016 the respective contributors
* All rights reserved.
*
*
* Caffe2 uses a copyright model similar to Caffe: each contributor holds
* copyright over their contributions to Caffe2. The project versioning records
* all such contribution and copyright details. If a contributor wants to further
* mark their specific copyright on a particular contribution, they should
* indicate their copyright solely in the commit message of the change when it is
* committed.
*
*
* All rights reserved.
*
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
*
* 1. Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
*
*
* 2. Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
*
*
* 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
* and IDIAP Research Institute nor the names of its contributors may be
* used to endorse or promote products derived from this software without
* specific prior written permission.
*
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
......@@ -70,7 +70,6 @@
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
* POSSIBILITY OF SUCH DAMAGE.
*/
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
......@@ -84,6 +83,8 @@
#include "type_shim.h"
#include "compat.h"
#define ALIGN_BYTES 16
using Tensor = at::Tensor;
using TensorList = at::TensorList;
using ScalarType = at::ScalarType;
......@@ -123,7 +124,7 @@ const int max_threads = 1024;
inline dim3 SoftMax_getBlockSize(int ILP, uint64_t dim_size) {
uint64_t block_size = 1;
uint64_t max_block_size = std::min(dim_size / ILP, static_cast<uint64_t>(max_threads));
while (block_size < max_block_size) block_size *= 2;
while (block_size < (max_block_size/2)) block_size *= 2;
// Launch at least a single warp - the kernel assumes that.
block_size = std::max(block_size, static_cast<uint64_t>(32));
return dim3(block_size);
......@@ -287,29 +288,40 @@ blockReduce(AccumT* smem,
template <template<typename, typename> class Reduction, int ILP, typename T, typename AccumT>
__device__ __forceinline__ AccumT
ilpReduce(T* data,
ilpReduce(int shift,
T* data,
int size,
const Reduction<T, AccumT>& r,
AccumT defaultVal)
{
typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LoadT;
AccumT threadVal = defaultVal;
int offset = threadIdx.x;
// shift and do 1
if(shift > 0){
data -= shift;
size += shift;
if(threadIdx.x >= shift){
threadVal = r(threadVal, data[offset]);
}
size -= blockDim.x;
data += blockDim.x;
}
int last = size % (ILP * blockDim.x);
// Body (unroll by ILP times)
for (; offset < size - last; offset += blockDim.x * ILP) {
T tmp[ILP];
T v[ILP];
LoadT* value = reinterpret_cast<LoadT*>(&v);
#pragma unroll
for (int j = 0; j < ILP; ++j)
tmp[j] = data[offset + j * blockDim.x];
for (; offset * ILP < (size - last); offset += blockDim.x) {
*value = reinterpret_cast<LoadT*>(data)[offset];
#pragma unroll
for (int j = 0; j < ILP; ++j)
threadVal = r(threadVal, tmp[j]);
for (int j = 0; j < ILP; ++j) {
threadVal = r(threadVal, v[j]);
}
}
offset = size - last + threadIdx.x;
// Epilogue
for (; offset < size; offset += blockDim.x)
threadVal = r(threadVal, data[offset]);
......@@ -319,7 +331,8 @@ ilpReduce(T* data,
template <template<typename, typename> class Reduction1, template<typename, typename> class Reduction2, int ILP, typename T, typename AccumT>
__device__ __forceinline__ void
ilpReduce(T* data,
ilpReduce(int shift,
T* data,
int size,
AccumT* reducVal1,
const Reduction1<T, AccumT>& r1,
......@@ -328,27 +341,38 @@ ilpReduce(T* data,
const Reduction2<T, AccumT>& r2,
AccumT defaultVal2)
{
typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LoadT;
AccumT threadVal1 = defaultVal1;
AccumT threadVal2 = defaultVal2;
int offset = threadIdx.x;
// shift and do 1
if(shift > 0){
data -= shift;
size += shift;
if(threadIdx.x >= shift){
threadVal1 = r1(threadVal1, data[offset]);
threadVal2 = r2(threadVal2, data[offset]);
}
size -= blockDim.x;
data += blockDim.x;
}
int last = size % (ILP * blockDim.x);
// Body (unroll by ILP times)
for (; offset < size - last; offset += blockDim.x * ILP) {
T tmp[ILP];
T v[ILP];
LoadT* value = reinterpret_cast<LoadT*>(&v);
#pragma unroll
for (int j = 0; j < ILP; ++j)
tmp[j] = data[offset + j * blockDim.x];
for (; offset * ILP < (size - last); offset += blockDim.x) {
*value = reinterpret_cast<LoadT*>(data)[offset];
#pragma unroll
for (int j = 0; j < ILP; ++j) {
threadVal1 = r1(threadVal1, tmp[j]);
threadVal2 = r2(threadVal2, tmp[j]);
threadVal1 = r1(threadVal1, v[j]);
threadVal2 = r2(threadVal2, v[j]);
}
}
offset = size - last + threadIdx.x;
// Epilogue
for (; offset < size; offset += blockDim.x) {
threadVal1 = r1(threadVal1, data[offset]);
......@@ -375,17 +399,19 @@ cunn_SoftMaxXEntropyForward(
// each block handles a sample in the mini-batch
input += blockIdx.x * classes;
//output += blockIdx.x * classes;
const int shift = ((uint64_t)input) % ALIGN_BYTES / sizeof(scalar_t);
int64_t label = labels[blockIdx.x];
// find the max and sum
accscalar_t threadMax, threadSum, max_k, sum_k;
ilpReduce<MaxFloat, AddFloat, ILP, scalar_t, accscalar_t>(
input, classes,
&threadMax, MaxFloat<scalar_t, accscalar_t>(),
-at::numeric_limits<accscalar_t>::max(),
&threadSum, AddFloat<scalar_t, accscalar_t>(),
static_cast<accscalar_t>(0));
shift, input, classes,
&threadMax, MaxFloat<scalar_t, accscalar_t>(),
-at::numeric_limits<accscalar_t>::max(),
&threadSum, AddFloat<scalar_t, accscalar_t>(),
static_cast<accscalar_t>(0));
blockReduce<Max, Add, accscalar_t>(
sdata,
&max_k, threadMax, Max<accscalar_t>(),
......@@ -393,9 +419,7 @@ cunn_SoftMaxXEntropyForward(
&sum_k, threadSum, Add<accscalar_t>(),
static_cast<accscalar_t>(0));
// reduce all values
accscalar_t threadExp = ilpReduce<SumExpFloat, ILP, scalar_t, accscalar_t>(
input, classes, SumExpFloat<scalar_t, accscalar_t>(max_k), static_cast<accscalar_t>(0));
accscalar_t threadExp = ilpReduce<SumExpFloat, ILP, scalar_t, accscalar_t>(shift, input, classes, SumExpFloat<scalar_t, accscalar_t>(max_k), static_cast<accscalar_t>(0));
accscalar_t sumAll = blockReduce<Add, accscalar_t>(
sdata, threadExp, Add<accscalar_t>(), static_cast<accscalar_t>(0));
......@@ -411,20 +435,16 @@ cunn_SoftMaxXEntropyForward(
}
}
template <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t, template<typename, typename, typename> class Epilogue>
__global__ void
cunn_SoftMaxXEntropyBackward(
scalar_t *gradInput,
scalar_t *logits,
outscalar_t *max_log_sum_exp,
outscalar_t *gradOutput,
int64_t *labels,
const float smoothing,
int classes)
template <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t>
__device__ __forceinline__ void
apply(scalar_t *gradInput,
scalar_t *logits,
outscalar_t *max_log_sum_exp,
outscalar_t *gradOutput,
int64_t *labels,
const float smoothing,
int classes)
{
gradInput += blockIdx.x * classes;
logits += blockIdx.x * classes;
accscalar_t smooth_positives = 1.0 - smoothing;
accscalar_t smooth_negatives = smoothing / classes;
accscalar_t tmpGradOutput = gradOutput[blockIdx.x];
......@@ -433,6 +453,7 @@ cunn_SoftMaxXEntropyBackward(
int offset = threadIdx.x;
int last = classes % (ILP * blockDim.x);
for (; offset < classes - last; offset += blockDim.x * ILP) {
accscalar_t tmpLogits[ILP];
......@@ -444,22 +465,112 @@ cunn_SoftMaxXEntropyBackward(
#pragma unroll
for (int j = 0; j < ILP; ++j)
gradInput[offset + j * blockDim.x] = tmpGradOutput * (
std::exp(tmpLogits[j] - coeff) - static_cast<accscalar_t>(
(offset + j * blockDim.x == label) ? 1 : 0) *
smooth_positives - smooth_negatives);
std::exp(tmpLogits[j] - coeff) - static_cast<accscalar_t>(
(offset + j * blockDim.x == label) ? 1 : 0) *
smooth_positives - smooth_negatives);
}
for (; offset < classes; offset += blockDim.x)
gradInput[offset] = tmpGradOutput * (std::exp(
static_cast<accscalar_t>(logits[offset]) - coeff) -
static_cast<accscalar_t>(logits[offset]) - coeff) -
static_cast<accscalar_t>((offset == label) ? 1 : 0) *
smooth_positives - smooth_negatives);
}
template <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t>
__device__ __forceinline__ void
aligned_apply(int shift,
scalar_t *gradInput,
scalar_t *logits,
outscalar_t *max_log_sum_exp,
outscalar_t *gradOutput,
int64_t *labels,
const float smoothing,
int classes)
{
accscalar_t smooth_positives = 1.0 - smoothing;
accscalar_t smooth_negatives = smoothing / classes;
accscalar_t tmpGradOutput = gradOutput[blockIdx.x];
int64_t label = labels[blockIdx.x];
accscalar_t coeff = max_log_sum_exp[blockIdx.x];
int offset = threadIdx.x;
// shift and do 1
if(shift > 0){
logits -= shift;
gradInput -= shift;
classes += shift;
if(threadIdx.x >= shift){
gradInput[offset] = tmpGradOutput * (std::exp(
static_cast<accscalar_t>(logits[offset]) - coeff) -
static_cast<accscalar_t>(((offset - shift) == label) ? 1 : 0) *
smooth_positives - smooth_negatives);
}
classes -= blockDim.x;
gradInput += blockDim.x;
logits += blockDim.x;
shift -= blockDim.x;
}
int last = classes % (ILP * blockDim.x);
typedef typename std::aligned_storage<ILP*sizeof(scalar_t), ILP*alignof(scalar_t)>::type LoadT;
// input
scalar_t v[ILP];
LoadT* value = reinterpret_cast<LoadT*>(&v);
// output
scalar_t r[ILP];
LoadT* result = reinterpret_cast<LoadT*>(&r);
for (; offset * ILP < (classes - last); offset += blockDim.x) {
*value = reinterpret_cast<LoadT*>(logits)[offset];
#pragma unroll
for (int j = 0; j < ILP; ++j) {
r[j] = tmpGradOutput * (std::exp(
static_cast<accscalar_t>(v[j]) - coeff) -
static_cast<accscalar_t>(((ILP * offset + j - shift) == label) ? 1 : 0) *
smooth_positives - smooth_negatives);
}
reinterpret_cast<LoadT*>(gradInput)[offset] = *result;
}
offset = classes - last + threadIdx.x;
for (; offset < classes; offset += blockDim.x)
gradInput[offset] = tmpGradOutput * (std::exp(
static_cast<accscalar_t>(logits[offset]) - coeff) -
static_cast<accscalar_t>(((offset - shift) == label) ? 1 : 0) *
smooth_positives - smooth_negatives);
}
template <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t, template<typename, typename, typename> class Epilogue>
__global__ void
cunn_SoftMaxXEntropyBackward(
scalar_t *gradInput,
scalar_t *logits,
outscalar_t *max_log_sum_exp,
outscalar_t *gradOutput,
int64_t *labels,
const float smoothing,
int classes)
{
gradInput += blockIdx.x * classes;
logits += blockIdx.x * classes;
// Do vectorized load/store when input/output have same alignment
const int shift = ((uint64_t)logits) % ALIGN_BYTES / sizeof(scalar_t);
const int shift_ = ((uint64_t)gradInput) % ALIGN_BYTES / sizeof(scalar_t);
if (shift == shift_){
aligned_apply<ILP, scalar_t, accscalar_t, outscalar_t>(shift, gradInput, logits, max_log_sum_exp, gradOutput, labels, smoothing, classes);
}
else {
apply<ILP, scalar_t, accscalar_t, outscalar_t>(gradInput, logits, max_log_sum_exp, gradOutput, labels, smoothing, classes);
}
}
template<template<typename, typename, typename> class Epilogue>
std::vector<Tensor> host_softmax_xentropy(
......@@ -495,13 +606,13 @@ std::vector<Tensor> host_softmax_xentropy(
// XXX: it assumes that inner_size == 1
TORCH_CHECK(inner_size == 1, "Currently only inner size 1 supported");
const int ILP = 2;
dim3 grid(outer_size);
dim3 block = SoftMax_getBlockSize(ILP, dim_size);
using namespace at;
DISPATCH_FLOAT_AND_HALF(input.scalar_type(), 0, "host_softmax_xentropy",
using accscalar_t = at::acc_type<scalar_t_0, true>;
const int ILP = sizeof(float4)/sizeof(scalar_t_0);
dim3 block = SoftMax_getBlockSize(ILP, dim_size);
if (!half_to_float) {
cunn_SoftMaxXEntropyForward<ILP, scalar_t_0, accscalar_t, scalar_t_0, Epilogue>
<<<grid, block, 2 * block.x * sizeof(accscalar_t), stream>>>(
......@@ -564,12 +675,12 @@ Tensor host_softmax_xentropy_backward(
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
TORCH_CHECK(inner_size == 1, "Currently only inner size 1 supported");
const int ILP = 2;
dim3 grid(outer_size);
dim3 block = SoftMax_getBlockSize(ILP, dim_size);
DISPATCH_FLOAT_AND_HALF(gI.scalar_type(), 0, "host_softmax_xentropy_backward",
using accscalar_t = acc_type<scalar_t_0, true>;
const int ILP = sizeof(float4)/sizeof(scalar_t_0);
dim3 block = SoftMax_getBlockSize(ILP, dim_size);
if (!half_to_float) {
cunn_SoftMaxXEntropyBackward<ILP, scalar_t_0, accscalar_t, scalar_t_0, Epilogue>
<<<grid, block, block.x * sizeof(accscalar_t), stream>>>(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment