Commit c2b62b7f authored by JR_ZZU's avatar JR_ZZU 🌴
Browse files

delete origin files

parent 2a4864d5
#include <ATen/ATen.h>
#include <ATen/cudnn/Handle.h> // for getcudnnhandle
#include <torch/extension.h>
#include <torch/torch.h>
#include <vector>
#include <cudnn_frontend.h>
#include <iostream>
#ifdef DEBUG
#define DEBUG_MSG(str) do { std::cout << str << std::endl; } while( false )
#else
#define DEBUG_MSG(str) do { } while ( false )
#endif
#ifdef DEBUG_CUDNN
#define DEBUG_CUDNN_MSG(buf, str) do { buf << str << std::endl; } while( false )
#else
#define DEBUG_CUDNN_MSG(buf, str) do { } while ( false )
#endif
#define checkCudnnErr(...) \
do { \
int err = checkCudnnError(__VA_ARGS__, #__VA_ARGS__, __FILE__, __LINE__); \
if (err) { \
return; \
} \
} while (0)
int checkCudnnError(cudnnStatus_t code, const char* expr, const char* file, int line) {
if (code) {
printf("CUDNN error at %s:%d, code=%d (%s) in '%s'\n", file, line, (int)code, cudnnGetErrorString(code), expr);
return 1;
}
return 0;
}
void checkError(cudaError_t code, char const * func, const char *file, const int line, bool abort = true);
#define checkCUDAError(val) { checkError((val), #val, __FILE__, __LINE__); } // in-line regular function
void checkError(cudaError_t code, char const * func, const char *file, const int line, bool abort)
{
if (code != cudaSuccess)
{
const char * errorMessage = cudaGetErrorString(code);
fprintf(stderr, "CUDA error returned from \"%s\" at %s:%d, Error code: %d (%s)\n", func, file, line, code, errorMessage);
if (abort){
cudaDeviceReset();
exit(code);
}
}
}
void generateStrides(const int64_t* dimA, int64_t* strideA, int nbDims, cudnnTensorFormat_t filterFormat) {
// For INT8x4 and INT8x32 we still compute standard strides here to input
// into the cuDNN functions. We will manually scale by resizeFactor in the cpu ref.
if (filterFormat == CUDNN_TENSOR_NCHW) {
strideA[nbDims - 1] = 1;
for (int64_t d = nbDims - 2; d >= 0; d--) {
strideA[d] = strideA[d + 1] * dimA[d + 1];
}
} else {
// Here we assume that the format is CUDNN_TENSOR_NHWC
strideA[1] = 1;
strideA[nbDims - 1] = strideA[1] * dimA[1];
for (int64_t d = nbDims - 2; d >= 2; d--) {
strideA[d] = strideA[d + 1] * dimA[d + 1];
}
strideA[0] = strideA[2] * dimA[2];
}
}
int getFwdConvDilatedFilterDim(int filterDim, int dilation) {
return ((filterDim - 1) * dilation) + 1;
}
int getFwdConvPaddedImageDim(int tensorDim, int pad) {
return tensorDim + (2 * pad);
}
int getFwdConvOutputDim(
int tensorDim,
int pad,
int filterDim,
int stride,
int dilation)
{
int p = (getFwdConvPaddedImageDim(tensorDim, pad) - getFwdConvDilatedFilterDim(filterDim, dilation)) / stride + 1;
return (p);
}
enum {
X_TENSOR,
Y_TENSOR,
W_TENSOR,
Z_TENSOR,
B_TENSOR,
AFTERADD_TENSOR,
AFTERBIAS_TENSOR,
AFTERCONV_TENSOR,
OPTIONAL,
AFTEROPT_TENSOR,
};
using common_conv_descriptors =
std::tuple<cudnn_frontend::Tensor, cudnn_frontend::Tensor, cudnn_frontend::Tensor, cudnn_frontend::ConvDesc>;
common_conv_descriptors
create_common_descriptors(int64_t* x_dim_padded,
int64_t* padA,
int64_t* convstrideA,
int64_t* dilationA,
int64_t* w_dim_padded,
int64_t* y_dim_padded,
cudnnDataType_t dataType,
cudnnConvolutionMode_t mode) {
const int convDim = 2;
int64_t strideA_padded[4];
int64_t outstrideA_padded[4];
int64_t filterstrideA_padded[4];
generateStrides(w_dim_padded, filterstrideA_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(x_dim_padded, strideA_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(y_dim_padded, outstrideA_padded, 4, CUDNN_TENSOR_NHWC);
return common_conv_descriptors(cudnn_frontend::TensorBuilder()
.setDim(4, x_dim_padded)
.setStrides(4, strideA_padded)
.setId('x')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, outstrideA_padded)
.setId('y')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, w_dim_padded)
.setStrides(4, filterstrideA_padded)
.setId('w')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::ConvDescBuilder()
.setDataType(CUDNN_DATA_FLOAT)
.setMathMode(mode)
.setNDims(convDim)
.setStrides(convDim, convstrideA)
.setPrePadding(convDim, padA)
.setPostPadding(convDim, padA)
.setDilation(convDim, dilationA)
.build());
}
using common_convbias_descriptors = std::tuple<cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor>;
common_convbias_descriptors
create_conv_bias_add_act_descriptors(int64_t* x_dim_padded,
int64_t* padA,
int64_t* convstrideA,
int64_t* dilationA,
int64_t* w_dim_padded,
int64_t* y_dim_padded,
cudnnDataType_t dataType) {
const int convDim = 2;
int64_t b_dim_padded[4];
b_dim_padded[0] = 1;
b_dim_padded[1] = y_dim_padded[1];
b_dim_padded[2] = 1;
b_dim_padded[3] = 1;
int64_t x_stride_padded[4];
int64_t y_stride_padded[4];
int64_t w_stride_padded[4];
int64_t b_stride_padded[4];
generateStrides(w_dim_padded, w_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(x_dim_padded, x_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(y_dim_padded, y_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(b_dim_padded, b_stride_padded, 4, CUDNN_TENSOR_NHWC);
return common_convbias_descriptors(cudnn_frontend::TensorBuilder()
.setDim(4, x_dim_padded)
.setStrides(4, x_stride_padded)
.setId('x')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('y')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, w_dim_padded)
.setStrides(4, w_stride_padded)
.setId('w')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, b_dim_padded)
.setStrides(4, b_stride_padded)
.setId('z')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, b_dim_padded)
.setStrides(4, b_stride_padded)
.setId('b')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setVirtual()
.setId('A') // after add
.setAlignment(16)
.setDataType(CUDNN_DATA_FLOAT)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setVirtual()
.setId('B') // after bias
.setAlignment(16)
.setDataType(CUDNN_DATA_FLOAT)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('C') // after conv
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_FLOAT)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('i')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('D') // after optional add
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_FLOAT)
.build());
}
// tensor descriptors used for dgrad
enum {
X_OR_DX_TENSOR,
DY_TENSOR,
W_OR_DW_TENSOR,
SCALE_TENSOR,
RELU_TENSOR,
AFTER_DCONV_TENSOR,
AFTER_DRELU_TENSOR,
};
using dconv_descriptors = std::tuple<cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor>;
dconv_descriptors
create_dconv_descriptors(int64_t* x_dim_padded,
int64_t* padA,
int64_t* convstrideA,
int64_t* dilationA,
int64_t* w_dim_padded,
int64_t* y_dim_padded,
cudnnDataType_t dataType) {
const int convDim = 2;
int64_t b_dim_padded[4];
b_dim_padded[0] = 1;
b_dim_padded[1] = x_dim_padded[1];
b_dim_padded[2] = 1;
b_dim_padded[3] = 1;
int64_t x_stride_padded[4];
int64_t y_stride_padded[4];
int64_t w_stride_padded[4];
int64_t b_stride_padded[4];
generateStrides(w_dim_padded, w_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(x_dim_padded, x_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(y_dim_padded, y_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(b_dim_padded, b_stride_padded, 4, CUDNN_TENSOR_NHWC);
return dconv_descriptors(cudnn_frontend::TensorBuilder()
.setDim(4, x_dim_padded)
.setStrides(4, x_stride_padded)
.setId('x')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('y')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, w_dim_padded)
.setStrides(4, w_stride_padded)
.setId('w')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, b_dim_padded)
.setStrides(4, b_stride_padded)
.setId('s')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, x_dim_padded)
.setStrides(4, x_stride_padded)
.setId('r')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, x_dim_padded)
.setStrides(4, x_stride_padded)
.setVirtual()
.setId('A') // after dconv
.setAlignment(16)
.setDataType(CUDNN_DATA_FLOAT)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, x_dim_padded)
.setStrides(4, x_stride_padded)
.setVirtual()
.setId('B') // after drelu
.setAlignment(16)
.setDataType(CUDNN_DATA_FLOAT)
.build());
}
// create a cache for plan
std::unordered_map<std::string, cudnn_frontend::ExecutionPlan> plan_cache;
// TODO: better name
std::string getConvFusionString(int64_t* x_dim_padded,
int64_t* padA,
int64_t* convstrideA,
int64_t* dilationA,
int64_t* w_dim_padded,
cudnnDataType_t dataType,
std::string fusion_string) {
for(int i=0;i<4;i++) {
fusion_string += 'X';
fusion_string += std::to_string(x_dim_padded[i]);
}
for(int i=0;i<4;i++) {
fusion_string += 'W';
fusion_string += std::to_string(w_dim_padded[i]);
}
for(int i=0;i<2;i++) {
fusion_string += 'P';
fusion_string += std::to_string(padA[i]);
}
for(int i=0;i<2;i++) {
fusion_string += 'S';
fusion_string += std::to_string(convstrideA[i]);
}
for(int i=0;i<2;i++) {
fusion_string += 'D';
fusion_string += std::to_string(dilationA[i]);
}
fusion_string += 'T';
fusion_string += std::to_string(dataType);
return fusion_string;
}
cudnn_frontend::ExecutionPlan& getOrCreatePlan(cudnnHandle_t handle_,
std::stringstream& log_buf,
cudnn_frontend::OperationGraph& opGraph,
std::string cache_string,
bool use_heuristic = true){
auto it = plan_cache.find(cache_string);
if (it != plan_cache.end()) {
DEBUG_CUDNN_MSG(log_buf, "Found plan in cache");
return it->second;
} else {
if (use_heuristic){
// TODO: confirm which mode to use
auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()
.setOperationGraph(opGraph)
.setHeurMode(CUDNN_HEUR_MODE_INSTANT)
.build();
// try 3 times for now as WAR for no heuristic training
int max_tries = 3, count = 0;
auto& engine_configs = heuristics.getEngineConfig(max_tries);
while(true) {
try {
plan_cache.emplace(cache_string, std::move(cudnn_frontend::ExecutionPlanBuilder()
.setHandle(handle_)
.setEngineConfig(engine_configs[count], opGraph.getTag())
.build()));
break;
} catch (cudnn_frontend::cudnnException e) {
if (++count == max_tries) throw e;
}
}
}else{
DEBUG_CUDNN_MSG(log_buf, "No plan in cache");
// How many engines support this operation graph ?
auto total_engines = opGraph.getEngineCount();
DEBUG_CUDNN_MSG(log_buf, opGraph.describe() << " has " << total_engines << " engines.");
// We have to randomly pick one engine from [0, total_engines)
// Selecting "0" by default
auto engine = cudnn_frontend::EngineBuilder().setGlobalEngineIdx(0).setOperationGraph(opGraph).build();
DEBUG_CUDNN_MSG(log_buf, engine.describe());
auto& knobs = engine.getSupportedKnobs();
for (auto it = std::begin(knobs); it != std::end(knobs); ++it) {
DEBUG_CUDNN_MSG(log_buf, it->describe());
}
if (knobs.begin() != knobs.end()) {
DEBUG_CUDNN_MSG(log_buf, "Updated knob choice");
knobs.begin()->setChoice(knobs.begin()->getMinValue() + 1);
DEBUG_CUDNN_MSG(log_buf, knobs.begin()->describe());
}
// Createmplacee the requisite engine config
auto engine_config = cudnn_frontend::EngineConfigBuilder().setEngine(engine).build();
DEBUG_CUDNN_MSG(log_buf, engine_config.describe());
plan_cache.emplace(cache_string, std::move(cudnn_frontend::ExecutionPlanBuilder().setHandle(handle_).setEngineConfig(engine_config).build()));
}
return plan_cache.find(cache_string)->second;
}
}
void
run_conv_scale_bias_add_activation(int64_t* x_dim_padded,
int64_t* pad,
int64_t* convstride,
int64_t* dilation,
int64_t* w_dim_padded,
int64_t* y_dim_padded,
cudnnDataType_t dataType,
at::Half* devPtrX,
at::Half* devPtrW,
at::Half* devPtrY,
at::Half* devPtrZ,
at::Half* devPtrB,
at::Half* devPtrI) {
cudnnHandle_t handle_ = torch::native::getCudnnHandle();
std::stringstream log_buf;
try {
int convDim = 2;
// Creates the necessary tensor descriptors
common_convbias_descriptors tensors = create_conv_bias_add_act_descriptors(
x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType);
DEBUG_CUDNN_MSG(log_buf, std::get<X_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<Y_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<W_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<Z_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<B_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTERADD_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTERBIAS_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTERCONV_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<OPTIONAL>(tensors).describe());
// Define the add operation
auto scaleDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_MUL)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe());
// Define the bias operation
auto biasDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_ADD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, biasDesc.describe());
// optional add
auto addDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_ADD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, addDesc.describe());
// Define the activation operation
auto actDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_RELU_FWD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, actDesc.describe());
// Define the convolution problem
auto convDesc = cudnn_frontend::ConvDescBuilder()
.setDataType(CUDNN_DATA_FLOAT)
.setMathMode(CUDNN_CROSS_CORRELATION)
.setNDims(convDim)
.setStrides(convDim, convstride)
.setPrePadding(convDim, pad)
.setPostPadding(convDim, pad)
.setDilation(convDim, dilation)
.build();
DEBUG_CUDNN_MSG(log_buf, convDesc.describe());
float alpha = 1.0f;
float beta = 0.0f;
// Create a convolution Node
auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR)
.setxDesc(std::get<X_TENSOR>(tensors))
.setwDesc(std::get<W_TENSOR>(tensors))
.setyDesc(std::get<AFTERCONV_TENSOR>(tensors))
.setcDesc(convDesc)
.setAlpha(alpha)
.setBeta(beta)
.build();
DEBUG_CUDNN_MSG(log_buf, conv_op.describe());
// Create a Add Node with scaling parameters.
auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(conv_op.getOutputTensor())
.setbDesc(std::get<Z_TENSOR>(tensors))
.setyDesc(std::get<AFTERADD_TENSOR>(tensors))
.setpwDesc(scaleDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, scale_op.describe());
// Create a Bias Node.
auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(scale_op.getOutputTensor())
.setbDesc(std::get<B_TENSOR>(tensors))
.setyDesc(std::get<AFTERBIAS_TENSOR>(tensors))
.setpwDesc(biasDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, bias_op.describe());
// Create a optional add Node.
auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(bias_op.getOutputTensor())
.setbDesc(std::get<OPTIONAL>(tensors))
.setyDesc(std::get<AFTEROPT_TENSOR>(tensors))
.setpwDesc(addDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, add_op.describe());
// Create an Activation Node.
auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(devPtrI ? add_op.getOutputTensor() : bias_op.getOutputTensor())
.setyDesc(std::get<Y_TENSOR>(tensors))
.setpwDesc(actDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, act_op.describe());
// Create an Operation Graph. In this case it is convolution add bias activation
std::array<cudnn_frontend::Operation const*, 5> ops = {&conv_op, &scale_op, &bias_op, devPtrI ? &add_op : &act_op, &act_op};
auto opGraph = cudnn_frontend::OperationGraphBuilder()
.setHandle(handle_)
.setOperationGraph(devPtrI ? ops.size() : 4, ops.data())
.build();
// Create string encoding for plan caching
auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag());
DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
auto workspace_size = plan.getWorkspaceSize();
DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
void* workspace_ptr = nullptr;
auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
if (workspace_size > 0) {
workspace_ptr = workspace_tensor.data_ptr<float>();
}
void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrB, devPtrI};
int64_t uids[] = {'x', 'y', 'w', 'z', 'b', 'i'};
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace_ptr)
.setDataPointers(devPtrI ? 6 : 5, data_ptrs)
.setUids(devPtrI ? 6 : 5, uids)
.build();
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
checkCudnnErr(status);
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
} catch (cudnn_frontend::cudnnException e) {
std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
}
}
void
run_conv_scale_bias(int64_t* x_dim_padded,
int64_t* pad,
int64_t* convstride,
int64_t* dilation,
int64_t* w_dim_padded,
int64_t* y_dim_padded,
cudnnDataType_t dataType,
at::Half* devPtrX,
at::Half* devPtrW,
at::Half* devPtrY,
at::Half* devPtrZ,
at::Half* devPtrB) {
cudnnHandle_t handle_ = torch::native::getCudnnHandle();
std::stringstream log_buf;
try {
int convDim = 2;
// Creates the necessary tensor descriptors
common_convbias_descriptors tensors = create_conv_bias_add_act_descriptors(
x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType);
DEBUG_CUDNN_MSG(log_buf, std::get<X_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<Y_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<W_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<Z_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<B_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTERADD_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTERBIAS_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTERCONV_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<OPTIONAL>(tensors).describe());
// Define the add operation
auto scaleDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_MUL)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe());
// Define the bias operation
auto addDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_ADD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, addDesc.describe());
// Define the convolution problem
auto convDesc = cudnn_frontend::ConvDescBuilder()
.setDataType(CUDNN_DATA_FLOAT)
.setMathMode(CUDNN_CROSS_CORRELATION)
.setNDims(convDim)
.setStrides(convDim, convstride)
.setPrePadding(convDim, pad)
.setPostPadding(convDim, pad)
.setDilation(convDim, dilation)
.build();
DEBUG_CUDNN_MSG(log_buf, convDesc.describe());
float alpha = 1.0f;
float beta = 0.0f;
// Create a convolution Node
auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR)
.setxDesc(std::get<X_TENSOR>(tensors))
.setwDesc(std::get<W_TENSOR>(tensors))
.setyDesc(std::get<AFTERCONV_TENSOR>(tensors))
.setcDesc(convDesc)
.setAlpha(alpha)
.setBeta(beta)
.build();
DEBUG_CUDNN_MSG(log_buf, conv_op.describe());
// Create a Add Node with scaling parameters.
auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(conv_op.getOutputTensor())
.setbDesc(std::get<Z_TENSOR>(tensors))
.setyDesc(std::get<AFTERADD_TENSOR>(tensors)) // TODO: change enum to aftermul
.setpwDesc(scaleDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, scale_op.describe());
// Create a Bias Node.
auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(scale_op.getOutputTensor())
.setbDesc(std::get<B_TENSOR>(tensors))
.setyDesc(std::get<Y_TENSOR>(tensors))
.setpwDesc(addDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, add_op.describe());
// Create an Operation Graph. In this case it is convolution add bias activation
std::array<cudnn_frontend::Operation const*, 3> ops = {&conv_op, &scale_op, &add_op};
auto opGraph = cudnn_frontend::OperationGraphBuilder()
.setHandle(handle_)
.setOperationGraph(ops.size(), ops.data())
.build();
// Create string encoding for plan caching
auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag());
DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
auto workspace_size = plan.getWorkspaceSize();
DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
void* workspace_ptr = nullptr;
auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
if (workspace_size > 0) {
workspace_ptr = workspace_tensor.data_ptr<float>();
}
void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrB};
int64_t uids[] = {'x', 'y', 'w', 'z', 'b'};
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace_ptr)
.setDataPointers(5, data_ptrs)
.setUids(5, uids)
.build();
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
checkCudnnErr(status);
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
} catch (cudnn_frontend::cudnnException e) {
std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
}
}
void
run_dconv_drelu_dscale(int64_t* x_dim_padded,
int64_t* pad,
int64_t* convstride,
int64_t* dilation,
int64_t* w_dim_padded,
int64_t* y_dim_padded,
cudnnDataType_t dataType,
at::Half* devPtrX,
at::Half* devPtrW,
at::Half* devPtrY,
at::Half* devPtrZ,
at::Half* devPtrR) {
cudnnHandle_t handle_ = torch::native::getCudnnHandle();
std::stringstream log_buf;
try {
int convDim = 2;
// Creates the necessary tensor descriptors
dconv_descriptors tensors = create_dconv_descriptors(
x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType);
DEBUG_CUDNN_MSG(log_buf, std::get<X_OR_DX_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DY_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<W_OR_DW_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<SCALE_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<RELU_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTER_DCONV_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTER_DRELU_TENSOR>(tensors).describe());
// Define the convolution problem
auto convDesc = cudnn_frontend::ConvDescBuilder()
.setDataType(CUDNN_DATA_FLOAT)
.setMathMode(CUDNN_CROSS_CORRELATION)
.setNDims(convDim)
.setStrides(convDim, convstride)
.setPrePadding(convDim, pad)
.setPostPadding(convDim, pad)
.setDilation(convDim, dilation)
.build();
DEBUG_CUDNN_MSG(log_buf, convDesc.describe());
// Define the activation backward operation
auto actDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_RELU_BWD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, actDesc.describe());
// Define the scale backward operation
auto scaleDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_MUL)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe());
float alpha = 1.0f;
float beta = 0.0f;
// Create a convolution Node
auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR)
.setdxDesc(std::get<AFTER_DCONV_TENSOR>(tensors))
.setwDesc(std::get<W_OR_DW_TENSOR>(tensors))
.setdyDesc(std::get<DY_TENSOR>(tensors))
.setcDesc(convDesc)
.setAlpha(alpha)
.setBeta(beta)
.build();
DEBUG_CUDNN_MSG(log_buf, conv_op.describe());
// TODO: do we need getOutputTensor(), and what it returns in backward case?
// Create an relu backward Node.
auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setdyDesc(std::get<AFTER_DCONV_TENSOR>(tensors))
.setxDesc(std::get<RELU_TENSOR>(tensors))
.setdxDesc(std::get<AFTER_DRELU_TENSOR>(tensors))
.setpwDesc(actDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, act_op.describe());
// Create a Scale Node.
auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<AFTER_DRELU_TENSOR>(tensors))
.setbDesc(std::get<SCALE_TENSOR>(tensors))
.setyDesc(std::get<X_OR_DX_TENSOR>(tensors))
.setpwDesc(scaleDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, scale_op.describe());
// Create an Operation Graph. In this case it is convolution add bias activation
std::array<cudnn_frontend::Operation const*, 3> ops = {&conv_op, &act_op, &scale_op};
auto opGraph = cudnn_frontend::OperationGraphBuilder()
.setHandle(handle_)
.setOperationGraph(ops.size(), ops.data())
.build();
// Create string encoding for plan caching
auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag());
DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
auto workspace_size = plan.getWorkspaceSize();
DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
void* workspace_ptr = nullptr;
auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
if (workspace_size > 0) {
workspace_ptr = workspace_tensor.data_ptr<float>();
}
void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrR};
int64_t uids[] = {'x', 'y', 'w', 's', 'r'};
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace_ptr)
.setDataPointers(5, data_ptrs)
.setUids(5, uids)
.build();
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
checkCudnnErr(status);
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
} catch (cudnn_frontend::cudnnException e) {
std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
}
}
void
run_dconv(int64_t* x_dim_padded,
int64_t* pad,
int64_t* convstride,
int64_t* dilation,
int64_t* w_dim_padded,
int64_t* y_dim_padded,
cudnnDataType_t dataType,
at::Half* devPtrX,
at::Half* devPtrW,
at::Half* devPtrY,
cudnnBackendDescriptorType_t mode) {
cudnnHandle_t handle_ = torch::native::getCudnnHandle();
std::stringstream log_buf;
try {
int convDim = 2;
// Creates the necessary tensor descriptors
dconv_descriptors tensors = create_dconv_descriptors(
x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType);
DEBUG_CUDNN_MSG(log_buf, std::get<X_OR_DX_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DY_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<W_OR_DW_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<SCALE_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<RELU_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTER_DCONV_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTER_DRELU_TENSOR>(tensors).describe());
// Define the convolution problem
auto convDesc = cudnn_frontend::ConvDescBuilder()
.setDataType(CUDNN_DATA_FLOAT)
.setMathMode(CUDNN_CROSS_CORRELATION)
.setNDims(convDim)
.setStrides(convDim, convstride)
.setPrePadding(convDim, pad)
.setPostPadding(convDim, pad)
.setDilation(convDim, dilation)
.build();
DEBUG_CUDNN_MSG(log_buf, convDesc.describe());
float alpha = 1.0f;
float beta = 0.0f;
// Create a convolution Node
// mode should be one of following
// CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR
// CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR
auto conv_op_builder = cudnn_frontend::OperationBuilder(mode);
if (mode == CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR) {
conv_op_builder.setdxDesc(std::get<X_OR_DX_TENSOR>(tensors))
.setwDesc(std::get<W_OR_DW_TENSOR>(tensors))
.setdyDesc(std::get<DY_TENSOR>(tensors))
.setcDesc(convDesc)
.setAlpha(alpha)
.setBeta(beta);
}
else {
conv_op_builder.setxDesc(std::get<X_OR_DX_TENSOR>(tensors))
.setdwDesc(std::get<W_OR_DW_TENSOR>(tensors))
.setdyDesc(std::get<DY_TENSOR>(tensors))
.setcDesc(convDesc)
.setAlpha(alpha)
.setBeta(beta);
}
auto conv_op = conv_op_builder.build();
DEBUG_CUDNN_MSG(log_buf, conv_op.describe());
// Create an Operation Graph. In this case it is convolution add bias activation
std::array<cudnn_frontend::Operation const*, 1> ops = {&conv_op};
auto opGraph = cudnn_frontend::OperationGraphBuilder()
.setHandle(handle_)
.setOperationGraph(ops.size(), ops.data())
.build();
// Create string encoding for plan caching
auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag());
DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
auto workspace_size = plan.getWorkspaceSize();
DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
void* workspace_ptr = nullptr;
auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
if (workspace_size > 0) {
workspace_ptr = workspace_tensor.data_ptr<float>();
}
void* data_ptrs[] = {devPtrX, devPtrY, devPtrW};
int64_t uids[] = {'x', 'y', 'w'};
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace_ptr)
.setDataPointers(3, data_ptrs)
.setUids(3, uids)
.build();
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
checkCudnnErr(status);
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
} catch (cudnn_frontend::cudnnException e) {
std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
}
}
void
run_dconv_add(int64_t* x_dim_padded,
int64_t* pad,
int64_t* convstride,
int64_t* dilation,
int64_t* w_dim_padded,
int64_t* y_dim_padded,
cudnnDataType_t dataType,
at::Half* devPtrX,
at::Half* devPtrW,
at::Half* devPtrY,
at::Half* devPtrR) {
cudnnHandle_t handle_ = torch::native::getCudnnHandle();
std::stringstream log_buf;
try {
int convDim = 2;
// Creates the necessary tensor descriptors
dconv_descriptors tensors = create_dconv_descriptors(
x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType);
DEBUG_CUDNN_MSG(log_buf, std::get<X_OR_DX_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DY_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<W_OR_DW_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<SCALE_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<RELU_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTER_DCONV_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTER_DRELU_TENSOR>(tensors).describe());
// Define the convolution problem
auto convDesc = cudnn_frontend::ConvDescBuilder()
.setDataType(CUDNN_DATA_FLOAT)
.setMathMode(CUDNN_CROSS_CORRELATION)
.setNDims(convDim)
.setStrides(convDim, convstride)
.setPrePadding(convDim, pad)
.setPostPadding(convDim, pad)
.setDilation(convDim, dilation)
.build();
DEBUG_CUDNN_MSG(log_buf, convDesc.describe());
// Define the add backward operation
auto addDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_ADD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, addDesc.describe());
float alpha = 1.0f;
float beta = 0.0f;
// Create a convolution Node
auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR)
.setdxDesc(std::get<AFTER_DCONV_TENSOR>(tensors))
.setwDesc(std::get<W_OR_DW_TENSOR>(tensors))
.setdyDesc(std::get<DY_TENSOR>(tensors))
.setcDesc(convDesc)
.setAlpha(alpha)
.setBeta(beta)
.build();
DEBUG_CUDNN_MSG(log_buf, conv_op.describe());
// TODO: do we need getOutputTensor(), and what it returns in backward case?
// Create add Node.
auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<AFTER_DCONV_TENSOR>(tensors))
.setbDesc(std::get<RELU_TENSOR>(tensors))
.setyDesc(std::get<X_OR_DX_TENSOR>(tensors))
.setpwDesc(addDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, add_op.describe());
// Create an Operation Graph. In this case it is convolution add bias activation
std::array<cudnn_frontend::Operation const*, 2> ops = {&conv_op, &add_op};
auto opGraph = cudnn_frontend::OperationGraphBuilder()
.setHandle(handle_)
.setOperationGraph(ops.size(), ops.data())
.build();
// Create string encoding for plan caching
auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag());
DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
auto workspace_size = plan.getWorkspaceSize();
DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
void* workspace_ptr = nullptr;
auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
if (workspace_size > 0) {
workspace_ptr = workspace_tensor.data_ptr<float>();
}
void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrR};
int64_t uids[] = {'x', 'y', 'w', 'r'};
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace_ptr)
.setDataPointers(4, data_ptrs)
.setUids(4, uids)
.build();
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
checkCudnnErr(status);
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
} catch (cudnn_frontend::cudnnException e) {
std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
}
}
// inputs contains x,w,z,b,(i)
std::vector<at::Tensor> bottleneck_forward(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs) {
std::cout << std::fixed;
// create output vector
std::vector<at::Tensor> outputs;
auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
// setup dimensions
int64_t dimA[] = {0, 0, 0, 0};
int64_t filterdimA1[] = {0, 0, 0, 0};
int64_t filterdimA2[] = {0, 0, 0, 0};
int64_t filterdimA3[] = {0, 0, 0, 0};
int64_t filterdimA4[] = {0, 0, 0, 0};
// All dim calculation after this order of n,c,h,w
int axis[] {0,1,2,3};
if (explicit_nhwc) {
axis[0] = 0;
axis[1] = 3;
axis[2] = 1;
axis[3] = 2;
}
for (int dim=0;dim<4;dim++) {
dimA[dim] = inputs[0].size(axis[dim]);
filterdimA1[dim] = inputs[1].size(axis[dim]);
filterdimA2[dim] = inputs[2].size(axis[dim]);
filterdimA3[dim] = inputs[3].size(axis[dim]);
}
if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]) {
for (int dim=0;dim<4;dim++) {
filterdimA4[dim] = inputs[10].size(axis[dim]);
}
}
// output dim in n,c,h,w used by backend
int64_t outdimA1[] = {0, 0, 0, 0}; // Computed Below
int64_t outdimA2[] = {0, 0, 0, 0}; // Computed Below
int64_t outdimA3[] = {0, 0, 0, 0}; // Computed Below
// use these fixed value for test run
int64_t padA[] = {0, 0};
int64_t padA1[] = {1, 1};
int64_t dilationA[] = {1, 1};
int64_t convstrideA[] = {1, 1};
int64_t convstride1X1[] = {stride_1X1, stride_1X1};
// compute output from pad/stride/dilation
outdimA1[0] = dimA[0];
outdimA1[1] = filterdimA1[0];
for (int dim = 0; dim < 2; dim++) {
outdimA1[dim + 2] = getFwdConvOutputDim(dimA[dim + 2], padA[dim], filterdimA1[dim + 2], convstride1X1[dim], dilationA[dim]);
}
outdimA2[0] = outdimA1[0];
outdimA2[1] = filterdimA2[0];
for (int dim = 0; dim < 2; dim++) {
outdimA2[dim + 2] = getFwdConvOutputDim(outdimA1[dim + 2], padA1[dim], filterdimA2[dim + 2], convstrideA[dim], dilationA[dim]);
}
outdimA3[0] = outdimA2[0];
outdimA3[1] = filterdimA3[0];
for (int dim = 0; dim < 2; dim++) {
outdimA3[dim + 2] = getFwdConvOutputDim(outdimA2[dim + 2], padA[dim], filterdimA3[dim + 2], convstrideA[dim], dilationA[dim]);
}
// Create output tensor in the correct shape in pytorch's view
int64_t outdim1[] = {0, 0, 0, 0};
int64_t outdim2[] = {0, 0, 0, 0};
int64_t outdim3[] = {0, 0, 0, 0};
if (explicit_nhwc) {
axis[0] = 0;
axis[1] = 2;
axis[2] = 3;
axis[3] = 1;
}
for (int dim=0;dim<4;dim++) {
outdim1[dim] = outdimA1[axis[dim]];
outdim2[dim] = outdimA2[axis[dim]];
outdim3[dim] = outdimA3[axis[dim]];
}
// run
at::Half* x = inputs[0].data_ptr<at::Half>();
at::Half* w = inputs[1].data_ptr<at::Half>();
at::Half* z = inputs[4].data_ptr<at::Half>();
at::Half* b = inputs[7].data_ptr<at::Half>();
auto out1 = at::empty(outdim1, inputs[0].type(), output_format);
at::Half* y1 = out1.data_ptr<at::Half>();
run_conv_scale_bias_add_activation(dimA,
padA,
convstride1X1,
dilationA,
filterdimA1,
outdimA1,
CUDNN_DATA_HALF,
x,
w,
y1,
z,
b,
nullptr);
DEBUG_MSG("[DEBUG] new relu1 : " << out1.to(at::kFloat).sum().item<float>());
w = inputs[2].data_ptr<at::Half>();
z = inputs[5].data_ptr<at::Half>();
b = inputs[8].data_ptr<at::Half>();
auto out2 = at::empty(outdim2, inputs[0].type(), output_format);
at::Half* y2 = out2.data_ptr<at::Half>();
run_conv_scale_bias_add_activation(outdimA1,
padA1,
convstrideA,
dilationA,
filterdimA2,
outdimA2,
CUDNN_DATA_HALF,
y1,
w,
y2,
z,
b,
nullptr);
DEBUG_MSG("[DEBUG] new relu2 : " << out2.to(at::kFloat).sum().item<float>());
// create output of conv3
auto out3 = at::empty(outdim3, inputs[0].type(), output_format);
at::Half* y3 = out3.data_ptr<at::Half>();
// create output of conv4 that may exist
auto identity = at::empty_like(out3);
at::Half* yi = identity.data_ptr<at::Half>();
if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]){
w = inputs[10].data_ptr<at::Half>();
z = inputs[11].data_ptr<at::Half>();
b = inputs[12].data_ptr<at::Half>();
run_conv_scale_bias(dimA,
padA,
convstride1X1,
dilationA,
filterdimA4,
outdimA3,
CUDNN_DATA_HALF,
x,
w,
yi,
z,
b);
DEBUG_MSG("[DEBUG] new downsample : " << identity.to(at::kFloat).sum().item<float>());
}
else {
yi = x;
}
w = inputs[3].data_ptr<at::Half>();
z = inputs[6].data_ptr<at::Half>();
b = inputs[9].data_ptr<at::Half>();
run_conv_scale_bias_add_activation(outdimA2,
padA,
convstrideA,
dilationA,
filterdimA3,
outdimA3,
CUDNN_DATA_HALF,
y2,
w,
y3,
z,
b,
yi);
DEBUG_MSG("[DEBUG] new relu3 : " << out3.to(at::kFloat).sum().item<float>());
outputs.push_back(out1);
outputs.push_back(out2);
outputs.push_back(out3);
return outputs;
}
std::vector<at::Tensor> bottleneck_backward(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs) {
bool requires_grad = inputs[0].requires_grad();
std::cout << std::fixed;
// create output vector
std::vector<at::Tensor> outputs;
auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
// setup dimensions
int64_t dimA[] = {0, 0, 0, 0};
int64_t filterdimA1[] = {0, 0, 0, 0};
int64_t filterdimA2[] = {0, 0, 0, 0};
int64_t filterdimA3[] = {0, 0, 0, 0};
int64_t filterdimA4[] = {0, 0, 0, 0};
// All dim calculation after this order of n,c,h,w
int axis[] {0,1,2,3};
if (explicit_nhwc) {
axis[0] = 0;
axis[1] = 3;
axis[2] = 1;
axis[3] = 2;
}
for (int dim=0;dim<4;dim++) {
dimA[dim] = inputs[0].size(axis[dim]);
filterdimA1[dim] = inputs[1].size(axis[dim]);
filterdimA2[dim] = inputs[2].size(axis[dim]);
filterdimA3[dim] = inputs[3].size(axis[dim]);
}
if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]) {
for (int dim=0;dim<4;dim++) {
filterdimA4[dim] = inputs[14].size(axis[dim]);
}
}
// output dim in n,c,h,w used by backend
int64_t outdimA1[] = {0, 0, 0, 0}; // Computed Below
int64_t outdimA2[] = {0, 0, 0, 0}; // Computed Below
int64_t outdimA3[] = {0, 0, 0, 0}; // Computed Below
// use these fixed value for test run
int64_t padA[] = {0, 0};
int64_t padA1[] = {1, 1};
int64_t dilationA[] = {1, 1};
int64_t convstrideA[] = {1, 1};
int64_t convstride1X1[] = {stride_1X1, stride_1X1};
// compute output from pad/stride/dilation
outdimA1[0] = dimA[0];
outdimA1[1] = filterdimA1[0];
for (int dim = 0; dim < 2; dim++) {
outdimA1[dim + 2] = getFwdConvOutputDim(dimA[dim + 2], padA[dim], filterdimA1[dim + 2], convstride1X1[dim], dilationA[dim]);
}
outdimA2[0] = outdimA1[0];
outdimA2[1] = filterdimA2[0];
for (int dim = 0; dim < 2; dim++) {
outdimA2[dim + 2] = getFwdConvOutputDim(outdimA1[dim + 2], padA1[dim], filterdimA2[dim + 2], convstrideA[dim], dilationA[dim]);
}
outdimA3[0] = outdimA2[0];
outdimA3[1] = filterdimA3[0];
for (int dim = 0; dim < 2; dim++) {
outdimA3[dim + 2] = getFwdConvOutputDim(outdimA2[dim + 2], padA[dim], filterdimA3[dim + 2], convstrideA[dim], dilationA[dim]);
}
// Create output tensor in the correct shape in pytorch's view
int64_t outdim1[] = {0, 0, 0, 0};
int64_t outdim2[] = {0, 0, 0, 0};
int64_t outdim3[] = {0, 0, 0, 0};
if (explicit_nhwc) {
axis[0] = 0;
axis[1] = 2;
axis[2] = 3;
axis[3] = 1;
}
for (int dim=0;dim<4;dim++) {
outdim1[dim] = outdimA1[axis[dim]];
outdim2[dim] = outdimA2[axis[dim]];
outdim3[dim] = outdimA3[axis[dim]];
}
// dconv3+drelu2+dscale2
at::Half* conv_in = inputs[13].data_ptr<at::Half>();
at::Half* dy3 = inputs[10].data_ptr<at::Half>();
DEBUG_MSG("[DEBUG] new dconv3 : " << inputs[10].to(at::kFloat).sum().item<float>());
// wgrad
auto wgrad3 = at::empty_like(inputs[3]);
at::Half* dw3 = wgrad3.data_ptr<at::Half>();
run_dconv(outdimA2,
padA,
convstrideA,
dilationA,
filterdimA3,
outdimA3,
CUDNN_DATA_HALF,
conv_in,
dw3,
dy3,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);
// dgrad
auto grad_out2 = at::empty(outdim2, inputs[0].type(), output_format);
at::Half* dy2 = grad_out2.data_ptr<at::Half>();
at::Half* w = inputs[3].data_ptr<at::Half>();
at::Half* z = inputs[5].data_ptr<at::Half>();
at::Half* relu2 = inputs[13].data_ptr<at::Half>();
run_dconv_drelu_dscale(outdimA2,
padA,
convstrideA,
dilationA,
filterdimA3,
outdimA3,
CUDNN_DATA_HALF,
dy2,
w,
dy3,
z,
relu2);
DEBUG_MSG("[DEBUG] new dconv2 : " << grad_out2.to(at::kFloat).sum().item<float>());
// dconv2+drelu1+dscale1
conv_in = inputs[12].data_ptr<at::Half>();
// wgrad
auto wgrad2 = at::empty_like(inputs[2]);
at::Half* dw2 = wgrad2.data_ptr<at::Half>();
run_dconv(outdimA1,
padA1,
convstrideA,
dilationA,
filterdimA2,
outdimA2,
CUDNN_DATA_HALF,
conv_in,
dw2,
dy2,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);
// dgrad
auto grad_out1 = at::empty(outdim1, inputs[0].type(), output_format);
at::Half* dy1 = grad_out1.data_ptr<at::Half>();
w = inputs[2].data_ptr<at::Half>();
z = inputs[4].data_ptr<at::Half>();
at::Half* relu1 = inputs[12].data_ptr<at::Half>();
// fused dgrad
run_dconv_drelu_dscale(outdimA1,
padA1,
convstrideA,
dilationA,
filterdimA2,
outdimA2,
CUDNN_DATA_HALF,
dy1,
w,
dy2,
z,
relu1);
/*
// backward strided conv cannot be fused
// if stride == 1 but channel changes, we can fuse here
if (stride_1X1 != 1){
// dgrad
run_dconv(outdimA1,
padA1,
convstride1X1,
dilationA,
filterdimA2,
outdimA2,
CUDNN_DATA_HALF,
dy1,
w,
dy2,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR);
// mul fused mask
grad_out1.mul_(inputs[15]);
}
else {
at::Half* relu1 = inputs[12].data_ptr<at::Half>();
// fused dgrad
run_dconv_drelu_dscale(outdimA1,
padA1,
convstride1X1,
dilationA,
filterdimA2,
outdimA2,
CUDNN_DATA_HALF,
dy1,
w,
dy2,
z,
relu1);
}
*/
DEBUG_MSG("[DEBUG] new dconv1 : " << grad_out1.to(at::kFloat).sum().item<float>());
// create grads of conv4 that may exist
auto grad_x_conv4 = at::empty_like(inputs[0]);
at::Half* dx_conv4 = grad_x_conv4.data_ptr<at::Half>();
at::Tensor wgrad4;
// x used for dconv1 and dconv4 wgrad
at::Half* x = inputs[0].data_ptr<at::Half>();
if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]){
w = inputs[14].data_ptr<at::Half>();
at::Half* dy_conv4 = inputs[11].data_ptr<at::Half>();
if (requires_grad) {
run_dconv(dimA,
padA,
convstride1X1,
dilationA,
filterdimA4,
outdimA3,
CUDNN_DATA_HALF,
dx_conv4,
w,
dy_conv4,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR);
// we don't print here since we can't hook out this grad in pytorch alone to compare, due to addition with dx
// DEBUG_MSG("[DEBUG] new dx_identity : " << grad_x_conv4.to(at::kFloat).sum().item<float>());
}
// wgrad
wgrad4 = at::empty_like(inputs[14]);
at::Half* dw4 = wgrad4.data_ptr<at::Half>();
run_dconv(dimA,
padA,
convstride1X1,
dilationA,
filterdimA4,
outdimA3,
CUDNN_DATA_HALF,
x,
dw4,
dy_conv4,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);
}
else {
// if there is no downsample, dx_conv4 is fork of drelu3
dx_conv4 = inputs[11].data_ptr<at::Half>();
}
// dconv1+add
// wgrad
auto wgrad1 = at::empty_like(inputs[1]);
at::Half* dw1 = wgrad1.data_ptr<at::Half>();
run_dconv(dimA,
padA,
convstride1X1,
dilationA,
filterdimA1,
outdimA1,
CUDNN_DATA_HALF,
x,
dw1,
dy1,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);
// dgrad
w = inputs[1].data_ptr<at::Half>();
auto grad_x = at::empty_like(inputs[0]);
at::Half* dx = grad_x.data_ptr<at::Half>();
// backward strided conv cannot be fused
// if stride == 1 but channel changes, we can fuse here
if (requires_grad){
if (stride_1X1 != 1){
run_dconv(dimA,
padA,
convstride1X1,
dilationA,
filterdimA1,
outdimA1,
CUDNN_DATA_HALF,
dx,
w,
dy1,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR);
// add 2 together
grad_x.add_(grad_x_conv4);
}
else {
run_dconv_add(dimA,
padA,
convstride1X1,
dilationA,
filterdimA1,
outdimA1,
CUDNN_DATA_HALF,
dx,
w,
dy1,
dx_conv4);
}
}
DEBUG_MSG("[DEBUG] new dx : " << grad_x.to(at::kFloat).sum().item<float>());
DEBUG_MSG("[DEBUG] new wgrad1 : " << wgrad1.to(at::kFloat).sum().item<float>());
DEBUG_MSG("[DEBUG] new wgrad2 : " << wgrad2.to(at::kFloat).sum().item<float>());
DEBUG_MSG("[DEBUG] new wgrad3 : " << wgrad3.to(at::kFloat).sum().item<float>());
outputs.push_back(grad_x);
outputs.push_back(wgrad1);
outputs.push_back(wgrad2);
outputs.push_back(wgrad3);
if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]) {
DEBUG_MSG("[DEBUG] new wgrad4 : " << wgrad4.to(at::kFloat).sum().item<float>());
outputs.push_back(wgrad4);
}
return outputs;
}
namespace {
enum {
X_TENSOR,
Y_TENSOR,
W_TENSOR,
Z_TENSOR,
B_TENSOR,
AFTERADD_TENSOR,
AFTERBIAS_TENSOR,
AFTERCONV_TENSOR,
OPTIONAL,
AFTEROPT_TENSOR,
AFTERACT_TENSOR,
GEN_INDEX_TENSOR,
MASK_TOP_TENSOR,
MASK_BOTTOM_TENSOR,
MASK_TENSOR,
THRESHOLD_TOP_TENSOR,
THRESHOLD_BOTTOM_TENSOR,
};
using masked_convbias_descriptors = std::tuple<cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor>;
masked_convbias_descriptors
create_conv_bias_add_act_mask_descriptors(int64_t* x_dim_padded,
int64_t* padA,
int64_t* convstrideA,
int64_t* dilationA,
int64_t* w_dim_padded,
int64_t* y_dim_padded,
int64_t* threshold_dim,
cudnnDataType_t dataType) {
const int convDim = 2;
int64_t b_dim_padded[4];
b_dim_padded[0] = 1;
b_dim_padded[1] = y_dim_padded[1];
b_dim_padded[2] = 1;
b_dim_padded[3] = 1;
int64_t x_stride_padded[4];
int64_t y_stride_padded[4];
int64_t w_stride_padded[4];
int64_t b_stride_padded[4];
int64_t threshold_stride[4];
generateStrides(w_dim_padded, w_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(x_dim_padded, x_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(y_dim_padded, y_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(b_dim_padded, b_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(threshold_dim, threshold_stride, 4, CUDNN_TENSOR_NHWC);
return masked_convbias_descriptors(cudnn_frontend::TensorBuilder()
.setDim(4, x_dim_padded)
.setStrides(4, x_stride_padded)
.setId('x')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('y')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, w_dim_padded)
.setStrides(4, w_stride_padded)
.setId('w')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, b_dim_padded)
.setStrides(4, b_stride_padded)
.setId('z')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, b_dim_padded)
.setStrides(4, b_stride_padded)
.setId('b')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setVirtual()
.setId('A') // after add
.setAlignment(16)
.setDataType(CUDNN_DATA_FLOAT)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setVirtual()
.setId('B') // after bias
.setAlignment(16)
.setDataType(CUDNN_DATA_FLOAT)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('C') // after conv
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_FLOAT)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('i')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('D') // after optional add
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_FLOAT)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('E') // after act for masked
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_FLOAT)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('I') // output of the gen index operation
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_INT32)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('m') // top half of the mask created after the less than
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_BOOLEAN)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('n') // bottom half of the mask
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_BOOLEAN)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('M') // OR of the top and bottom masks
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_BOOLEAN)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, threshold_dim)
.setStrides(4, threshold_stride)
.setId('t') // threshold for creating the top mask
.setAlignment(16)
.setDataType(CUDNN_DATA_INT32)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, threshold_dim)
.setStrides(4, threshold_stride)
.setId('u') // threshold for creating the bottom mask
.setAlignment(16)
.setDataType(CUDNN_DATA_INT32)
.build());
}
// tensor descriptors used for dgrad
enum {
X_OR_DX_TENSOR,
DY_TENSOR,
W_OR_DW_TENSOR,
SCALE_TENSOR,
RELU_TENSOR,
AFTER_DCONV_TENSOR,
AFTER_DRELU_TENSOR,
DGRAD_INPUT_TENSOR,
DGRAD_OPTIONAL_TENSOR,
DGRAD_GEN_INDEX_TENSOR,
DGRAD_MASK_TOP_TENSOR,
DGRAD_MASK_BOTTOM_TENSOR,
DGRAD_MASK_TENSOR,
DGRAD_THRESHOLD_TOP_TENSOR,
DGRAD_THRESHOLD_BOTTOM_TENSOR,
};
using dconv_add_descriptors = std::tuple<cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor>;
dconv_add_descriptors
create_dconv_add_descriptors(int64_t* x_dim_padded,
int64_t* padA,
int64_t* convstrideA,
int64_t* dilationA,
int64_t* w_dim_padded,
int64_t* y_dim_padded,
cudnnDataType_t dataType) {
const int convDim = 2;
int64_t b_dim_padded[4];
b_dim_padded[0] = 1;
b_dim_padded[1] = x_dim_padded[1];
b_dim_padded[2] = 1;
b_dim_padded[3] = 1;
int64_t x_stride_padded[4];
int64_t y_stride_padded[4];
int64_t w_stride_padded[4];
int64_t b_stride_padded[4];
generateStrides(w_dim_padded, w_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(x_dim_padded, x_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(y_dim_padded, y_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(b_dim_padded, b_stride_padded, 4, CUDNN_TENSOR_NHWC);
return dconv_add_descriptors(cudnn_frontend::TensorBuilder()
.setDim(4, x_dim_padded)
.setStrides(4, x_stride_padded)
.setId('x')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('y')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, w_dim_padded)
.setStrides(4, w_stride_padded)
.setId('w')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, b_dim_padded)
.setStrides(4, b_stride_padded)
.setId('s')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, x_dim_padded)
.setStrides(4, x_stride_padded)
.setId('r')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, x_dim_padded)
.setStrides(4, x_stride_padded)
.setVirtual()
.setId('A') // after dconv
.setAlignment(16)
.setDataType(CUDNN_DATA_FLOAT)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, x_dim_padded)
.setStrides(4, x_stride_padded)
.setVirtual()
.setId('B') // after drelu
.setAlignment(16)
.setDataType(CUDNN_DATA_FLOAT)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('i')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('D') // after optional add
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_FLOAT)
.build());
}
using dconv_mask_descriptors = std::tuple<cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor,
cudnn_frontend::Tensor>;
dconv_mask_descriptors
create_dconv_mask_descriptors(int64_t* x_dim_padded,
int64_t* padA,
int64_t* convstrideA,
int64_t* dilationA,
int64_t* w_dim_padded,
int64_t* y_dim_padded,
int64_t* threshold_dim,
cudnnDataType_t dataType) {
const int convDim = 2;
int64_t b_dim_padded[4];
b_dim_padded[0] = 1;
b_dim_padded[1] = x_dim_padded[1];
b_dim_padded[2] = 1;
b_dim_padded[3] = 1;
int64_t x_stride_padded[4];
int64_t y_stride_padded[4];
int64_t w_stride_padded[4];
int64_t b_stride_padded[4];
int64_t threshold_stride[4];
generateStrides(w_dim_padded, w_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(x_dim_padded, x_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(y_dim_padded, y_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(b_dim_padded, b_stride_padded, 4, CUDNN_TENSOR_NHWC);
generateStrides(threshold_dim, threshold_stride, 4, CUDNN_TENSOR_NHWC);
return dconv_mask_descriptors(cudnn_frontend::TensorBuilder()
.setDim(4, x_dim_padded)
.setStrides(4, x_stride_padded)
.setId('x')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('y')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, w_dim_padded)
.setStrides(4, w_stride_padded)
.setId('w')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, b_dim_padded)
.setStrides(4, b_stride_padded)
.setId('s')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, x_dim_padded)
.setStrides(4, x_stride_padded)
.setId('r')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, x_dim_padded)
.setStrides(4, x_stride_padded)
.setVirtual()
.setId('A') // after dconv
.setAlignment(16)
.setDataType(CUDNN_DATA_FLOAT)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, x_dim_padded)
.setStrides(4, x_stride_padded)
.setVirtual()
.setId('B') // after drelu
.setAlignment(16)
.setDataType(CUDNN_DATA_FLOAT)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('i')
.setAlignment(16)
.setDataType(dataType)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('D') // after optional add
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_FLOAT)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('I') // output of the gen index operation
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_INT32)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('m') // top half of the mask created after the less than
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_BOOLEAN)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('n') // bottom half of the mask
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_BOOLEAN)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, y_dim_padded)
.setStrides(4, y_stride_padded)
.setId('M') // OR of the top and bottom masks
.setAlignment(16)
.setVirtual()
.setDataType(CUDNN_DATA_BOOLEAN)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, threshold_dim)
.setStrides(4, threshold_stride)
.setId('t') // threshold for creating the top mask
.setAlignment(16)
.setDataType(CUDNN_DATA_INT32)
.build(),
cudnn_frontend::TensorBuilder()
.setDim(4, threshold_dim)
.setStrides(4, threshold_stride)
.setId('u') // threshold for creating the bottom mask
.setAlignment(16)
.setDataType(CUDNN_DATA_INT32)
.build());
}
void
run_conv_add_scale_bias_activation(int64_t* x_dim_padded,
int64_t* pad,
int64_t* convstride,
int64_t* dilation,
int64_t* w_dim_padded,
int64_t* y_dim_padded,
cudnnDataType_t dataType,
at::Half* devPtrX,
at::Half* devPtrW,
at::Half* devPtrY,
at::Half* devPtrZ,
at::Half* devPtrB,
at::Half* devPtrI) {
cudnnHandle_t handle_ = torch::native::getCudnnHandle();
std::stringstream log_buf;
try {
int convDim = 2;
// Creates the necessary tensor descriptors
common_convbias_descriptors tensors = create_conv_bias_add_act_descriptors(
x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType);
DEBUG_CUDNN_MSG(log_buf, std::get<X_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<Y_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<W_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<Z_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<B_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTERADD_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTERBIAS_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTERCONV_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<OPTIONAL>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTEROPT_TENSOR>(tensors).describe());
// Define the add operation
auto scaleDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_MUL)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe());
// Define the bias operation
auto biasDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_ADD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, biasDesc.describe());
// optional add
auto addDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_ADD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, addDesc.describe());
// Define the activation operation
auto actDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_RELU_FWD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, actDesc.describe());
// Define the convolution problem
auto convDesc = cudnn_frontend::ConvDescBuilder()
.setDataType(CUDNN_DATA_FLOAT)
.setMathMode(CUDNN_CROSS_CORRELATION)
.setNDims(convDim)
.setStrides(convDim, convstride)
.setPrePadding(convDim, pad)
.setPostPadding(convDim, pad)
.setDilation(convDim, dilation)
.build();
DEBUG_CUDNN_MSG(log_buf, convDesc.describe());
float alpha = 1.0f;
float beta = 0.0f;
// Create a convolution Node
auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR)
.setxDesc(std::get<X_TENSOR>(tensors))
.setwDesc(std::get<W_TENSOR>(tensors))
.setyDesc(std::get<AFTERCONV_TENSOR>(tensors))
.setcDesc(convDesc)
.setAlpha(alpha)
.setBeta(beta)
.build();
DEBUG_CUDNN_MSG(log_buf, conv_op.describe());
// create an add node.
auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(conv_op.getOutputTensor())
.setbDesc(std::get<OPTIONAL>(tensors))
.setyDesc(std::get<AFTEROPT_TENSOR>(tensors))
.setpwDesc(addDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, add_op.describe());
// Create a Add Node with scaling parameters.
auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(add_op.getOutputTensor())
.setbDesc(std::get<Z_TENSOR>(tensors))
.setyDesc(std::get<AFTERADD_TENSOR>(tensors))
.setpwDesc(scaleDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, scale_op.describe());
// Create a Bias Node.
auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(scale_op.getOutputTensor())
.setbDesc(std::get<B_TENSOR>(tensors))
.setyDesc(std::get<AFTERBIAS_TENSOR>(tensors))
.setpwDesc(biasDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, bias_op.describe());
// Create an Activation Node.
auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(bias_op.getOutputTensor())
.setyDesc(std::get<Y_TENSOR>(tensors))
.setpwDesc(actDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, act_op.describe());
// Create an Operation Graph. In this case it is convolution add bias activation
std::array<cudnn_frontend::Operation const*, 5> ops = {&conv_op, &add_op, &scale_op, &bias_op, &act_op};
auto opGraph = cudnn_frontend::OperationGraphBuilder()
.setHandle(handle_)
.setOperationGraph(ops.size(), ops.data())
.build();
// Create string encoding for plan caching
auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag());
DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
auto workspace_size = plan.getWorkspaceSize();
DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
void* workspace_ptr = nullptr;
auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
if (workspace_size > 0) {
workspace_ptr = workspace_tensor.data_ptr<float>();
}
void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrB, devPtrI};
int64_t uids[] = {'x', 'y', 'w', 'z', 'b', 'i'};
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace_ptr)
.setDataPointers(6, data_ptrs)
.setUids(6, uids)
.build();
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
checkCudnnErr(status);
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
} catch (cudnn_frontend::cudnnException e) {
std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
}
}
void
run_conv_scale_bias_add_activation_mask(int64_t* x_dim_padded,
int64_t* pad,
int64_t* convstride,
int64_t* dilation,
int64_t* w_dim_padded,
int64_t* y_dim_padded,
int64_t* threshold_dim,
cudnnDataType_t dataType,
at::Half* devPtrX,
at::Half* devPtrW,
at::Half* devPtrY,
at::Half* devPtrZ,
at::Half* devPtrB,
at::Half* devPtrI,
int* devPtrT,
int* devPtrU,
int axis) {
cudnnHandle_t handle_ = torch::native::getCudnnHandle();
std::stringstream log_buf;
try {
int convDim = 2;
// Creates the necessary tensor descriptors
masked_convbias_descriptors tensors = create_conv_bias_add_act_mask_descriptors(
x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, threshold_dim, dataType);
DEBUG_CUDNN_MSG(log_buf, std::get<X_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<Y_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<W_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<Z_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<B_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTERADD_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTERBIAS_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTERCONV_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<OPTIONAL>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTERACT_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<GEN_INDEX_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<MASK_TOP_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<MASK_BOTTOM_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<MASK_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<THRESHOLD_TOP_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<THRESHOLD_BOTTOM_TENSOR>(tensors).describe());
// Define the add operation
auto scaleDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_MUL)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe());
// Define the bias operation
auto biasDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_ADD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, biasDesc.describe());
// optional add
auto addDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_ADD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, addDesc.describe());
// Define the activation operation
auto actDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_RELU_FWD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, actDesc.describe());
// Define the convolution problem
auto convDesc = cudnn_frontend::ConvDescBuilder()
.setDataType(CUDNN_DATA_FLOAT)
.setMathMode(CUDNN_CROSS_CORRELATION)
.setNDims(convDim)
.setStrides(convDim, convstride)
.setPrePadding(convDim, pad)
.setPostPadding(convDim, pad)
.setDilation(convDim, dilation)
.build();
DEBUG_CUDNN_MSG(log_buf, convDesc.describe());
// Define the genIndex descriptor
auto genIndexDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_GEN_INDEX)
.setMathPrecision(CUDNN_DATA_FLOAT)
.setAxis(axis)
.build();
DEBUG_CUDNN_MSG(log_buf, genIndexDesc.describe());
// Define the lessThan descriptor
auto lessThanDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_CMP_LT)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, lessThanDesc.describe());
// Define the greaterThan descriptor
auto greaterThanDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_CMP_GT)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, greaterThanDesc.describe());
// Define the logical_or descriptor
auto logicalOrDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_LOGICAL_OR)
.setMathPrecision(CUDNN_DATA_BOOLEAN)
.build();
DEBUG_CUDNN_MSG(log_buf, logicalOrDesc.describe());
// Define the binary_selection descriptor
auto selectionDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_BINARY_SELECT)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, selectionDesc.describe());
float alpha = 1.0f;
float beta = 0.0f;
// Create a convolution Node
auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR)
.setxDesc(std::get<X_TENSOR>(tensors))
.setwDesc(std::get<W_TENSOR>(tensors))
.setyDesc(std::get<AFTERCONV_TENSOR>(tensors))
.setcDesc(convDesc)
.setAlpha(alpha)
.setBeta(beta)
.build();
DEBUG_CUDNN_MSG(log_buf, conv_op.describe());
// Create a Add Node with scaling parameters.
auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(conv_op.getOutputTensor())
.setbDesc(std::get<Z_TENSOR>(tensors))
.setyDesc(std::get<AFTERADD_TENSOR>(tensors))
.setpwDesc(scaleDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, scale_op.describe());
// Create a Bias Node.
auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(scale_op.getOutputTensor())
.setbDesc(std::get<B_TENSOR>(tensors))
.setyDesc(std::get<AFTERBIAS_TENSOR>(tensors))
.setpwDesc(biasDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, bias_op.describe());
// Create a optional add Node.
auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(bias_op.getOutputTensor())
.setbDesc(std::get<OPTIONAL>(tensors))
.setyDesc(std::get<AFTEROPT_TENSOR>(tensors))
.setpwDesc(addDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, add_op.describe());
// Create an Activation Node.
auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(devPtrI ? add_op.getOutputTensor() : bias_op.getOutputTensor())
.setyDesc(std::get<AFTERACT_TENSOR>(tensors))
.setpwDesc(actDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, act_op.describe());
// Create a Gen_Index Node.
auto genIndex_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<AFTERACT_TENSOR>(tensors))
.setyDesc(std::get<GEN_INDEX_TENSOR>(tensors))
.setpwDesc(genIndexDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, genIndex_op.describe());
// Create a LessThan Node.
auto lessThan_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<GEN_INDEX_TENSOR>(tensors))
.setbDesc(std::get<THRESHOLD_TOP_TENSOR>(tensors))
.setyDesc(std::get<MASK_TOP_TENSOR>(tensors))
.setpwDesc(lessThanDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, lessThan_op.describe());
// Create a GreaterThan Node.
auto greaterThan_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<GEN_INDEX_TENSOR>(tensors))
.setbDesc(std::get<THRESHOLD_BOTTOM_TENSOR>(tensors))
.setyDesc(std::get<MASK_BOTTOM_TENSOR>(tensors))
.setpwDesc(greaterThanDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, greaterThan_op.describe());
// Create a LogicalOr Node.
auto logicalOr_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<MASK_TOP_TENSOR>(tensors))
.setbDesc(std::get<MASK_BOTTOM_TENSOR>(tensors))
.setyDesc(std::get<MASK_TENSOR>(tensors))
.setpwDesc(logicalOrDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, logicalOr_op.describe());
// Create a Binary_Selection Node.
auto selection_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<AFTERCONV_TENSOR>(tensors))
.setbDesc(std::get<AFTERACT_TENSOR>(tensors))
.settDesc(std::get<MASK_TENSOR>(tensors))
.setyDesc(std::get<Y_TENSOR>(tensors))
.setpwDesc(selectionDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, selection_op.describe());
// Create an Operation Graph. In this case it is convolution add bias activation
if (devPtrI) {
std::array<cudnn_frontend::Operation const*, 10> ops = {&conv_op, &scale_op, &bias_op, &add_op, &act_op, &genIndex_op, &lessThan_op, &greaterThan_op, &logicalOr_op, &selection_op};
auto opGraph = cudnn_frontend::OperationGraphBuilder()
.setHandle(handle_)
.setOperationGraph(ops.size(), ops.data())
.build();
// Create string encoding for plan caching
auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag());
DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
auto workspace_size = plan.getWorkspaceSize();
DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
void* workspace_ptr = nullptr;
auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
if (workspace_size > 0) {
workspace_ptr = workspace_tensor.data_ptr<float>();
}
void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrB, devPtrI, devPtrT, devPtrU};
int64_t uids[] = {'x', 'y', 'w', 'z', 'b', 'i', 't', 'u'};
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace_ptr)
.setDataPointers(8, data_ptrs)
.setUids(8, uids)
.build();
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
checkCudnnErr(status);
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
} else {
std::array<cudnn_frontend::Operation const*, 9> ops = {&conv_op, &scale_op, &bias_op, &act_op, &genIndex_op, &lessThan_op, &greaterThan_op, &logicalOr_op, &selection_op};
auto opGraph = cudnn_frontend::OperationGraphBuilder()
.setHandle(handle_)
.setOperationGraph(ops.size(), ops.data())
.build();
// Create string encoding for plan caching
auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag());
DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
auto workspace_size = plan.getWorkspaceSize();
DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
void* workspace_ptr = nullptr;
auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
if (workspace_size > 0) {
workspace_ptr = workspace_tensor.data_ptr<float>();
}
void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrB, devPtrT, devPtrU};
int64_t uids[] = {'x', 'y', 'w', 'z', 'b', 't', 'u'};
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace_ptr)
.setDataPointers(7, data_ptrs)
.setUids(7, uids)
.build();
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
checkCudnnErr(status);
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
}
} catch (cudnn_frontend::cudnnException e) {
std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
}
}
void
run_dconv_add_drelu_dscale(int64_t* x_dim_padded,
int64_t* pad,
int64_t* convstride,
int64_t* dilation,
int64_t* w_dim_padded,
int64_t* y_dim_padded,
cudnnDataType_t dataType,
at::Half* devPtrX,
at::Half* devPtrW,
at::Half* devPtrY,
at::Half* devPtrZ,
at::Half* devPtrR,
at::Half* devPtrI) {
cudnnHandle_t handle_ = torch::native::getCudnnHandle();
std::stringstream log_buf;
try {
int convDim = 2;
// Creates the necessary tensor descriptors
dconv_add_descriptors tensors = create_dconv_add_descriptors(
x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, dataType);
DEBUG_CUDNN_MSG(log_buf, std::get<X_OR_DX_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DY_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<W_OR_DW_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<SCALE_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<RELU_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTER_DCONV_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTER_DRELU_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DGRAD_INPUT_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DGRAD_OPTIONAL_TENSOR>(tensors).describe());
// Define the convolution problem
auto convDesc = cudnn_frontend::ConvDescBuilder()
.setDataType(CUDNN_DATA_FLOAT)
.setMathMode(CUDNN_CROSS_CORRELATION)
.setNDims(convDim)
.setStrides(convDim, convstride)
.setPrePadding(convDim, pad)
.setPostPadding(convDim, pad)
.setDilation(convDim, dilation)
.build();
DEBUG_CUDNN_MSG(log_buf, convDesc.describe());
// optional add
auto addDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_ADD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, addDesc.describe());
// Define the activation backward operation
auto actDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_RELU_BWD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, actDesc.describe());
// Define the scale backward operation
auto scaleDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_MUL)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe());
float alpha = 1.0f;
float beta = 0.0f;
// Create a convolution Node
auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR)
.setdxDesc(std::get<AFTER_DCONV_TENSOR>(tensors))
.setwDesc(std::get<W_OR_DW_TENSOR>(tensors))
.setdyDesc(std::get<DY_TENSOR>(tensors))
.setcDesc(convDesc)
.setAlpha(alpha)
.setBeta(beta)
.build();
DEBUG_CUDNN_MSG(log_buf, conv_op.describe());
// Create add Node.
auto add_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<AFTER_DCONV_TENSOR>(tensors))
.setbDesc(std::get<DGRAD_INPUT_TENSOR>(tensors))
.setyDesc(std::get<DGRAD_OPTIONAL_TENSOR>(tensors))
.setpwDesc(addDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, add_op.describe());
// TODO: do we need getOutputTensor(), and what it returns in backward case?
// Create an relu backward Node.
auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setdyDesc(std::get<DGRAD_OPTIONAL_TENSOR>(tensors))
.setxDesc(std::get<RELU_TENSOR>(tensors))
.setdxDesc(std::get<AFTER_DRELU_TENSOR>(tensors))
.setpwDesc(actDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, act_op.describe());
// Create a Scale Node.
auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<AFTER_DRELU_TENSOR>(tensors))
.setbDesc(std::get<SCALE_TENSOR>(tensors))
.setyDesc(std::get<X_OR_DX_TENSOR>(tensors))
.setpwDesc(scaleDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, scale_op.describe());
// Create an Operation Graph. In this case it is convolution add bias activation
std::array<cudnn_frontend::Operation const*, 4> ops = {&conv_op, &add_op, &act_op, &scale_op};
auto opGraph = cudnn_frontend::OperationGraphBuilder()
.setHandle(handle_)
.setOperationGraph(ops.size(), ops.data())
.build();
// Create string encoding for plan caching
auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag());
DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
auto workspace_size = plan.getWorkspaceSize();
DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
void* workspace_ptr = nullptr;
auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
if (workspace_size > 0) {
workspace_ptr = workspace_tensor.data_ptr<float>();
}
void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrR, devPtrI};
int64_t uids[] = {'x', 'y', 'w', 's', 'r', 'i'};
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace_ptr)
.setDataPointers(6, data_ptrs)
.setUids(6, uids)
.build();
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
checkCudnnErr(status);
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
} catch (cudnn_frontend::cudnnException e) {
std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
}
}
void
run_dconv_drelu_dscale_mask(int64_t* x_dim_padded,
int64_t* pad,
int64_t* convstride,
int64_t* dilation,
int64_t* w_dim_padded,
int64_t* y_dim_padded,
int64_t* threshold_dim,
cudnnDataType_t dataType,
at::Half* devPtrX,
at::Half* devPtrW,
at::Half* devPtrY,
at::Half* devPtrZ,
at::Half* devPtrR,
int* devPtrT,
int* devPtrU,
int axis) {
cudnnHandle_t handle_ = torch::native::getCudnnHandle();
std::stringstream log_buf;
try {
int convDim = 2;
// Creates the necessary tensor descriptors
dconv_mask_descriptors tensors = create_dconv_mask_descriptors(
x_dim_padded, pad, convstride, dilation, w_dim_padded, y_dim_padded, threshold_dim, dataType);
DEBUG_CUDNN_MSG(log_buf, std::get<X_OR_DX_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DY_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<W_OR_DW_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<SCALE_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<RELU_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTER_DCONV_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<AFTER_DRELU_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DGRAD_OPTIONAL_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DGRAD_GEN_INDEX_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DGRAD_MASK_TOP_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DGRAD_MASK_BOTTOM_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DGRAD_MASK_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DGRAD_THRESHOLD_TOP_TENSOR>(tensors).describe());
DEBUG_CUDNN_MSG(log_buf, std::get<DGRAD_THRESHOLD_BOTTOM_TENSOR>(tensors).describe());
// Define the convolution problem
auto convDesc = cudnn_frontend::ConvDescBuilder()
.setDataType(CUDNN_DATA_FLOAT)
.setMathMode(CUDNN_CROSS_CORRELATION)
.setNDims(convDim)
.setStrides(convDim, convstride)
.setPrePadding(convDim, pad)
.setPostPadding(convDim, pad)
.setDilation(convDim, dilation)
.build();
DEBUG_CUDNN_MSG(log_buf, convDesc.describe());
// Define the activation backward operation
auto actDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_RELU_BWD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, actDesc.describe());
// Define the scale backward operation
auto scaleDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_MUL)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, scaleDesc.describe());
// Define the genIndex descriptor
auto genIndexDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_GEN_INDEX)
.setMathPrecision(CUDNN_DATA_FLOAT)
.setAxis(axis)
.build();
DEBUG_CUDNN_MSG(log_buf, genIndexDesc.describe());
// Define the lessThan descriptor
auto lessThanDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_CMP_LT)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, lessThanDesc.describe());
// Define the greaterThan descriptor
auto greaterThanDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_CMP_GT)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, greaterThanDesc.describe());
// Define the logical_or descriptor
auto logicalOrDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_LOGICAL_OR)
.setMathPrecision(CUDNN_DATA_BOOLEAN)
.build();
DEBUG_CUDNN_MSG(log_buf, logicalOrDesc.describe());
// Define the binary_selection descriptor
auto selectionDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_BINARY_SELECT)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, selectionDesc.describe());
float alpha = 1.0f;
float beta = 0.0f;
// Create a convolution Node
auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR)
.setdxDesc(std::get<AFTER_DCONV_TENSOR>(tensors))
.setwDesc(std::get<W_OR_DW_TENSOR>(tensors))
.setdyDesc(std::get<DY_TENSOR>(tensors))
.setcDesc(convDesc)
.setAlpha(alpha)
.setBeta(beta)
.build();
DEBUG_CUDNN_MSG(log_buf, conv_op.describe());
// TODO: do we need getOutputTensor(), and what it returns in backward case?
// Create an relu backward Node.
auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setdyDesc(std::get<AFTER_DCONV_TENSOR>(tensors))
.setxDesc(std::get<RELU_TENSOR>(tensors))
.setdxDesc(std::get<AFTER_DRELU_TENSOR>(tensors))
.setpwDesc(actDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, act_op.describe());
// Create a Scale Node.
auto scale_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<AFTER_DRELU_TENSOR>(tensors))
.setbDesc(std::get<SCALE_TENSOR>(tensors))
.setyDesc(std::get<DGRAD_OPTIONAL_TENSOR>(tensors))
.setpwDesc(scaleDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, scale_op.describe());
// Create a Gen_Index Node.
auto genIndex_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<DGRAD_OPTIONAL_TENSOR>(tensors))
.setyDesc(std::get<DGRAD_GEN_INDEX_TENSOR>(tensors))
.setpwDesc(genIndexDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, genIndex_op.describe());
// Create a LessThan Node.
auto lessThan_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<DGRAD_GEN_INDEX_TENSOR>(tensors))
.setbDesc(std::get<DGRAD_THRESHOLD_TOP_TENSOR>(tensors))
.setyDesc(std::get<DGRAD_MASK_TOP_TENSOR>(tensors))
.setpwDesc(lessThanDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, lessThan_op.describe());
// Create a GreaterThan Node.
auto greaterThan_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<DGRAD_GEN_INDEX_TENSOR>(tensors))
.setbDesc(std::get<DGRAD_THRESHOLD_BOTTOM_TENSOR>(tensors))
.setyDesc(std::get<DGRAD_MASK_BOTTOM_TENSOR>(tensors))
.setpwDesc(greaterThanDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, greaterThan_op.describe());
// Create a LogicalOr Node.
auto logicalOr_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<DGRAD_MASK_TOP_TENSOR>(tensors))
.setbDesc(std::get<DGRAD_MASK_BOTTOM_TENSOR>(tensors))
.setyDesc(std::get<DGRAD_MASK_TENSOR>(tensors))
.setpwDesc(logicalOrDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, logicalOr_op.describe());
// Create a Binary_Selection Node.
auto selection_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(std::get<AFTER_DCONV_TENSOR>(tensors))
.setbDesc(std::get<DGRAD_OPTIONAL_TENSOR>(tensors))
.settDesc(std::get<DGRAD_MASK_TENSOR>(tensors))
.setyDesc(std::get<X_OR_DX_TENSOR>(tensors))
.setpwDesc(selectionDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, selection_op.describe());
// Create an Operation Graph. In this case it is convolution add bias activation
std::array<cudnn_frontend::Operation const*, 8> ops = {&conv_op, &act_op, &scale_op, &genIndex_op, &lessThan_op, &greaterThan_op, &logicalOr_op, &selection_op};
auto opGraph = cudnn_frontend::OperationGraphBuilder()
.setHandle(handle_)
.setOperationGraph(ops.size(), ops.data())
.build();
// Create string encoding for plan caching
auto cache_string = getConvFusionString(x_dim_padded, pad, convstride, dilation, w_dim_padded, dataType, opGraph.getTag());
DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
auto workspace_size = plan.getWorkspaceSize();
DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
void* workspace_ptr = nullptr;
auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
if (workspace_size > 0) {
workspace_ptr = workspace_tensor.data_ptr<float>();
}
void* data_ptrs[] = {devPtrX, devPtrY, devPtrW, devPtrZ, devPtrR, devPtrT, devPtrU};
int64_t uids[] = {'x', 'y', 'w', 's', 'r', 't', 'u'};
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace_ptr)
.setDataPointers(7, data_ptrs)
.setUids(7, uids)
.build();
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
checkCudnnErr(status);
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
} catch (cudnn_frontend::cudnnException e) {
std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
}
}
struct bottleneck_forward_status {
int64_t dimA[4];
int64_t filterdimA1[4];
int64_t filterdimA2[4];
int64_t filterdimA2hh[4];
int64_t filterdimA3[4];
int64_t filterdimA4[4];
int64_t threshdim[4];
int axis[4];
int64_t outdimA0[4];
int64_t outdimA1[4];
int64_t outdimA1b[4]; // out1_pad
int64_t outdimA2[4];
int64_t outdimA3[4];
int64_t outdimA4[4];
int64_t padA[2];
int64_t padA1[2];
int64_t padA2[2]; // halo padding
int64_t dilationA[2];
int64_t convstrideA[2];
int64_t convstride1X1[2];
int64_t outdim0[4]; // halo input shape
int64_t outdim1[4];
int64_t outdim1b[4];
int64_t outdim2[4];
int64_t outdim3[4];
int64_t outdim4[4]; // halo output shape
void init(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs) {
dimA[0] = dimA[1] = dimA[2] = dimA[3] = 0;
filterdimA1[0] = filterdimA1[1] = filterdimA1[2] = filterdimA1[3] = 0;
filterdimA2[0] = filterdimA2[1] = filterdimA2[2] = filterdimA2[3] = 0;
filterdimA2hh[0] = filterdimA2hh[1] = filterdimA2hh[2] = filterdimA2hh[3] = 0;
filterdimA3[0] = filterdimA3[1] = filterdimA3[2] = filterdimA3[3] = 0;
filterdimA4[0] = filterdimA4[1] = filterdimA4[2] = filterdimA4[3] = 0;
threshdim[0] = threshdim[1] = threshdim[2] = threshdim[3] = 1;
// All dim calculation after this order of n,c,h,w
if (explicit_nhwc) {
axis[0] = 0;
axis[1] = 3;
axis[2] = 1;
axis[3] = 2;
} else {
axis[0] = 0;
axis[1] = 1;
axis[2] = 2;
axis[3] = 3;
}
for (int dim=0;dim<4;dim++) {
dimA[dim] = inputs[0].size(axis[dim]);
filterdimA1[dim] = inputs[1].size(axis[dim]);
filterdimA2[dim] = inputs[2].size(axis[dim]);
filterdimA3[dim] = inputs[3].size(axis[dim]);
}
if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]) {
for (int dim=0;dim<4;dim++) {
filterdimA4[dim] = inputs[10].size(axis[dim]);
}
}
for (int dim=0;dim<4;dim++) {
if (dim == 2) {
filterdimA2hh[dim] = 1;
} else {
filterdimA2hh[dim] = filterdimA2[dim];
}
}
// output dim in n,c,h,w used by backend
outdimA0[0] = outdimA0[1] = outdimA0[2] = outdimA0[3] = 0;
outdimA1[0] = outdimA1[1] = outdimA1[2] = outdimA1[3] = 0;
outdimA1b[0] = outdimA1b[1] = outdimA1b[2] = outdimA1b[3] = 0;
outdimA2[0] = outdimA2[1] = outdimA2[2] = outdimA2[3] = 0;
outdimA3[0] = outdimA3[1] = outdimA3[2] = outdimA3[3] = 0;
outdimA4[0] = outdimA4[1] = outdimA4[2] = outdimA4[3] = 0;
// use these fixed value for test run
padA[0] = 0; padA[1] = 0;
padA1[0] = 1; padA1[1] = 1;
padA2[0] = 0; padA2[1] = 1;
dilationA[0] = 1; dilationA[1] = 1;
convstrideA[0] = 1; convstrideA[1] = 1;
convstride1X1[0] = stride_1X1; convstride1X1[1] = stride_1X1;
// compute output from pad/stride/dilation
outdimA1[0] = dimA[0];
outdimA1[1] = filterdimA1[0];
for (int dim = 0; dim < 2; dim++) {
outdimA1[dim + 2] = getFwdConvOutputDim(dimA[dim + 2], padA[dim], filterdimA1[dim + 2], convstride1X1[dim], dilationA[dim]);
}
for (int dim = 0; dim < 4; dim++) {
if (dim == 2) {
outdimA1b[dim] = outdimA1[dim] + 2;
} else {
outdimA1b[dim] = outdimA1[dim];
}
}
outdimA2[0] = outdimA1[0];
outdimA2[1] = filterdimA2[0];
for (int dim = 0; dim < 2; dim++) {
outdimA2[dim + 2] = getFwdConvOutputDim(outdimA1[dim + 2], padA1[dim], filterdimA2[dim + 2], convstrideA[dim], dilationA[dim]);
}
for (int dim = 0; dim < 4; dim++) {
if (dim == 2) {
outdimA0[dim] = 3;
outdimA4[dim] = 1;
} else {
outdimA0[dim] = outdimA1[dim];
outdimA4[dim] = outdimA2[dim];
}
}
outdimA3[0] = outdimA2[0];
outdimA3[1] = filterdimA3[0];
for (int dim = 0; dim < 2; dim++) {
outdimA3[dim + 2] = getFwdConvOutputDim(outdimA2[dim + 2], padA[dim], filterdimA3[dim + 2], convstrideA[dim], dilationA[dim]);
}
// Create output tensor in the correct shape in pytorch's view
outdim1[0] = outdim1[1] = outdim1[2] = outdim1[3] = 0;
outdim1b[0] = outdim1b[1] = outdim1b[2] = outdim1b[3] = 0;
outdim2[0] = outdim2[1] = outdim2[2] = outdim2[3] = 0;
outdim3[0] = outdim3[1] = outdim3[2] = outdim3[3] = 0;
if (explicit_nhwc) {
axis[0] = 0;
axis[1] = 2;
axis[2] = 3;
axis[3] = 1;
}
for (int dim=0;dim<4;dim++) {
outdim0[dim] = outdimA0[axis[dim]];
outdim1[dim] = outdimA1[axis[dim]];
outdim1b[dim] = outdimA1b[axis[dim]];
outdim2[dim] = outdimA2[axis[dim]];
outdim3[dim] = outdimA3[axis[dim]];
outdim4[dim] = outdimA4[axis[dim]];
}
}
};
bottleneck_forward_status forward_state;
} // end of anonymous namespace
std::vector<at::Tensor> bottleneck_forward_init(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs) {
// NB! Bottleneck_forward and bottleneck_backward are NOT thread safe method.
// NB! We use a global object to store state.
forward_state.init(explicit_nhwc, stride_1X1, inputs);
// create output vector
std::vector<at::Tensor> outputs;
auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
//printf("outdim1 = (%d,%d,%d,%d)\n",forward_state.outdim1[0],forward_state.outdim1[1],forward_state.outdim1[2],forward_state.outdim1[3]);
auto out1 = at::empty(forward_state.outdim1, inputs[0].type(), output_format);
auto out2 = at::empty(forward_state.outdim2, inputs[0].type(), output_format);
auto out3 = at::empty(forward_state.outdim3, inputs[0].type(), output_format);
outputs.push_back(out1);
outputs.push_back(out2);
outputs.push_back(out3);
return outputs;
}
// inputs contains x,w,z,b,(i)
void bottleneck_forward_out1(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs) {
std::cout << std::fixed;
// run
at::Half* x = inputs[0].data_ptr<at::Half>();
at::Half* w = inputs[1].data_ptr<at::Half>();
at::Half* z = inputs[4].data_ptr<at::Half>();
at::Half* b = inputs[7].data_ptr<at::Half>();
auto out1 = outputs[0];
at::Half* y1 = out1.data_ptr<at::Half>();
run_conv_scale_bias_add_activation(forward_state.dimA,
forward_state.padA,
forward_state.convstride1X1,
forward_state.dilationA,
forward_state.filterdimA1,
forward_state.outdimA1,
CUDNN_DATA_HALF,
x,
w,
y1,
z,
b,
nullptr);
DEBUG_MSG("[DEBUG] new relu1 : " << out1.to(at::kFloat).sum().item<float>());
}
// computes halo (top or bottom) from fat halo input.
// fat halo input is 3 pixels wide in H.
at::Tensor bottleneck_forward_out2_halo(bool explicit_nhwc, at::Tensor fat_halo_y1, std::vector<at::Tensor> inputs) {
auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
// run
at::Half* w = inputs[2].data_ptr<at::Half>();
at::Half* z = inputs[5].data_ptr<at::Half>();
at::Half* b = inputs[8].data_ptr<at::Half>();
at::Half* y1 = fat_halo_y1.data_ptr<at::Half>();
auto halo_y2 = at::empty(forward_state.outdim4, inputs[0].type(), output_format);
at::Half* y2 = halo_y2.data_ptr<at::Half>();
run_conv_scale_bias_add_activation(forward_state.outdimA0,
forward_state.padA2,
forward_state.convstrideA,
forward_state.dilationA,
forward_state.filterdimA2,
forward_state.outdimA4,
CUDNN_DATA_HALF,
y1,
w,
y2,
z,
b,
nullptr);
return halo_y2;
}
// compute halo correction term (top or bottom) from slim halo input (N,C,1,W).
// slim halo input is 1 pixel wide in H.
at::Tensor bottleneck_forward_out2_halo_corr(bool explicit_nhwc, at::Tensor slim_halo_y1, std::vector<at::Tensor> inputs, at::Tensor w1by3, at::Tensor out2_part_halo) {
auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
// run
at::Half* w = w1by3.data_ptr<at::Half>(); // C,C,1,3
at::Half* z = inputs[5].data_ptr<at::Half>();
at::Half* b = inputs[8].data_ptr<at::Half>();
at::Half* y1 = slim_halo_y1.data_ptr<at::Half>();
at::Half* prev_out2 = out2_part_halo.data_ptr<at::Half>();
auto halo_y2 = at::empty(forward_state.outdim4, inputs[0].type(), output_format);
at::Half* y2 = halo_y2.data_ptr<at::Half>();
run_conv_add_scale_bias_activation(forward_state.outdimA4,
forward_state.padA2,
forward_state.convstrideA,
forward_state.dilationA,
forward_state.filterdimA2hh,
forward_state.outdimA4,
CUDNN_DATA_HALF,
y1,
w,
y2,
z,
b,
prev_out2);
return halo_y2;
}
void bottleneck_forward_out2(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs) {
std::cout << std::fixed;
// from _out1 method
at::Half* x = inputs[0].data_ptr<at::Half>();
auto out1 = outputs[0];
at::Half* y1 = out1.data_ptr<at::Half>();
// run
at::Half* w = inputs[2].data_ptr<at::Half>();
at::Half* z = inputs[5].data_ptr<at::Half>();
at::Half* b = inputs[8].data_ptr<at::Half>();
auto out2 = outputs[1];
at::Half* y2 = out2.data_ptr<at::Half>();
//printf("forward_state.outdimA1 = {%d,%d,%d,%d}\n",forward_state.outdimA1[0],forward_state.outdimA1[1],forward_state.outdimA1[2],forward_state.outdimA1[3]);
//printf("forward_state.padA1 = {%d,%d}\n",forward_state.padA1[0],forward_state.padA1[1]);
//printf("forward_state.convstrideA = {%d,%d}\n",forward_state.convstrideA[0],forward_state.convstrideA[1]);
//printf("forward_state.dilationA = {%d,%d}\n",forward_state.dilationA[0],forward_state.dilationA[1]);
//printf("forward_state.filterdimA2 = {%d,%d,%d,%d}\n",forward_state.filterdimA2[0],forward_state.filterdimA2[1],forward_state.filterdimA2[2],forward_state.filterdimA2[3]);
//printf("forward_state.outdimA2 = {%d,%d,%d,%d}\n",forward_state.outdimA2[0],forward_state.outdimA2[1],forward_state.outdimA2[2],forward_state.outdimA2[3]);
run_conv_scale_bias_add_activation(forward_state.outdimA1,
forward_state.padA1,
forward_state.convstrideA,
forward_state.dilationA,
forward_state.filterdimA2,
forward_state.outdimA2,
CUDNN_DATA_HALF,
y1,
w,
y2,
z,
b,
nullptr);
DEBUG_MSG("[DEBUG] new relu2 : " << out2.to(at::kFloat).sum().item<float>());
}
void bottleneck_forward_out2_mask(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor thresholdTop, at::Tensor thresholdBottom) {
std::cout << std::fixed;
// from _out1 method
at::Half* x = inputs[0].data_ptr<at::Half>();
auto out1 = outputs[0];
at::Half* y1 = out1.data_ptr<at::Half>();
// run
at::Half* w = inputs[2].data_ptr<at::Half>();
at::Half* z = inputs[5].data_ptr<at::Half>();
at::Half* b = inputs[8].data_ptr<at::Half>();
auto out2 = outputs[1];
at::Half* y2 = out2.data_ptr<at::Half>();
//printf("forward_state.outdimA1 = {%d,%d,%d,%d}\n",forward_state.outdimA1[0],forward_state.outdimA1[1],forward_state.outdimA1[2],forward_state.outdimA1[3]);
//printf("forward_state.padA1 = {%d,%d}\n",forward_state.padA1[0],forward_state.padA1[1]);
//printf("forward_state.convstrideA = {%d,%d}\n",forward_state.convstrideA[0],forward_state.convstrideA[1]);
//printf("forward_state.dilationA = {%d,%d}\n",forward_state.dilationA[0],forward_state.dilationA[1]);
//printf("forward_state.filterdimA2 = {%d,%d,%d,%d}\n",forward_state.filterdimA2[0],forward_state.filterdimA2[1],forward_state.filterdimA2[2],forward_state.filterdimA2[3]);
//printf("forward_state.outdimA2 = {%d,%d,%d,%d}\n",forward_state.outdimA2[0],forward_state.outdimA2[1],forward_state.outdimA2[2],forward_state.outdimA2[3]);
run_conv_scale_bias_add_activation_mask(forward_state.outdimA1,
forward_state.padA1,
forward_state.convstrideA,
forward_state.dilationA,
forward_state.filterdimA2,
forward_state.outdimA2,
forward_state.threshdim,
CUDNN_DATA_HALF,
y1,
w,
y2,
z,
b,
nullptr,
thresholdTop.data_ptr<int>(),
thresholdBottom.data_ptr<int>(),
2); // axis == 1 -> Does this assume explicit NHWC?
DEBUG_MSG("[DEBUG] new relu2 : " << out2.to(at::kFloat).sum().item<float>());
}
void bottleneck_forward_out2_pad(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor out1_pad) {
std::cout << std::fixed;
// from _out1 method
at::Half* x = inputs[0].data_ptr<at::Half>();
auto out1 = outputs[0];
at::Half* y1 = out1_pad.data_ptr<at::Half>();
// run
at::Half* w = inputs[2].data_ptr<at::Half>();
at::Half* z = inputs[5].data_ptr<at::Half>();
at::Half* b = inputs[8].data_ptr<at::Half>();
auto out2 = outputs[1];
at::Half* y2 = out2.data_ptr<at::Half>();
//printf("forward_state.outdimA1 = {%d,%d,%d,%d}\n",forward_state.outdimA1[0],forward_state.outdimA1[1],forward_state.outdimA1[2],forward_state.outdimA1[3]);
//printf("forward_state.padA1 = {%d,%d}\n",forward_state.padA1[0],forward_state.padA1[1]);
//printf("forward_state.convstrideA = {%d,%d}\n",forward_state.convstrideA[0],forward_state.convstrideA[1]);
//printf("forward_state.dilationA = {%d,%d}\n",forward_state.dilationA[0],forward_state.dilationA[1]);
//printf("forward_state.filterdimA2 = {%d,%d,%d,%d}\n",forward_state.filterdimA2[0],forward_state.filterdimA2[1],forward_state.filterdimA2[2],forward_state.filterdimA2[3]);
//printf("forward_state.outdimA2 = {%d,%d,%d,%d}\n",forward_state.outdimA2[0],forward_state.outdimA2[1],forward_state.outdimA2[2],forward_state.outdimA2[3]);
run_conv_scale_bias_add_activation(forward_state.outdimA1b,
forward_state.padA2,
forward_state.convstrideA,
forward_state.dilationA,
forward_state.filterdimA2,
forward_state.outdimA2,
CUDNN_DATA_HALF,
y1,
w,
y2,
z,
b,
nullptr);
DEBUG_MSG("[DEBUG] new relu2 : " << out2.to(at::kFloat).sum().item<float>());
}
void bottleneck_forward_rest(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs) {
std::cout << std::fixed;
// from _out1 method
at::Half* x = inputs[0].data_ptr<at::Half>();
// create output of conv3
auto out3 = outputs[2];
at::Half* y3 = out3.data_ptr<at::Half>();
// create output of conv4 that may exist
auto identity = at::empty_like(out3);
at::Half* yi = identity.data_ptr<at::Half>();
at::Half *w, *z, *b;
if (stride_1X1 != 1 || forward_state.filterdimA3[0] != forward_state.dimA[1]){
w = inputs[10].data_ptr<at::Half>();
z = inputs[11].data_ptr<at::Half>();
b = inputs[12].data_ptr<at::Half>();
run_conv_scale_bias(forward_state.dimA,
forward_state.padA,
forward_state.convstride1X1,
forward_state.dilationA,
forward_state.filterdimA4,
forward_state.outdimA3,
CUDNN_DATA_HALF,
x,
w,
yi,
z,
b);
DEBUG_MSG("[DEBUG] new downsample : " << identity.to(at::kFloat).sum().item<float>());
}
else {
yi = x;
}
auto out2 = outputs[1];
at::Half* y2 = out2.data_ptr<at::Half>();
w = inputs[3].data_ptr<at::Half>();
z = inputs[6].data_ptr<at::Half>();
b = inputs[9].data_ptr<at::Half>();
run_conv_scale_bias_add_activation(forward_state.outdimA2,
forward_state.padA,
forward_state.convstrideA,
forward_state.dilationA,
forward_state.filterdimA3,
forward_state.outdimA3,
CUDNN_DATA_HALF,
y2,
w,
y3,
z,
b,
yi);
DEBUG_MSG("[DEBUG] new relu3 : " << out3.to(at::kFloat).sum().item<float>());
}
namespace {
struct bottleneck_backward_state {
int64_t dimA[4];
int64_t filterdimA1[4];
int64_t filterdimA2[4];
int64_t filterdimA3[4];
int64_t filterdimA4[4];
int64_t filterdimA2hh[4]; // Cin,Cout,1,3
int64_t threshdim[4];
int axis[4];
int64_t outdimA1[4]; // grad_out1
int64_t outdimA1b[4]; // out1_pad
int64_t outdimA2[4]; // grad_out2
int64_t outdimA3[4];
int64_t outdimA1h[4]; // output: grad_out1 halo (H=3)
int64_t outdimA2h[4]; // input : grad_out2 halo cells (H=3)
int64_t outdimA1hh[4]; // input: grad_out2 halo (H=1)
int64_t outdimA2hh[4]; // input: out1 halo (H=1)
int64_t padA[2];
int64_t padA1[2];
int64_t padA2[2];
int64_t dilationA[2];
int64_t convstrideA[2];
int64_t convstride1X1[2];
int64_t filterdim2hh[4]; // Cin,1,3,Cout
int64_t outdim1[4];
int64_t outdim1b[4];
int64_t outdim2[4];
int64_t outdim3[4];
int64_t outdim1h[4];
int64_t outdim1hh[4];
void init(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs) {
// setup dimensions
dimA[0] = dimA[1] = dimA[2] = dimA[3] = 0;
filterdimA1[0] = filterdimA1[1] = filterdimA1[2] = filterdimA1[3] = 0;
filterdimA2[0] = filterdimA2[1] = filterdimA2[2] = filterdimA2[3] = 0;
filterdimA3[0] = filterdimA3[1] = filterdimA3[2] = filterdimA3[3] = 0;
filterdimA4[0] = filterdimA4[1] = filterdimA4[2] = filterdimA4[3] = 0;
filterdimA2hh[0] = filterdimA2hh[1] = filterdimA2hh[2] = filterdimA2hh[3] = 0;
threshdim[0] = threshdim[1] = threshdim[2] = threshdim[3] = 1;
// All dim calculation after this order of n,c,h,w
if (explicit_nhwc) {
axis[0] = 0;
axis[1] = 3;
axis[2] = 1;
axis[3] = 2;
} else {
axis[0] = 0;
axis[1] = 1;
axis[2] = 2;
axis[3] = 3;
}
for (int dim=0;dim<4;dim++) {
dimA[dim] = inputs[0].size(axis[dim]);
filterdimA1[dim] = inputs[1].size(axis[dim]);
filterdimA2[dim] = inputs[2].size(axis[dim]);
filterdimA3[dim] = inputs[3].size(axis[dim]);
}
if (stride_1X1 != 1 || filterdimA3[0] != dimA[1]) {
for (int dim=0;dim<4;dim++) {
filterdimA4[dim] = inputs[14].size(axis[dim]);
}
}
for (int dim=0;dim<4;dim++) {
if (dim == 2) {
filterdimA2hh[dim] = 1;
} else {
filterdimA2hh[dim] = filterdimA2[dim];
}
}
// output dim in n,c,h,w used by backend
outdimA1[0] = outdimA1[1] = outdimA1[2] = outdimA1[3] = 0;
outdimA1b[0] = outdimA1b[1] = outdimA1b[2] = outdimA1b[3] = 0;
outdimA2[0] = outdimA2[1] = outdimA2[2] = outdimA2[3] = 0;
outdimA3[0] = outdimA3[1] = outdimA3[2] = outdimA3[3] = 0;
outdimA1h[0] = outdimA1h[1] = outdimA1h[2] = outdimA1h[3] = 0;
outdimA2h[0] = outdimA2h[1] = outdimA2h[2] = outdimA2h[3] = 0;
outdimA1hh[0] = outdimA1hh[1] = outdimA1hh[2] = outdimA1hh[3] = 0;
outdimA2hh[0] = outdimA2hh[1] = outdimA2hh[2] = outdimA2hh[3] = 0;
// use these fixed value for test run
padA[0] = 0; padA[1] = 0;
padA1[0] = 1; padA1[1] = 1;
padA2[0] = 0; padA2[1] = 1;
dilationA[0] = 1; dilationA[1] = 1;
convstrideA[0] = 1; convstrideA[1] = 1;
convstride1X1[0] = stride_1X1; convstride1X1[1] = stride_1X1;
// compute output from pad/stride/dilation
outdimA1[0] = dimA[0];
outdimA1[1] = filterdimA1[0];
for (int dim = 0; dim < 2; dim++) {
outdimA1[dim + 2] = getFwdConvOutputDim(dimA[dim + 2], padA[dim], filterdimA1[dim + 2], convstride1X1[dim], dilationA[dim]);
}
for (int dim = 0; dim < 4; dim++) {
if (dim == 2) {
outdimA1b[dim] = outdimA1[dim] + 2;
} else {
outdimA1b[dim] = outdimA1[dim];
}
}
outdimA2[0] = outdimA1[0];
outdimA2[1] = filterdimA2[0];
for (int dim = 0; dim < 2; dim++) {
outdimA2[dim + 2] = getFwdConvOutputDim(outdimA1[dim + 2], padA1[dim], filterdimA2[dim + 2], convstrideA[dim], dilationA[dim]);
}
outdimA3[0] = outdimA2[0];
outdimA3[1] = filterdimA3[0];
for (int dim = 0; dim < 2; dim++) {
outdimA3[dim + 2] = getFwdConvOutputDim(outdimA2[dim + 2], padA[dim], filterdimA3[dim + 2], convstrideA[dim], dilationA[dim]);
}
for (int dim = 0; dim < 4; dim++) {
if (dim == 2) {
outdimA1h[dim] = 3;
outdimA2h[dim] = 3;
outdimA1hh[dim] = 1;
outdimA2hh[dim] = 1;
} else {
outdimA1h[dim] = outdimA1[dim];
outdimA2h[dim] = outdimA2[dim];
outdimA1hh[dim] = outdimA1[dim];
outdimA2hh[dim] = outdimA2[dim];
}
}
// Create output tensor in the correct shape in pytorch's view
outdim1[0] = outdim1[1] = outdim1[2] = outdim1[3] = 0;
outdim1b[0] = outdim1b[1] = outdim1b[2] = outdim1b[3] = 0;
outdim2[0] = outdim2[1] = outdim2[2] = outdim2[3] = 0;
outdim3[0] = outdim3[1] = outdim3[2] = outdim3[3] = 0;
outdim1h[0] = outdim1h[1] = outdim1h[2] = outdim1h[3] = 0;
outdim1hh[0] = outdim1hh[1] = outdim1hh[2] = outdim1hh[3] = 0;
filterdim2hh[0] = filterdim2hh[1] = filterdim2hh[2] = filterdim2hh[3] = 0;
if (explicit_nhwc) {
axis[0] = 0;
axis[1] = 2;
axis[2] = 3;
axis[3] = 1;
}
for (int dim=0;dim<4;dim++) {
outdim1[dim] = outdimA1[axis[dim]];
outdim1b[dim] = outdimA1b[axis[dim]];
outdim2[dim] = outdimA2[axis[dim]];
outdim3[dim] = outdimA3[axis[dim]];
outdim1h[dim] = outdimA1h[axis[dim]];
outdim1hh[dim] = outdimA1hh[axis[dim]];
filterdim2hh[dim] = filterdimA2hh[axis[dim]];
}
}
};
bottleneck_backward_state backward_state;
}
std::vector<at::Tensor> bottleneck_backward_init(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs) {
std::cout << std::fixed;
backward_state.init(explicit_nhwc, stride_1X1, inputs);
// create output vector
std::vector<at::Tensor> outputs;
auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
auto grad_x = at::empty_like(inputs[0]);
auto wgrad1 = at::empty_like(inputs[1]);
auto wgrad2 = at::empty_like(inputs[2]);
auto wgrad3 = at::empty_like(inputs[3]);
outputs.push_back(grad_x);
outputs.push_back(wgrad1);
outputs.push_back(wgrad2);
outputs.push_back(wgrad3);
if (stride_1X1 != 1 || backward_state.filterdimA3[0] != backward_state.dimA[1]) {
auto wgrad4 = at::empty_like(inputs[14]);
outputs.push_back(wgrad4);
}
return outputs;
}
void bottleneck_backward_wgrad3(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs) {
// dconv3+drelu2+dscale2
at::Half* conv_in = inputs[13].data_ptr<at::Half>();
at::Half* dy3 = inputs[10].data_ptr<at::Half>();
// wgrad
auto wgrad3 = outputs[3];
at::Half* dw3 = wgrad3.data_ptr<at::Half>();
run_dconv(backward_state.outdimA2,
backward_state.padA,
backward_state.convstrideA,
backward_state.dilationA,
backward_state.filterdimA3,
backward_state.outdimA3,
CUDNN_DATA_HALF,
conv_in,
dw3,
dy3,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);
DEBUG_MSG("[DEBUG] new wgrad3 : " << wgrad3.to(at::kFloat).sum().item<float>());
}
at::Tensor bottleneck_backward_grad_out2(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs) {
bool requires_grad = inputs[0].requires_grad();
std::cout << std::fixed;
auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
// dconv3+drelu2+dscale2
at::Half* conv_in = inputs[13].data_ptr<at::Half>();
at::Half* dy3 = inputs[10].data_ptr<at::Half>();
DEBUG_MSG("[DEBUG] new dconv3 : " << inputs[10].to(at::kFloat).sum().item<float>());
// dgrad
auto grad_out2 = at::empty(backward_state.outdim2, inputs[0].type(), output_format);
at::Half* dy2 = grad_out2.data_ptr<at::Half>();
at::Half* w = inputs[3].data_ptr<at::Half>();
at::Half* z = inputs[5].data_ptr<at::Half>();
at::Half* relu2 = inputs[13].data_ptr<at::Half>();
run_dconv_drelu_dscale(backward_state.outdimA2,
backward_state.padA,
backward_state.convstrideA,
backward_state.dilationA,
backward_state.filterdimA3,
backward_state.outdimA3,
CUDNN_DATA_HALF,
dy2,
w,
dy3,
z,
relu2);
// do halo exchange of dy2 here
DEBUG_MSG("[DEBUG] new dconv2 : " << grad_out2.to(at::kFloat).sum().item<float>());
return grad_out2;
}
at::Tensor bottleneck_backward_grad_out1(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor grad_out2) {
bool requires_grad = inputs[0].requires_grad();
std::cout << std::fixed;
auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
// dgrad
at::Half* dy2 = grad_out2.data_ptr<at::Half>();
// dgrad
auto grad_out1 = at::empty(backward_state.outdim1, inputs[0].type(), output_format);
at::Half* dy1 = grad_out1.data_ptr<at::Half>();
at::Half* w = inputs[2].data_ptr<at::Half>();
at::Half* z = inputs[4].data_ptr<at::Half>();
at::Half* relu1 = inputs[12].data_ptr<at::Half>();
//printf("relu.shape = [%d,%d,%d,%d]\n",inputs[12].size(0),inputs[12].size(1),inputs[12].size(2),inputs[12].size(3));
// fused dgrad
//printf("backward_state.outdim1 = {%d,%d,%d,%d}\n",backward_state.outdim1[0],backward_state.outdim1[1],backward_state.outdim1[2],backward_state.outdim1[3]);
run_dconv_drelu_dscale(backward_state.outdimA1,
backward_state.padA1,
backward_state.convstrideA,
backward_state.dilationA,
backward_state.filterdimA2,
backward_state.outdimA2,
CUDNN_DATA_HALF,
dy1,
w,
dy2,
z,
relu1);
return grad_out1;
}
at::Tensor bottleneck_backward_grad_out1_mask(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor grad_out2, at::Tensor thresholdTop, at::Tensor thresholdBottom) {
bool requires_grad = inputs[0].requires_grad();
std::cout << std::fixed;
auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
// dgrad
at::Half* dy2 = grad_out2.data_ptr<at::Half>();
// dgrad
auto grad_out1 = at::empty(backward_state.outdim1, inputs[0].type(), output_format);
at::Half* dy1 = grad_out1.data_ptr<at::Half>();
at::Half* w = inputs[2].data_ptr<at::Half>();
at::Half* z = inputs[4].data_ptr<at::Half>();
at::Half* relu1 = inputs[12].data_ptr<at::Half>();
//printf("relu.shape = [%d,%d,%d,%d]\n",inputs[12].size(0),inputs[12].size(1),inputs[12].size(2),inputs[12].size(3));
// fused dgrad
run_dconv_drelu_dscale_mask(backward_state.outdimA1,
backward_state.padA1,
backward_state.convstrideA,
backward_state.dilationA,
backward_state.filterdimA2,
backward_state.outdimA2,
backward_state.threshdim,
CUDNN_DATA_HALF,
dy1,
w,
dy2,
z,
relu1,
thresholdTop.data_ptr<int>(),
thresholdBottom.data_ptr<int>(),
2);
return grad_out1;
}
// perform backward data 1x3 convolution (grad_out * w_rot180) on grad_out2 input of shape [N,1,W,C] with padding=(0,1) to produce output of shape [N,1,W,C]
at::Tensor bottleneck_backward_grad_out1_halo_corr(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, at::Tensor w1by3, std::vector<at::Tensor> outputs, at::Tensor grad_out2_halo, at::Tensor relu1_halo, at::Tensor part_grad_out1) {
bool requires_grad = inputs[0].requires_grad();
std::cout << std::fixed;
auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
// dgrad
at::Half* dy2h = grad_out2_halo.data_ptr<at::Half>();
// dgrad
auto grad_out1_halo = at::empty(backward_state.outdim1hh, inputs[0].type(), output_format);
at::Half* dy1h = grad_out1_halo.data_ptr<at::Half>();
//at::Half* w = inputs[2].data_ptr<at::Half>(); // use w1by3 instead, which is a sliced version of inputs[2]
at::Half* w = w1by3.data_ptr<at::Half>();
at::Half* z = inputs[4].data_ptr<at::Half>();
at::Half* relu1h = relu1_halo.data_ptr<at::Half>();
at::Half* pdy1h = part_grad_out1.data_ptr<at::Half>();
//printf("relu.shape = [%d,%d,%d,%d]\n",relu1_halo.size(0),relu1_halo.size(1),relu1_halo.size(2),relu1_halo.size(3));
// fused dgrad
//printf("backward_state.outdimA1h = {%d,%d,%d,%d}\n",backward_state.outdimA1h[0],backward_state.outdimA1h[1],backward_state.outdimA1h[2],backward_state.outdimA1h[3]);
//printf("backward_state.outdimA2h = {%d,%d,%d,%d}\n",backward_state.outdimA2h[0],backward_state.outdimA2h[1],backward_state.outdimA2h[2],backward_state.outdimA2h[3]);
//printf("backward_state.filterdimA2 = {%d,%d,%d,%d}\n",backward_state.filterdimA2[0],backward_state.filterdimA2[1],backward_state.filterdimA2[2],backward_state.filterdimA2[3]);
run_dconv_add_drelu_dscale(backward_state.outdimA1hh,
backward_state.padA2, // 0,1
backward_state.convstrideA,
backward_state.dilationA,
backward_state.filterdimA2hh, // C,1,3,C
backward_state.outdimA2hh,
CUDNN_DATA_HALF,
dy1h,
w,
dy2h,
z,
relu1h,
pdy1h);
return grad_out1_halo;
}
// perform backward data 3x3 convolution (grad_out * w_rot180) on grad_out2 input of shape [N,3,W,C] with padding=(1,1) to produce output of shape [N,3,W,C]
at::Tensor bottleneck_backward_grad_out1_halo(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor grad_out2_halo, at::Tensor relu1_halo) {
bool requires_grad = inputs[0].requires_grad();
std::cout << std::fixed;
auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
// dgrad
at::Half* dy2h = grad_out2_halo.data_ptr<at::Half>();
// dgrad
auto grad_out1_halo = at::empty(backward_state.outdim1h, inputs[0].type(), output_format);
at::Half* dy1h = grad_out1_halo.data_ptr<at::Half>();
at::Half* w = inputs[2].data_ptr<at::Half>();
at::Half* z = inputs[4].data_ptr<at::Half>();
at::Half* relu1h = relu1_halo.data_ptr<at::Half>();
//printf("relu.shape = [%d,%d,%d,%d]\n",relu1_halo.size(0),relu1_halo.size(1),relu1_halo.size(2),relu1_halo.size(3));
// fused dgrad
//printf("backward_state.outdimA1h = {%d,%d,%d,%d}\n",backward_state.outdimA1h[0],backward_state.outdimA1h[1],backward_state.outdimA1h[2],backward_state.outdimA1h[3]);
//printf("backward_state.outdimA2h = {%d,%d,%d,%d}\n",backward_state.outdimA2h[0],backward_state.outdimA2h[1],backward_state.outdimA2h[2],backward_state.outdimA2h[3]);
//printf("backward_state.filterdimA2 = {%d,%d,%d,%d}\n",backward_state.filterdimA2[0],backward_state.filterdimA2[1],backward_state.filterdimA2[2],backward_state.filterdimA2[3]);
run_dconv_drelu_dscale(backward_state.outdimA1h,
backward_state.padA1,
backward_state.convstrideA,
backward_state.dilationA,
backward_state.filterdimA2,
backward_state.outdimA2h,
CUDNN_DATA_HALF,
dy1h,
w,
dy2h,
z,
relu1h);
return grad_out1_halo;
}
void bottleneck_backward_wgrad2_pad(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor input, at::Tensor grad_out2) {
std::cout << std::fixed;
auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
// dgrad
at::Half* dy2 = grad_out2.data_ptr<at::Half>();
// dconv2+drelu1+dscale1
at::Half* conv_in = input.data_ptr<at::Half>();
// wgrad
auto wgrad2 = outputs[2];
at::Half* dw2 = wgrad2.data_ptr<at::Half>();
//printf("outdimA1b = (%d,%d,%d,%d)\n",backward_state.outdimA1b[0],backward_state.outdimA1b[1],backward_state.outdimA1b[2],backward_state.outdimA1b[3]);
//printf("backward_state.padA2 = {%d,%d}\n",backward_state.padA2[0],backward_state.padA2[1]);
run_dconv(backward_state.outdimA1b, // conv_in.shape (including H halos)
backward_state.padA2, // 0, 1
backward_state.convstrideA,
backward_state.dilationA,
backward_state.filterdimA2, // dw2.shape
backward_state.outdimA2, // dy2.shape
CUDNN_DATA_HALF,
conv_in,
dw2,
dy2,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);
DEBUG_MSG("[DEBUG] new wgrad2 : " << wgrad2.to(at::kFloat).sum().item<float>());
}
void bottleneck_backward_wgrad2(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor grad_out2) {
bool requires_grad = inputs[0].requires_grad();
std::cout << std::fixed;
auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
// dgrad
at::Half* dy2 = grad_out2.data_ptr<at::Half>();
// dconv2+drelu1+dscale1
at::Half* conv_in = inputs[12].data_ptr<at::Half>();
// wgrad
auto wgrad2 = outputs[2];
at::Half* dw2 = wgrad2.data_ptr<at::Half>();
//printf("outdimA1 = (%d,%d,%d,%d)\n",backward_state.outdimA1[0],backward_state.outdimA1[1],backward_state.outdimA1[2],backward_state.outdimA1[3]);
run_dconv(backward_state.outdimA1,
backward_state.padA1,
backward_state.convstrideA,
backward_state.dilationA,
backward_state.filterdimA2,
backward_state.outdimA2,
CUDNN_DATA_HALF,
conv_in,
dw2,
dy2,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);
DEBUG_MSG("[DEBUG] new wgrad2 : " << wgrad2.to(at::kFloat).sum().item<float>());
}
// compute halo cells for input volume of dimension [N,1,W,C] with padding=(0,1) to produce output volume of dimension [N,1,W,C]
// input and grad_out2_halo tensors are all of same shape
// output tensor is of shape [Cin,1,3,Cout] (regular filter dims are [Cin,3,3,Cout]
at::Tensor bottleneck_backward_wgrad2_halo(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor input, at::Tensor grad_out2_halo) {
bool requires_grad = inputs[0].requires_grad();
std::cout << std::fixed;
auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
// dgrad
at::Half* dy2 = grad_out2_halo.data_ptr<at::Half>();
// dconv2+drelu1+dscale1
at::Half* conv_in = input.data_ptr<at::Half>();
// wgrad
auto wgrad2_halo = at::empty(backward_state.filterdim2hh, input.type(), output_format);
at::Half* dw2 = wgrad2_halo.data_ptr<at::Half>();
//printf("backward_state.outdimA1hh = {%d,%d,%d,%d}\n",backward_state.outdimA1hh[0],backward_state.outdimA1hh[1],backward_state.outdimA1hh[2],backward_state.outdimA1hh[3]);
//printf("backward_state.outdimA2hh = {%d,%d,%d,%d}\n",backward_state.outdimA2hh[0],backward_state.outdimA2hh[1],backward_state.outdimA2hh[2],backward_state.outdimA2hh[3]);
//printf("backward_state.filterdim2hh = {%d,%d,%d,%d}\n",backward_state.filterdim2hh[0],backward_state.filterdim2hh[1],backward_state.filterdim2hh[2],backward_state.filterdim2hh[3]);
//printf("backward_state.filterdimA2hh = {%d,%d,%d,%d}\n",backward_state.filterdimA2hh[0],backward_state.filterdimA2hh[1],backward_state.filterdimA2hh[2],backward_state.filterdimA2hh[3]);
//printf("backward_state.padA2 = {%d,%d}\n",backward_state.padA2[0],backward_state.padA2[1]);
run_dconv(backward_state.outdimA1hh, // N,C,1,W
backward_state.padA2, // 0, 1
backward_state.convstrideA,
backward_state.dilationA,
backward_state.filterdimA2hh, // Cin,Cout,1,3
backward_state.outdimA2hh, // N,C,1,W
CUDNN_DATA_HALF,
conv_in,
dw2,
dy2,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);
return wgrad2_halo;
}
void bottleneck_backward_wgrad1(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor grad_out1) {
at::Half* x = inputs[0].data_ptr<at::Half>();
at::Half* dy1 = grad_out1.data_ptr<at::Half>();
// dconv1+add
// wgrad
auto wgrad1 = outputs[1];
at::Half* dw1 = wgrad1.data_ptr<at::Half>();
run_dconv(backward_state.dimA,
backward_state.padA,
backward_state.convstride1X1,
backward_state.dilationA,
backward_state.filterdimA1,
backward_state.outdimA1,
CUDNN_DATA_HALF,
x,
dw1,
dy1,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);
}
void bottleneck_backward_rest(bool explicit_nhwc, int stride_1X1, std::vector<at::Tensor> inputs, std::vector<at::Tensor> outputs, at::Tensor grad_out2, at::Tensor grad_out1) {
bool requires_grad = inputs[0].requires_grad();
std::cout << std::fixed;
auto output_format = explicit_nhwc ? at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
// dgrad
at::Half* dy2 = grad_out2.data_ptr<at::Half>();
at::Half* dy1 = grad_out1.data_ptr<at::Half>();
/*
// backward strided conv cannot be fused
// if stride == 1 but channel changes, we can fuse here
if (stride_1X1 != 1){
// dgrad
run_dconv(outdimA1,
padA1,
convstride1X1,
dilationA,
filterdimA2,
outdimA2,
CUDNN_DATA_HALF,
dy1,
w,
dy2,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR);
// mul fused mask
grad_out1.mul_(inputs[15]);
}
else {
at::Half* relu1 = inputs[12].data_ptr<at::Half>();
// fused dgrad
run_dconv_drelu_dscale(outdimA1,
padA1,
convstride1X1,
dilationA,
filterdimA2,
outdimA2,
CUDNN_DATA_HALF,
dy1,
w,
dy2,
z,
relu1);
}
*/
DEBUG_MSG("[DEBUG] new dconv1 : " << grad_out1.to(at::kFloat).sum().item<float>());
// create grads of conv4 that may exist
auto grad_x_conv4 = at::empty_like(inputs[0]);
at::Half* dx_conv4 = grad_x_conv4.data_ptr<at::Half>();
at::Tensor wgrad4;
// x used for dconv1 and dconv4 wgrad
at::Half* x = inputs[0].data_ptr<at::Half>();
at::Half* w = NULL;
if (stride_1X1 != 1 || backward_state.filterdimA3[0] != backward_state.dimA[1]){
w = inputs[14].data_ptr<at::Half>();
at::Half* dy_conv4 = inputs[11].data_ptr<at::Half>();
if (requires_grad) {
run_dconv(backward_state.dimA,
backward_state.padA,
backward_state.convstride1X1,
backward_state.dilationA,
backward_state.filterdimA4,
backward_state.outdimA3,
CUDNN_DATA_HALF,
dx_conv4,
w,
dy_conv4,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR);
// we don't print here since we can't hook out this grad in pytorch alone to compare, due to addition with dx
// DEBUG_MSG("[DEBUG] new dx_identity : " << grad_x_conv4.to(at::kFloat).sum().item<float>());
}
// wgrad
wgrad4 = outputs[4];
at::Half* dw4 = wgrad4.data_ptr<at::Half>();
run_dconv(backward_state.dimA,
backward_state.padA,
backward_state.convstride1X1,
backward_state.dilationA,
backward_state.filterdimA4,
backward_state.outdimA3,
CUDNN_DATA_HALF,
x,
dw4,
dy_conv4,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);
}
else {
// if there is no downsample, dx_conv4 is fork of drelu3
dx_conv4 = inputs[11].data_ptr<at::Half>();
}
// dgrad
w = inputs[1].data_ptr<at::Half>();
auto grad_x = outputs[0];
at::Half* dx = grad_x.data_ptr<at::Half>();
// backward strided conv cannot be fused
// if stride == 1 but channel changes, we can fuse here
if (requires_grad){
if (stride_1X1 != 1){
run_dconv(backward_state.dimA,
backward_state.padA,
backward_state.convstride1X1,
backward_state.dilationA,
backward_state.filterdimA1,
backward_state.outdimA1,
CUDNN_DATA_HALF,
dx,
w,
dy1,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR);
// add 2 together
grad_x.add_(grad_x_conv4);
}
else {
run_dconv_add(backward_state.dimA,
backward_state.padA,
backward_state.convstride1X1,
backward_state.dilationA,
backward_state.filterdimA1,
backward_state.outdimA1,
CUDNN_DATA_HALF,
dx,
w,
dy1,
dx_conv4);
}
}
DEBUG_MSG("[DEBUG] new dx : " << grad_x.to(at::kFloat).sum().item<float>());
DEBUG_MSG("[DEBUG] new wgrad1 : " << wgrad1.to(at::kFloat).sum().item<float>());
if (stride_1X1 != 1 || backward_state.filterdimA3[0] != backward_state.dimA[1]) {
DEBUG_MSG("[DEBUG] new wgrad4 : " << wgrad4.to(at::kFloat).sum().item<float>());
}
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &bottleneck_forward, "Bottleneck block forward");
m.def("backward", &bottleneck_backward, "Bottleneck block backward");
m.def("forward_init", &bottleneck_forward_init, "Bottleneck block init");
m.def("forward_out1", &bottleneck_forward_out1, "Bottleneck block forward");
m.def("forward_out2", &bottleneck_forward_out2, "Bottleneck block forward");
m.def("forward_out2_mask", &bottleneck_forward_out2_mask, "Bottleneck block forward");
m.def("forward_out2_halo", &bottleneck_forward_out2_halo, "Bottleneck block forward");
m.def("forward_out2_halo_corr", &bottleneck_forward_out2_halo_corr, "Bottleneck block forward");
m.def("forward_out2_pad", &bottleneck_forward_out2_pad, "Bottleneck block forward");
m.def("forward_rest", &bottleneck_forward_rest, "Bottleneck block forward");
m.def("backward_init", &bottleneck_backward_init, "Bottleneck block backward init");
m.def("backward_grad_out2", &bottleneck_backward_grad_out2, "Bottleneck block backward");
m.def("backward_grad_out1", &bottleneck_backward_grad_out1, "Bottleneck block backward");
m.def("backward_grad_out1_mask", &bottleneck_backward_grad_out1_mask, "Bottleneck block backward");
m.def("backward_grad_out1_halo", &bottleneck_backward_grad_out1_halo, "Bottleneck block backward");
m.def("backward_grad_out1_halo_corr", &bottleneck_backward_grad_out1_halo_corr, "Bottleneck block backward");
m.def("backward_wgrad2_pad", &bottleneck_backward_wgrad2_pad, "Bottleneck block backward");
m.def("backward_wgrad2", &bottleneck_backward_wgrad2, "Bottleneck block backward");
m.def("backward_wgrad2_halo", &bottleneck_backward_wgrad2_halo, "Bottleneck block backward");
m.def("backward_wgrad3", &bottleneck_backward_wgrad3, "Bottleneck block backward");
m.def("backward_wgrad1", &bottleneck_backward_wgrad1, "Bottleneck block backward");
m.def("backward_rest", &bottleneck_backward_rest, "Bottleneck block backward");
}
#include <ATen/ATen.h>
#include <ATen/cudnn/Handle.h> // for getcudnnhandle
#include <torch/extension.h>
#include <torch/torch.h>
#include <vector>
#include <cudnn_frontend.h>
#include <iostream>
#ifdef DEBUG
#define DEBUG_MSG(str) do { std::cout << str << std::endl; } while( false )
#else
#define DEBUG_MSG(str) do { } while ( false )
#endif
#ifdef DEBUG_CUDNN
#define DEBUG_CUDNN_MSG(buf, str) do { buf << str << std::endl; } while( false )
#else
#define DEBUG_CUDNN_MSG(buf, str) do { } while ( false )
#endif
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(at::MemoryFormat::ChannelsLast), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
#define checkCudnnErr(...) \
do { \
int err = checkCudnnError(__VA_ARGS__, #__VA_ARGS__, __FILE__, __LINE__); \
if (err) { \
return; \
} \
} while (0)
int checkCudnnError(cudnnStatus_t code, const char* expr, const char* file, int line) {
if (code) {
printf("CUDNN error at %s:%d, code=%d (%s) in '%s'\n", file, line, (int)code, cudnnGetErrorString(code), expr);
return 1;
}
return 0;
}
void checkError(cudaError_t code, char const * func, const char *file, const int line, bool abort = true);
#define checkCUDAError(val) { checkError((val), #val, __FILE__, __LINE__); } // in-line regular function
void checkError(cudaError_t code, char const * func, const char *file, const int line, bool abort) {
if (code != cudaSuccess)
{
const char * errorMessage = cudaGetErrorString(code);
fprintf(stderr, "CUDA error returned from \"%s\" at %s:%d, Error code: %d (%s)\n", func, file, line, code, errorMessage);
if (abort){
cudaDeviceReset();
exit(code);
}
}
}
void generateStrides(const int64_t* dimA, int64_t* strideA, int nbDims, cudnnTensorFormat_t filterFormat) {
// For INT8x4 and INT8x32 we still compute standard strides here to input
// into the cuDNN functions. We will manually scale by resizeFactor in the cpu ref.
if (filterFormat == CUDNN_TENSOR_NCHW) {
strideA[nbDims - 1] = 1;
for (int64_t d = nbDims - 2; d >= 0; d--) {
strideA[d] = strideA[d + 1] * dimA[d + 1];
}
} else {
// Here we assume that the format is CUDNN_TENSOR_NHWC
strideA[1] = 1;
strideA[nbDims - 1] = strideA[1] * dimA[1];
for (int64_t d = nbDims - 2; d >= 2; d--) {
strideA[d] = strideA[d + 1] * dimA[d + 1];
}
strideA[0] = strideA[2] * dimA[2];
}
}
int getFwdConvDilatedFilterDim(int filterDim, int dilation) {
return ((filterDim - 1) * dilation) + 1;
}
int getFwdConvPaddedImageDim(int tensorDim, int pad) {
return tensorDim + (2 * pad);
}
int getFwdConvOutputDim(int tensorDim,
int pad,
int filterDim,
int stride,
int dilation) {
int p = (getFwdConvPaddedImageDim(tensorDim, pad) - getFwdConvDilatedFilterDim(filterDim, dilation)) / stride + 1;
return (p);
}
// create a cache for plan
std::unordered_map<std::string, cudnn_frontend::ExecutionPlan> plan_cache;
std::string getConvFusionString(int64_t* x_dim_padded,
int64_t* padA,
int64_t* convstrideA,
int64_t* dilationA,
int64_t* w_dim_padded,
cudnnDataType_t dataType,
std::string fusion_string) {
for(int i=0;i<4;i++) {
fusion_string += 'X';
fusion_string += std::to_string(x_dim_padded[i]);
}
for(int i=0;i<4;i++) {
fusion_string += 'W';
fusion_string += std::to_string(w_dim_padded[i]);
}
for(int i=0;i<2;i++) {
fusion_string += 'P';
fusion_string += std::to_string(padA[i]);
}
for(int i=0;i<2;i++) {
fusion_string += 'S';
fusion_string += std::to_string(convstrideA[i]);
}
for(int i=0;i<2;i++) {
fusion_string += 'D';
fusion_string += std::to_string(dilationA[i]);
}
fusion_string += 'T';
fusion_string += std::to_string(dataType);
return fusion_string;
}
cudnn_frontend::ExecutionPlan& getOrCreatePlan(cudnnHandle_t handle_,
std::stringstream& log_buf,
cudnn_frontend::OperationGraph& opGraph,
std::string cache_string,
bool use_heuristic = true){
auto it = plan_cache.find(cache_string);
if (it != plan_cache.end()) {
DEBUG_CUDNN_MSG(log_buf, "Found plan in cache");
return it->second;
} else {
if (use_heuristic){
// TODO: confirm which mode to use
auto heuristics = cudnn_frontend::EngineHeuristicsBuilder()
.setOperationGraph(opGraph)
.setHeurMode(CUDNN_HEUR_MODE_INSTANT)
.build();
// try 3 times for now as WAR for no heuristic training
int max_tries = 3, count = 0;
auto& engine_configs = heuristics.getEngineConfig(max_tries);
while(true) {
try {
plan_cache.emplace(cache_string, std::move(cudnn_frontend::ExecutionPlanBuilder()
.setHandle(handle_)
.setEngineConfig(engine_configs[count], opGraph.getTag())
.build()));
break;
} catch (cudnn_frontend::cudnnException e) {
if (++count == max_tries) throw e;
}
}
}else{
DEBUG_CUDNN_MSG(log_buf, "No plan in cache");
// How many engines support this operation graph ?
auto total_engines = opGraph.getEngineCount();
DEBUG_CUDNN_MSG(log_buf, opGraph.describe() << " has " << total_engines << " engines.");
// We have to randomly pick one engine from [0, total_engines)
// Selecting "0" by default
auto engine = cudnn_frontend::EngineBuilder().setGlobalEngineIdx(0).setOperationGraph(opGraph).build();
DEBUG_CUDNN_MSG(log_buf, engine.describe());
auto& knobs = engine.getSupportedKnobs();
for (auto it = std::begin(knobs); it != std::end(knobs); ++it) {
DEBUG_CUDNN_MSG(log_buf, it->describe());
}
if (knobs.begin() != knobs.end()) {
DEBUG_CUDNN_MSG(log_buf, "Updated knob choice");
knobs.begin()->setChoice(knobs.begin()->getMinValue() + 1);
DEBUG_CUDNN_MSG(log_buf, knobs.begin()->describe());
}
// Createmplacee the requisite engine config
auto engine_config = cudnn_frontend::EngineConfigBuilder().setEngine(engine).build();
DEBUG_CUDNN_MSG(log_buf, engine_config.describe());
plan_cache.emplace(cache_string, std::move(cudnn_frontend::ExecutionPlanBuilder().setHandle(handle_).setEngineConfig(engine_config).build()));
}
return plan_cache.find(cache_string)->second;
}
}
void
run_conv_bias(int64_t* x_dim,
int64_t* w_dim,
int64_t* y_dim,
int64_t* conv_pad,
int64_t* convstride,
int64_t* dilation,
cudnnDataType_t dataType,
at::Half* devPtrX,
at::Half* devPtrW,
at::Half* devPtrB,
at::Half* devPtrY) {
cudnnHandle_t handle_ = torch::native::getCudnnHandle();
std::stringstream log_buf;
try {
int convDim = 2;
float alpha = 1.0f;
float beta = 0.0f;
int64_t b_dim[] = {1, y_dim[1], 1, 1};
// Creates the necessary tensor descriptors
int64_t stride[4];
generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto xTensor = cudnn_frontend::TensorBuilder()
.setDim(4, x_dim)
.setStrides(4, stride)
.setId('x')
.setAlignment(16)
.setDataType(dataType)
.build();
DEBUG_CUDNN_MSG(log_buf, xTensor.describe());
generateStrides(w_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto wTensor = cudnn_frontend::TensorBuilder()
.setDim(4, w_dim)
.setStrides(4, stride)
.setId('w')
.setAlignment(16)
.setDataType(dataType)
.build();
DEBUG_CUDNN_MSG(log_buf, wTensor.describe());
generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto afterConvTensor = cudnn_frontend::TensorBuilder()
.setDim(4, y_dim)
.setStrides(4, stride)
.setId('c')
.setAlignment(16)
.setDataType(CUDNN_DATA_FLOAT)
.setVirtual()
.build();
DEBUG_CUDNN_MSG(log_buf, afterConvTensor.describe());
generateStrides(b_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto bTensor = cudnn_frontend::TensorBuilder()
.setDim(4, b_dim)
.setStrides(4, stride)
.setId('b')
.setAlignment(16)
.setDataType(dataType)
.build();
DEBUG_CUDNN_MSG(log_buf, bTensor.describe());
generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto afterBiasTensor = cudnn_frontend::TensorBuilder()
.setDim(4, y_dim)
.setStrides(4, stride)
.setId('y')
.setAlignment(16)
.setDataType(dataType)
.build();
DEBUG_CUDNN_MSG(log_buf, afterBiasTensor.describe());
// Define the bias operation
auto biasDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_ADD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, biasDesc.describe());
// Define the convolution problem
auto convDesc = cudnn_frontend::ConvDescBuilder()
.setDataType(CUDNN_DATA_FLOAT)
.setMathMode(CUDNN_CROSS_CORRELATION)
.setNDims(convDim)
.setStrides(convDim, convstride)
.setPrePadding(convDim, conv_pad)
.setPostPadding(convDim, conv_pad)
.setDilation(convDim, dilation)
.build();
DEBUG_CUDNN_MSG(log_buf, convDesc.describe());
// Create a convolution Node
auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR)
.setxDesc(xTensor)
.setwDesc(wTensor)
.setyDesc(afterConvTensor)
.setcDesc(convDesc)
.setAlpha(alpha)
.setBeta(beta)
.build();
DEBUG_CUDNN_MSG(log_buf, conv_op.describe());
// Create a Bias Node.
auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(conv_op.getOutputTensor())
.setbDesc(bTensor)
.setyDesc(afterBiasTensor)
.setpwDesc(biasDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, bias_op.describe());
// Create an Operation Graph. In this case it is convolution bias activation
std::array<cudnn_frontend::Operation const*, 2> ops = {&conv_op, &bias_op};
auto opGraph = cudnn_frontend::OperationGraphBuilder()
.setHandle(handle_)
.setOperationGraph(2, ops.data())
.build();
// Create string encoding for plan caching
auto cache_string = getConvFusionString(x_dim, conv_pad, convstride, dilation, w_dim, dataType, opGraph.getTag());
DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
auto workspace_size = plan.getWorkspaceSize();
DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
void* workspace_ptr = nullptr;
auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
if (workspace_size > 0) {
workspace_ptr = workspace_tensor.data_ptr<float>();
}
void* data_ptrs[] = {devPtrX, devPtrW, devPtrB, devPtrY};
int64_t uids[] = {'x', 'w', 'b', 'y'};
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace_ptr)
.setDataPointers(4, data_ptrs)
.setUids(4, uids)
.build();
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
checkCudnnErr(status);
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
} catch (cudnn_frontend::cudnnException e) {
std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
}
}
void
run_conv_bias_mask_relu(int64_t* x_dim,
int64_t* w_dim,
int64_t* y_dim,
int64_t* conv_pad,
int64_t* conv_stride,
int64_t* conv_dilation,
cudnnDataType_t dataType,
at::Half* devPtrX,
at::Half* devPtrW,
at::Half* devPtrB,
int8_t* devPtrM,
at::Half* devPtrY) {
cudnnHandle_t handle_ = torch::native::getCudnnHandle();
std::stringstream log_buf;
try {
int conv_dim = 2;
float alpha = 1.0f;
float beta = 0.0f;
int64_t b_dim[] = {1, y_dim[1], 1, 1};
// Creates the necessary tensor descriptors
int64_t stride[4];
generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto xTensor = cudnn_frontend::TensorBuilder()
.setDim(4, x_dim)
.setStrides(4, stride)
.setId('x')
.setAlignment(16)
.setDataType(dataType)
.build();
DEBUG_CUDNN_MSG(log_buf, xTensor.describe());
generateStrides(w_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto wTensor = cudnn_frontend::TensorBuilder()
.setDim(4, w_dim)
.setStrides(4, stride)
.setId('w')
.setAlignment(16)
.setDataType(dataType)
.build();
DEBUG_CUDNN_MSG(log_buf, wTensor.describe());
generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto mTensor = cudnn_frontend::TensorBuilder()
.setDim(4, y_dim)
.setStrides(4, stride)
.setId('m')
.setAlignment(16)
.setDataType(CUDNN_DATA_INT8)
.build();
DEBUG_CUDNN_MSG(log_buf, wTensor.describe());
generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto afterConvTensor = cudnn_frontend::TensorBuilder()
.setDim(4, y_dim)
.setStrides(4, stride)
.setId('c')
.setAlignment(16)
.setDataType(CUDNN_DATA_FLOAT)
.setVirtual()
.build();
DEBUG_CUDNN_MSG(log_buf, afterConvTensor.describe());
generateStrides(b_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto bTensor = cudnn_frontend::TensorBuilder()
.setDim(4, b_dim)
.setStrides(4, stride)
.setId('b')
.setAlignment(16)
.setDataType(dataType)
.build();
DEBUG_CUDNN_MSG(log_buf, bTensor.describe());
generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto afterBiasTensor = cudnn_frontend::TensorBuilder()
.setDim(4, y_dim)
.setStrides(4, stride)
.setId('B')
.setAlignment(16)
.setDataType(CUDNN_DATA_FLOAT)
.setVirtual()
.build();
DEBUG_CUDNN_MSG(log_buf, afterBiasTensor.describe());
generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto afterMaskTensor = cudnn_frontend::TensorBuilder()
.setDim(4, y_dim)
.setStrides(4, stride)
.setId('M')
.setAlignment(16)
.setDataType(CUDNN_DATA_FLOAT)
.setVirtual()
.build();
DEBUG_CUDNN_MSG(log_buf, afterBiasTensor.describe());
generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto afterReLUTensor = cudnn_frontend::TensorBuilder()
.setDim(4, y_dim)
.setStrides(4, stride)
.setId('y')
.setAlignment(16)
.setDataType(dataType)
.build();
DEBUG_CUDNN_MSG(log_buf, afterReLUTensor.describe());
// Define the convolution problem
auto convDesc = cudnn_frontend::ConvDescBuilder()
.setDataType(CUDNN_DATA_FLOAT)
.setMathMode(CUDNN_CROSS_CORRELATION)
.setNDims(conv_dim)
.setStrides(conv_dim, conv_stride)
.setPrePadding(conv_dim, conv_pad)
.setPostPadding(conv_dim, conv_pad)
.setDilation(conv_dim, conv_dilation)
.build();
DEBUG_CUDNN_MSG(log_buf, convDesc.describe());
// Define the bias operation
auto biasDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_ADD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, biasDesc.describe());
// Define the mask operation
auto maskDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_MUL)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
// Define the activation operation
auto actDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_RELU_FWD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, actDesc.describe());
// Create a convolution Node
auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR)
.setxDesc(xTensor)
.setwDesc(wTensor)
.setyDesc(afterConvTensor)
.setcDesc(convDesc)
.setAlpha(alpha)
.setBeta(beta)
.build();
DEBUG_CUDNN_MSG(log_buf, conv_op.describe());
// Create a Bias Node
auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(conv_op.getOutputTensor())
.setbDesc(bTensor)
.setyDesc(afterBiasTensor)
.setpwDesc(biasDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, bias_op.describe());
// create a Mask Node
auto mask_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(bias_op.getOutputTensor())
.setbDesc(mTensor)
.setyDesc(afterMaskTensor)
.setpwDesc(maskDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, mask_op.describe());
// Create an Activation Node
auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(mask_op.getOutputTensor())
.setyDesc(afterReLUTensor)
.setpwDesc(actDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, act_op.describe());
// Create an Operation Graph. In this case it is convolution bias activation
std::array<cudnn_frontend::Operation const*, 4> ops = {&conv_op, &bias_op, &mask_op, &act_op};
auto opGraph = cudnn_frontend::OperationGraphBuilder()
.setHandle(handle_)
.setOperationGraph(4, ops.data())
.build();
// Create string encoding for plan caching
auto cache_string = getConvFusionString(x_dim, conv_pad, conv_stride, conv_dilation, w_dim, dataType, opGraph.getTag());
DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
auto workspace_size = plan.getWorkspaceSize();
DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
void* workspace_ptr = nullptr;
auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
if (workspace_size > 0) {
workspace_ptr = workspace_tensor.data_ptr<float>();
}
void* data_ptrs[] = {devPtrX, devPtrW, devPtrB, devPtrM, devPtrY};
int64_t uids[] = {'x', 'w', 'b', 'm', 'y'};
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace_ptr)
.setDataPointers(5, data_ptrs)
.setUids(5, uids)
.build();
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
checkCudnnErr(status);
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
} catch (cudnn_frontend::cudnnException e) {
std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
}
}
void
run_conv_bias_relu(int64_t* x_dim,
int64_t* w_dim,
int64_t* y_dim,
int64_t* conv_pad,
int64_t* conv_stride,
int64_t* conv_dilation,
cudnnDataType_t dataType,
at::Half* devPtrX,
at::Half* devPtrW,
at::Half* devPtrB,
at::Half* devPtrY) {
cudnnHandle_t handle_ = torch::native::getCudnnHandle();
std::stringstream log_buf;
try {
int conv_dim = 2;
float alpha = 1.0f;
float beta = 0.0f;
int64_t b_dim[] = {1, y_dim[1], 1, 1};
// Creates the necessary tensor descriptors
int64_t stride[4];
generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto xTensor = cudnn_frontend::TensorBuilder()
.setDim(4, x_dim)
.setStrides(4, stride)
.setId('x')
.setAlignment(16)
.setDataType(dataType)
.build();
DEBUG_CUDNN_MSG(log_buf, xTensor.describe());
generateStrides(w_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto wTensor = cudnn_frontend::TensorBuilder()
.setDim(4, w_dim)
.setStrides(4, stride)
.setId('w')
.setAlignment(16)
.setDataType(dataType)
.build();
DEBUG_CUDNN_MSG(log_buf, wTensor.describe());
generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto afterConvTensor = cudnn_frontend::TensorBuilder()
.setDim(4, y_dim)
.setStrides(4, stride)
.setId('c')
.setAlignment(16)
.setDataType(CUDNN_DATA_FLOAT)
.setVirtual()
.build();
DEBUG_CUDNN_MSG(log_buf, afterConvTensor.describe());
generateStrides(b_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto bTensor = cudnn_frontend::TensorBuilder()
.setDim(4, b_dim)
.setStrides(4, stride)
.setId('b')
.setAlignment(16)
.setDataType(dataType)
.build();
DEBUG_CUDNN_MSG(log_buf, bTensor.describe());
generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto afterBiasTensor = cudnn_frontend::TensorBuilder()
.setDim(4, y_dim)
.setStrides(4, stride)
.setId('B')
.setAlignment(16)
.setDataType(CUDNN_DATA_FLOAT)
.setVirtual()
.build();
DEBUG_CUDNN_MSG(log_buf, afterBiasTensor.describe());
generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto afterReLUTensor = cudnn_frontend::TensorBuilder()
.setDim(4, y_dim)
.setStrides(4, stride)
.setId('y')
.setAlignment(16)
.setDataType(dataType)
.build();
DEBUG_CUDNN_MSG(log_buf, afterReLUTensor.describe());
// Define the convolution problem
auto convDesc = cudnn_frontend::ConvDescBuilder()
.setDataType(CUDNN_DATA_FLOAT)
.setMathMode(CUDNN_CROSS_CORRELATION)
.setNDims(conv_dim)
.setStrides(conv_dim, conv_stride)
.setPrePadding(conv_dim, conv_pad)
.setPostPadding(conv_dim, conv_pad)
.setDilation(conv_dim, conv_dilation)
.build();
DEBUG_CUDNN_MSG(log_buf, convDesc.describe());
// Define the bias operation
auto biasDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_ADD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, biasDesc.describe());
// Define the activation operation
auto actDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_RELU_FWD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, actDesc.describe());
// Create a convolution Node
auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR)
.setxDesc(xTensor)
.setwDesc(wTensor)
.setyDesc(afterConvTensor)
.setcDesc(convDesc)
.setAlpha(alpha)
.setBeta(beta)
.build();
DEBUG_CUDNN_MSG(log_buf, conv_op.describe());
// Create a Bias Node.
auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(conv_op.getOutputTensor())
.setbDesc(bTensor)
.setyDesc(afterBiasTensor)
.setpwDesc(biasDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, bias_op.describe());
// Create an Activation Node.
auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setxDesc(bias_op.getOutputTensor())
.setyDesc(afterReLUTensor)
.setpwDesc(actDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, act_op.describe());
// Create an Operation Graph. In this case it is convolution bias activation
std::array<cudnn_frontend::Operation const*, 3> ops = {&conv_op, &bias_op, &act_op};
auto opGraph = cudnn_frontend::OperationGraphBuilder()
.setHandle(handle_)
.setOperationGraph(3, ops.data())
.build();
// Create string encoding for plan caching
auto cache_string = getConvFusionString(x_dim, conv_pad, conv_stride, conv_dilation, w_dim, dataType, opGraph.getTag());
DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
auto workspace_size = plan.getWorkspaceSize();
DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
void* workspace_ptr = nullptr;
auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
if (workspace_size > 0) {
workspace_ptr = workspace_tensor.data_ptr<float>();
}
void* data_ptrs[] = {devPtrX, devPtrW, devPtrB, devPtrY};
int64_t uids[] = {'x', 'w', 'b', 'y'};
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace_ptr)
.setDataPointers(4, data_ptrs)
.setUids(4, uids)
.build();
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
checkCudnnErr(status);
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
} catch (cudnn_frontend::cudnnException e) {
std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
}
}
void
run_drelu_dbias(int64_t* dy_dim,
cudnnDataType_t dataType,
at::Half* devPtrDY,
at::Half* devPtrR,
at::Half* devPtrDR,
float* devPtrDB) {
cudnnHandle_t handle_ = torch::native::getCudnnHandle();
std::stringstream log_buf;
try {
int convDim = 2;
float alpha = 1.0f;
float beta = 0.0f;
int64_t b_dim[] = {1, dy_dim[1], 1, 1};
// Creates the necessary tensor descriptors
int64_t stride[4];
generateStrides(dy_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto dyTensor = cudnn_frontend::TensorBuilder()
.setDim(4, dy_dim)
.setStrides(4, stride)
.setId('x')
.setAlignment(16)
.setDataType(dataType)
.build();
DEBUG_CUDNN_MSG(log_buf, dyTensor.describe());
generateStrides(dy_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto rTensor = cudnn_frontend::TensorBuilder()
.setDim(4, dy_dim)
.setStrides(4, stride)
.setId('r')
.setAlignment(16)
.setDataType(dataType)
.build();
DEBUG_CUDNN_MSG(log_buf, rTensor.describe());
generateStrides(dy_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto inActGradTensor = cudnn_frontend::TensorBuilder()
.setDim(4, dy_dim)
.setStrides(4, stride)
.setId('R')
.setAlignment(16)
.setDataType(dataType)
.build();
DEBUG_CUDNN_MSG(log_buf, inActGradTensor.describe());
generateStrides(b_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto biasGradTensor = cudnn_frontend::TensorBuilder()
.setDim(4, b_dim)
.setStrides(4, stride)
.setId('y')
.setAlignment(16)
.setDataType(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, biasGradTensor.describe());
// Define the activation backward operation
auto actDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_RELU_BWD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, actDesc.describe());
// Define the bias backward operation
auto biasDesc = cudnn_frontend::ReductionDescBuilder()
.setMathPrecision(CUDNN_DATA_FLOAT)
.setReductionOp(CUDNN_REDUCE_TENSOR_ADD)
.build();
DEBUG_CUDNN_MSG(log_buf, biasDesc.describe());
// Create an relu backward Node
auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setdyDesc(dyTensor)
.setxDesc(rTensor)
.setdxDesc(inActGradTensor)
.setpwDesc(actDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, act_op.describe());
// Create bias node
auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR)
.setxDesc(inActGradTensor)
.setyDesc(biasGradTensor)
.setreductionDesc(biasDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, bias_op.describe());
// Create an Operation Graph. In this case it is bias only
std::array<cudnn_frontend::Operation const*, 2> ops = {&act_op, &bias_op};
auto opGraph = cudnn_frontend::OperationGraphBuilder()
.setHandle(handle_)
.setOperationGraph(ops.size(), ops.data())
.build();
// Create string encoding for plan caching
// creating unique dummy values
int64_t pad_dummy[] = {20, 20};
int64_t stride_dummy[] = {20, 20};
int64_t dilation_dummy[] = {20, 20};
auto cache_string = getConvFusionString(dy_dim, pad_dummy, stride_dummy, dilation_dummy, b_dim, dataType, opGraph.getTag());
DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
auto workspace_size = plan.getWorkspaceSize();
DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
void* workspace_ptr = nullptr;
auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
if (workspace_size > 0) {
workspace_ptr = workspace_tensor.data_ptr<float>();
}
void* data_ptrs[] = {devPtrDY, devPtrR, devPtrDR, devPtrDB};
int64_t uids[] = {'x', 'r', 'R', 'y'};
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace_ptr)
.setDataPointers(4, data_ptrs)
.setUids(4, uids)
.build();
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
checkCudnnErr(status);
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
} catch (cudnn_frontend::cudnnException e) {
std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
}
}
void
run_dconv_drelu_dbias(int64_t* x_dim,
int64_t* w_dim,
int64_t* y_dim,
int64_t* pad,
int64_t* convstride,
int64_t* dilation,
cudnnDataType_t dataType,
at::Half* devPtrX,
at::Half* devPtrW,
at::Half* devPtrR,
at::Half* devPtrRg,
float* devPtrY) {
cudnnHandle_t handle_ = torch::native::getCudnnHandle();
std::stringstream log_buf;
try {
int convDim = 2;
float alpha = 1.0f;
float beta = 0.0f;
int64_t b_dim[] = {1, x_dim[1], 1, 1};
int64_t stride[4];
generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto outConvGradTensor = cudnn_frontend::TensorBuilder()
.setDim(4, y_dim)
.setStrides(4, stride)
.setId('x')
.setAlignment(16)
.setDataType(dataType)
.build();
DEBUG_CUDNN_MSG(log_buf, outConvGradTensor.describe());
generateStrides(w_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto wTensor = cudnn_frontend::TensorBuilder()
.setDim(4, w_dim)
.setStrides(4, stride)
.setId('w')
.setAlignment(16)
.setDataType(dataType)
.build();
DEBUG_CUDNN_MSG(log_buf, wTensor.describe());
generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto inConvGradTensor = cudnn_frontend::TensorBuilder()
.setDim(4, x_dim)
.setStrides(4, stride)
.setId('A')
.setAlignment(16)
.setDataType(CUDNN_DATA_FLOAT)
.setVirtual()
.build();
DEBUG_CUDNN_MSG(log_buf, inConvGradTensor.describe());
generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto rTensor = cudnn_frontend::TensorBuilder()
.setDim(4, x_dim)
.setStrides(4, stride)
.setId('r')
.setAlignment(16)
.setDataType(dataType)
.build();
DEBUG_CUDNN_MSG(log_buf, rTensor.describe());
generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto inReLUGradTensor = cudnn_frontend::TensorBuilder()
.setDim(4, x_dim)
.setStrides(4, stride)
.setId('R')
.setAlignment(16)
.setDataType(dataType)
.build();
DEBUG_CUDNN_MSG(log_buf, inReLUGradTensor.describe());
generateStrides(b_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto inBiasGradTensor = cudnn_frontend::TensorBuilder()
.setDim(4, b_dim)
.setStrides(4, stride)
.setId('y')
.setAlignment(16)
.setDataType(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, inBiasGradTensor.describe());
// Define the convolution problem
auto convDesc = cudnn_frontend::ConvDescBuilder()
.setDataType(CUDNN_DATA_FLOAT)
.setMathMode(CUDNN_CROSS_CORRELATION)
.setNDims(convDim)
.setStrides(convDim, convstride)
.setPrePadding(convDim, pad)
.setPostPadding(convDim, pad)
.setDilation(convDim, dilation)
.build();
DEBUG_CUDNN_MSG(log_buf, convDesc.describe());
// Define the activation backward operation
auto actDesc = cudnn_frontend::PointWiseDescBuilder()
.setMode(CUDNN_POINTWISE_RELU_BWD)
.setMathPrecision(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, actDesc.describe());
// Define the bias backward operation
auto biasDesc = cudnn_frontend::ReductionDescBuilder()
.setMathPrecision(CUDNN_DATA_FLOAT)
.setReductionOp(CUDNN_REDUCE_TENSOR_ADD)
.build();
DEBUG_CUDNN_MSG(log_buf, biasDesc.describe());
// Create a convolution Node
auto conv_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR)
.setdyDesc(outConvGradTensor)
.setwDesc(wTensor)
.setdxDesc(inConvGradTensor)
.setcDesc(convDesc)
.setAlpha(alpha)
.setBeta(beta)
.build();
DEBUG_CUDNN_MSG(log_buf, conv_op.describe());
// Create an relu backward Node
auto act_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
.setdyDesc(inConvGradTensor)
.setxDesc(rTensor)
.setdxDesc(inReLUGradTensor)
.setpwDesc(actDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, act_op.describe());
// Create bias node
auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR)
.setxDesc(inReLUGradTensor)
.setyDesc(inBiasGradTensor)
.setreductionDesc(biasDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, bias_op.describe());
// Create an Operation Graph. In this case it is bias only
std::array<cudnn_frontend::Operation const*, 3> ops = {&conv_op, &act_op, &bias_op};
auto opGraph = cudnn_frontend::OperationGraphBuilder()
.setHandle(handle_)
.setOperationGraph(ops.size(), ops.data())
.build();
// Create string encoding for plan caching
auto cache_string = getConvFusionString(x_dim, pad, convstride, dilation, w_dim, dataType, opGraph.getTag());
DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
auto workspace_size = plan.getWorkspaceSize();
DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
void* workspace_ptr = nullptr;
auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
if (workspace_size > 0) {
workspace_ptr = workspace_tensor.data_ptr<float>();
}
void* data_ptrs[] = {devPtrX, devPtrW, devPtrR, devPtrRg, devPtrY};
int64_t uids[] = {'x', 'w', 'r', 'R', 'y'};
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace_ptr)
.setDataPointers(5, data_ptrs)
.setUids(5, uids)
.build();
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
checkCudnnErr(status);
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
} catch (cudnn_frontend::cudnnException e) {
std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
}
}
void
run_dconv(int64_t* x_dim,
int64_t* w_dim,
int64_t* y_dim,
int64_t* conv_pad,
int64_t* conv_stride,
int64_t* conv_dilation,
cudnnDataType_t dataType,
at::Half* devPtrX,
at::Half* devPtrW,
at::Half* devPtrY,
cudnnBackendDescriptorType_t mode) {
cudnnHandle_t handle_ = torch::native::getCudnnHandle();
std::stringstream log_buf;
try {
int conv_dim = 2;
float alpha = 1.0f;
float beta = 0.0f;
// Define the convolution problem
int64_t stride[4];
generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto xTensor = cudnn_frontend::TensorBuilder()
.setDim(4, x_dim)
.setStrides(4, stride)
.setId('x')
.setAlignment(16)
.setDataType(dataType)
.build();
DEBUG_CUDNN_MSG(log_buf, xTensor.describe());
generateStrides(w_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto wTensor = cudnn_frontend::TensorBuilder()
.setDim(4, w_dim)
.setStrides(4, stride)
.setId('w')
.setAlignment(16)
.setDataType(dataType)
.build();
DEBUG_CUDNN_MSG(log_buf, wTensor.describe());
generateStrides(y_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto yTensor = cudnn_frontend::TensorBuilder()
.setDim(4, y_dim)
.setStrides(4, stride)
.setId('y')
.setAlignment(16)
.setDataType(dataType)
.build();
DEBUG_CUDNN_MSG(log_buf, yTensor.describe());
// Define the convolution problem
auto convDesc = cudnn_frontend::ConvDescBuilder()
.setDataType(CUDNN_DATA_FLOAT)
.setMathMode(CUDNN_CROSS_CORRELATION)
.setNDims(conv_dim)
.setStrides(conv_dim, conv_stride)
.setPrePadding(conv_dim, conv_pad)
.setPostPadding(conv_dim, conv_pad)
.setDilation(conv_dim, conv_dilation)
.build();
DEBUG_CUDNN_MSG(log_buf, convDesc.describe());
// Create a convolution node
// mode should be one of following
// CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR
// CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR
auto conv_op_builder = cudnn_frontend::OperationBuilder(mode);
if (mode == CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR) {
conv_op_builder.setdxDesc(xTensor)
.setwDesc(wTensor)
.setdyDesc(yTensor)
.setcDesc(convDesc);
}
else {
conv_op_builder.setxDesc(xTensor)
.setdwDesc(wTensor)
.setdyDesc(yTensor)
.setcDesc(convDesc);
}
auto conv_op = conv_op_builder
.setAlpha(alpha)
.setBeta(beta)
.build();
DEBUG_CUDNN_MSG(log_buf, conv_op.describe());
// Create an Operation Graph. In this case it is convolution add bias activation
std::array<cudnn_frontend::Operation const*, 1> ops = {&conv_op};
auto opGraph = cudnn_frontend::OperationGraphBuilder()
.setHandle(handle_)
.setOperationGraph(ops.size(), ops.data())
.build();
// Create string encoding for plan caching
auto cache_string = getConvFusionString(x_dim, conv_pad, conv_stride, conv_dilation, w_dim, dataType, opGraph.getTag());
DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
auto workspace_size = plan.getWorkspaceSize();
DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
void* workspace_ptr = nullptr;
auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
if (workspace_size > 0) {
workspace_ptr = workspace_tensor.data_ptr<float>();
}
void* data_ptrs[] = {devPtrX, devPtrW, devPtrY};
int64_t uids[] = {'x', 'w', 'y'};
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace_ptr)
.setDataPointers(3, data_ptrs)
.setUids(3, uids)
.build();
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
checkCudnnErr(status);
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
} catch (cudnn_frontend::cudnnException e) {
std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
}
}
void
run_dbias(int64_t* x_dim,
cudnnDataType_t dataType,
at::Half* devPtrX,
float* devPtrY) {
cudnnHandle_t handle_ = torch::native::getCudnnHandle();
std::stringstream log_buf;
try {
int convDim = 2;
int64_t b_dim[] = {1, x_dim[1], 1, 1};
int64_t stride[4];
generateStrides(x_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto xTensor = cudnn_frontend::TensorBuilder()
.setDim(4, x_dim)
.setStrides(4, stride)
.setId('x')
.setAlignment(16)
.setDataType(dataType)
.build();
DEBUG_CUDNN_MSG(log_buf, xTensor.describe());
generateStrides(b_dim, stride, 4, CUDNN_TENSOR_NHWC);
auto yTensor = cudnn_frontend::TensorBuilder()
.setDim(4, b_dim)
.setStrides(4, stride)
.setId('y')
.setAlignment(16)
.setDataType(CUDNN_DATA_FLOAT)
.build();
DEBUG_CUDNN_MSG(log_buf, yTensor.describe());
// Define the bias backward operation
auto biasDesc = cudnn_frontend::ReductionDescBuilder()
.setMathPrecision(CUDNN_DATA_FLOAT)
.setReductionOp(CUDNN_REDUCE_TENSOR_ADD)
.build();
DEBUG_CUDNN_MSG(log_buf, biasDesc.describe());
// Create bias node
auto bias_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR)
.setxDesc(xTensor)
.setyDesc(yTensor)
.setreductionDesc(biasDesc)
.build();
DEBUG_CUDNN_MSG(log_buf, bias_op.describe());
// Create an Operation Graph. In this case it is bias only
std::array<cudnn_frontend::Operation const*, 1> ops = {&bias_op};
auto opGraph = cudnn_frontend::OperationGraphBuilder()
.setHandle(handle_)
.setOperationGraph(ops.size(), ops.data())
.build();
// Create string encoding for plan caching
int64_t pad_dummy[] = {10, 10};
int64_t stride_dummy[] = {10, 10};
int64_t dilation_dummy[] = {10, 10};
auto cache_string = getConvFusionString(x_dim, pad_dummy, stride_dummy, dilation_dummy, b_dim, dataType, opGraph.getTag());
DEBUG_CUDNN_MSG(log_buf, "[convstring] " << cache_string);
auto& plan = getOrCreatePlan(handle_, log_buf, opGraph, cache_string);
DEBUG_CUDNN_MSG(log_buf, "Plan tag: " << plan.getTag());
auto workspace_size = plan.getWorkspaceSize();
DEBUG_CUDNN_MSG(log_buf, plan.describe() << " requires workspace " << workspace_size);
void* workspace_ptr = nullptr;
auto workspace_tensor = at::empty({(workspace_size+3)/4}, at::TensorOptions(at::kCUDA).dtype(at::kFloat));
if (workspace_size > 0) {
workspace_ptr = workspace_tensor.data_ptr<float>();
}
void* data_ptrs[] = {devPtrX, devPtrY};
int64_t uids[] = {'x', 'y'};
auto variantPack = cudnn_frontend::VariantPackBuilder()
.setWorkspacePointer(workspace_ptr)
.setDataPointers(2, data_ptrs)
.setUids(2, uids)
.build();
DEBUG_CUDNN_MSG(log_buf, "variantPack " << variantPack.describe());
cudnnStatus_t status = cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc());
checkCudnnErr(status);
cudnn_frontend::throw_if([status]() { return (status != CUDNN_STATUS_SUCCESS); }, "Plan execute error", status);
} catch (cudnn_frontend::cudnnException e) {
std::cout << log_buf.str() << "[ERROR] Exception " << e.what() << std::endl;
}
}
std::vector<at::Tensor> conv_bias_mask_relu_forward(std::vector<at::Tensor> inputs, int64_t padding, int64_t stride) {
std::cout << std::fixed;
// create output vector
std::vector<at::Tensor> outputs;
auto output_format = at::MemoryFormat::ChannelsLast;
// setup dimensions
int64_t x_dim[] = {0, 0, 0, 0};
int64_t w_dim[] = {0, 0, 0, 0};
// All dim calculation after this order of n,c,h,w
int axis[] = {0, 1, 2, 3};
for (int dim = 0; dim < 4; dim++) {
x_dim[dim] = inputs[0].size(axis[dim]);
w_dim[dim] = inputs[1].size(axis[dim]);
}
// output dim in n,c,h,w used by backend
int64_t y_dim[] = {0, 0, 0, 0};
// use these fixed values
int64_t conv_pad[] = {padding, padding};
int64_t conv_stride[] = {stride, stride};
int64_t conv_dilation[] = {1, 1};
// compute output from pad/stride/dilation
y_dim[0] = x_dim[0];
y_dim[1] = w_dim[0];
for (int dim = 0; dim < 2; dim++) {
y_dim[dim + 2] = getFwdConvOutputDim(x_dim[dim + 2], conv_pad[dim], w_dim[dim + 2], conv_stride[dim], conv_dilation[dim]);
}
// run
at::Half* x = inputs[0].data_ptr<at::Half>();
at::Half* w = inputs[1].data_ptr<at::Half>();
at::Half* b = inputs[2].data_ptr<at::Half>();
int8_t* m = inputs[3].data_ptr<int8_t>();
auto out = at::empty(y_dim, inputs[0].type(), output_format);
at::Half* y = out.data_ptr<at::Half>();
run_conv_bias_mask_relu(x_dim,
w_dim,
y_dim,
conv_pad,
conv_stride,
conv_dilation,
CUDNN_DATA_HALF,
x,
w,
b,
m,
y);
DEBUG_MSG("[DEBUG] conv-bias-mask-relu : " << y.to(at::kFloat).sum().item<float>());
outputs.push_back(out);
return outputs;
}
std::vector<at::Tensor> conv_bias_relu_forward(std::vector<at::Tensor> inputs, int64_t padding, int64_t stride) {
std::cout << std::fixed;
// create output vector
std::vector<at::Tensor> outputs;
auto output_format = at::MemoryFormat::ChannelsLast;
// setup dimensions
int64_t x_dim[] = {0, 0, 0, 0};
int64_t w_dim[] = {0, 0, 0, 0};
// All dim calculation after this order of n,c,h,w
int axis[] = {0, 1, 2, 3};
for (int dim = 0; dim < 4; dim++) {
x_dim[dim] = inputs[0].size(axis[dim]);
w_dim[dim] = inputs[1].size(axis[dim]);
}
// output dim in n,c,h,w used by backend
int64_t y_dim[] = {0, 0, 0, 0};
// use these fixed values
int64_t conv_pad[] = {padding, padding};
int64_t conv_stride[] = {stride, stride};
int64_t conv_dilation[] = {1, 1};
// compute output from pad/stride/dilation
y_dim[0] = x_dim[0];
y_dim[1] = w_dim[0];
for (int dim = 0; dim < 2; dim++) {
y_dim[dim + 2] = getFwdConvOutputDim(x_dim[dim + 2], conv_pad[dim], w_dim[dim + 2], conv_stride[dim], conv_dilation[dim]);
}
// run
at::Half* x = inputs[0].data_ptr<at::Half>();
at::Half* w = inputs[1].data_ptr<at::Half>();
at::Half* b = inputs[2].data_ptr<at::Half>();
auto out = at::empty(y_dim, inputs[0].type(), output_format);
at::Half* y = out.data_ptr<at::Half>();
run_conv_bias_relu(x_dim,
w_dim,
y_dim,
conv_pad,
conv_stride,
conv_dilation,
CUDNN_DATA_HALF,
x,
w,
b,
y);
DEBUG_MSG("[DEBUG] conv-bias-relu : " << y.to(at::kFloat).sum().item<float>());
outputs.push_back(out);
return outputs;
}
std::vector<at::Tensor> conv_bias_relu_backward(std::vector<at::Tensor> inputs, int64_t padding, int64_t stride) {
bool requires_grad = inputs[0].requires_grad();
for (int i = 0; i <= 3; i++) {
CHECK_INPUT(inputs[i]);
}
std::cout << std::fixed;
// create output vector
std::vector<at::Tensor> outputs;
auto output_format = at::MemoryFormat::ChannelsLast;
// setup dimensions
int64_t x_dim[] = {0, 0, 0, 0};
int64_t w_dim[] = {0, 0, 0, 0};
int64_t y_dim[] = {0, 0, 0, 0};
// All dim calculation after this order of n,c,h,w
int axis[] = {0, 1, 2, 3};
for (int dim = 0; dim < 4; dim++) {
x_dim[dim] = inputs[0].size(axis[dim]);
w_dim[dim] = inputs[1].size(axis[dim]);
y_dim[dim] = inputs[3].size(axis[dim]);
}
int64_t b_dim[] = {1, y_dim[1], 1, 1};
int64_t conv_pad[] = {padding, padding};
int64_t conv_stride[] = {stride, stride};
int64_t conv_dilation[] = {1, 1};
// run
// drelu-dbias
at::Half* dy = inputs[3].data_ptr<at::Half>();
at::Half* r = inputs[2].data_ptr<at::Half>();
auto drelu = at::empty_like(inputs[2]);
at::Half* dr = drelu.data_ptr<at::Half>();
auto options = at::TensorOptions().dtype(at::kFloat).layout(inputs[0].layout()).device(inputs[0].device()).requires_grad(false);
auto bgrad = at::empty(b_dim, options, output_format);
float* db = bgrad.data_ptr<float>();
run_drelu_dbias(y_dim,
CUDNN_DATA_HALF,
dy,
r,
dr,
db);
// conv wgrad
at::Half* x = inputs[0].data_ptr<at::Half>();
auto wgrad = at::empty_like(inputs[1]);
at::Half* dw = wgrad.data_ptr<at::Half>();
run_dconv(x_dim,
w_dim,
y_dim,
conv_pad,
conv_stride,
conv_dilation,
CUDNN_DATA_HALF,
x,
dw,
dr,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);
// conv dgrad
at::Half* w = inputs[1].data_ptr<at::Half>();
auto dgrad = at::empty_like(inputs[0]);
at::Half* dx = dgrad.data_ptr<at::Half>();
run_dconv(x_dim,
w_dim,
y_dim,
conv_pad,
conv_stride,
conv_dilation,
CUDNN_DATA_HALF,
dx,
w,
dr,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR);
outputs.push_back(dgrad);
outputs.push_back(wgrad);
outputs.push_back(bgrad);
return outputs;
}
std::vector<at::Tensor> conv_bias_forward(std::vector<at::Tensor> inputs, int64_t padding, int64_t stride) {
std::cout << std::fixed;
// create output vector
std::vector<at::Tensor> outputs;
auto output_format = at::MemoryFormat::ChannelsLast;
// setup dimensions
int64_t x_dim[] = {0, 0, 0, 0};
int64_t w_dim[] = {0, 0, 0, 0};
// All dim calculation after this order of n,c,h,w
int axis[] = {0, 1, 2, 3};
for (int dim = 0; dim < 4; dim++) {
x_dim[dim] = inputs[0].size(axis[dim]);
w_dim[dim] = inputs[1].size(axis[dim]);
}
// output dim in n,c,h,w used by backend
int64_t y_dim[] = {0, 0, 0, 0};
// use these fixed values
int64_t conv_pad[] = {padding, padding};
int64_t conv_stride[] = {stride, stride};
int64_t conv_dilation[] = {1, 1};
// compute output from pad/stride/dilation
y_dim[0] = x_dim[0];
y_dim[1] = w_dim[0];
for (int dim = 0; dim < 2; dim++) {
y_dim[dim + 2] = getFwdConvOutputDim(x_dim[dim + 2], conv_pad[dim], w_dim[dim + 2], conv_stride[dim], conv_dilation[dim]);
}
// run
at::Half* x = inputs[0].data_ptr<at::Half>();
at::Half* w = inputs[1].data_ptr<at::Half>();
at::Half* b = inputs[2].data_ptr<at::Half>();
auto out = at::empty(y_dim, inputs[0].type(), output_format);
at::Half* y = out.data_ptr<at::Half>();
run_conv_bias(x_dim,
w_dim,
y_dim,
conv_pad,
conv_stride,
conv_dilation,
CUDNN_DATA_HALF,
x,
w,
b,
y);
DEBUG_MSG("[DEBUG] conv-bias : " << y.to(at::kFloat).sum().item<float>());
outputs.push_back(out);
return outputs;
}
std::vector<at::Tensor> conv_bias_backward(std::vector<at::Tensor> inputs, int64_t padding, int64_t stride) {
bool requires_grad = inputs[0].requires_grad();
for (int i = 0; i <= 2; i++) {
CHECK_INPUT(inputs[i]);
}
std::cout << std::fixed;
// create output vector
std::vector<at::Tensor> outputs;
auto output_format = at::MemoryFormat::ChannelsLast;
// setup dimensions
int64_t x_dim[] = {0, 0, 0, 0};
int64_t w_dim[] = {0, 0, 0, 0};
int64_t y_dim[] = {0, 0, 0, 0};
// All dim calculation after this order of n,c,h,w
int axis[] = {0, 1, 2, 3};
for (int dim = 0; dim < 4; dim++) {
x_dim[dim] = inputs[0].size(axis[dim]);
w_dim[dim] = inputs[1].size(axis[dim]);
y_dim[dim] = inputs[2].size(axis[dim]);
}
int64_t b_dim[] = {1, y_dim[1], 1, 1};
int64_t conv_pad[] = {padding, padding};
int64_t conv_stride[] = {stride, stride};
int64_t conv_dilation[] = {1, 1};
// run
// dbias
at::Half* dy = inputs[2].data_ptr<at::Half>();
auto options = at::TensorOptions().dtype(at::kFloat).layout(inputs[0].layout()).device(inputs[0].device()).requires_grad(false);
auto bgrad = at::empty(b_dim, options, output_format);
float* db = bgrad.data_ptr<float>();
run_dbias(y_dim,
CUDNN_DATA_HALF,
dy,
db);
// conv wgrad
at::Half* x = inputs[0].data_ptr<at::Half>();
auto wgrad = at::empty_like(inputs[1]);
at::Half* dw = wgrad.data_ptr<at::Half>();
run_dconv(x_dim,
w_dim,
y_dim,
conv_pad,
conv_stride,
conv_dilation,
CUDNN_DATA_HALF,
x,
dw,
dy,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR);
// conv dgrad
at::Half* w = inputs[1].data_ptr<at::Half>();
auto dgrad = at::empty_like(inputs[0]);
at::Half* dx = dgrad.data_ptr<at::Half>();
run_dconv(x_dim,
w_dim,
y_dim,
conv_pad,
conv_stride,
conv_dilation,
CUDNN_DATA_HALF,
dx,
w,
dy,
CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR);
outputs.push_back(dgrad);
outputs.push_back(wgrad);
outputs.push_back(bgrad);
return outputs;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &conv_bias_relu_forward, "Fused Conv-Bias-ReLU forward");
m.def("backward", &conv_bias_relu_backward, "Fused Conv-Bias-ReLU backward");
m.def("forward_no_relu", &conv_bias_forward, "Fused Conv-Bias forward");
m.def("backward_no_relu", &conv_bias_backward, "Fused Conv-Bias backward");
m.def("forward_mask", &conv_bias_mask_relu_forward, "Fused Conv-Bias-Mask-ReLU forward");
}
Subproject commit fa611998a360cbabaa2dcc7c9859748144114fc0
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include "fmha.h"
void set_params(Fused_multihead_attention_fprop_params &params,
// sizes
const size_t b,
const size_t s,
const size_t h,
const size_t d,
// device pointers
void *qkv_packed_d,
void *cu_seqlens_d,
void *o_packed_d,
void *s_d,
float p_dropout) {
Data_type acc_type = DATA_TYPE_FP32;
Data_type data_type = DATA_TYPE_FP16;
// Reset the parameters
memset(&params, 0, sizeof(params));
// Set the pointers and strides.
params.qkv_ptr = qkv_packed_d;
params.qkv_stride_in_bytes = get_size_in_bytes(h * 3 * d, data_type);
params.o_ptr = o_packed_d;
params.o_stride_in_bytes = get_size_in_bytes(h * d, data_type);
params.cu_seqlens = static_cast<int *>(cu_seqlens_d);
// S = softmax(P)
params.s_ptr = s_d;
params.s_stride_in_bytes = get_size_in_bytes(b * h * s, data_type);
// Set the dimensions.
params.b = b;
params.h = h;
params.s = s;
params.d = d;
// Set the different scale values.
const float scale_bmm1 = 1.f / sqrtf(d);
constexpr float scale_softmax = 1.f;
constexpr float scale_bmm2 = 1.f;
set_alpha(params.scale_bmm1, scale_bmm1, data_type);
set_alpha(params.scale_softmax, scale_softmax, acc_type);
set_alpha(params.scale_bmm2, scale_bmm2, data_type);
// Set this to probability of keeping an element to simplify things.
params.p_dropout = 1.f - p_dropout;
params.rp_dropout = 1.f / params.p_dropout;
TORCH_CHECK(p_dropout < 1.f);
set_alpha(params.scale_dropout, params.rp_dropout, data_type);
}
std::vector<at::Tensor>
mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens, // b+1
const float p_dropout,
const int max_seq_len,
const bool is_training,
const bool is_nl,
const bool zero_tensors,
c10::optional<at::Generator> gen_) {
auto dprops = at::cuda::getCurrentDeviceProperties();
TORCH_CHECK(dprops->major == 8 && dprops->minor == 0);
auto stream = at::cuda::getCurrentCUDAStream().stream();
Launch_params<Fused_multihead_attention_fprop_params> launch_params(dprops, stream, is_training, is_nl);
int seq_len = 512;
auto launch = &run_fmha_fp16_512_64_sm80;
if( max_seq_len <= 128 ) {
seq_len = 128;
launch = &run_fmha_fp16_128_64_sm80;
} else if( max_seq_len <= 256 ) {
seq_len = 256;
launch = &run_fmha_fp16_256_64_sm80;
} else if( max_seq_len <= 384 ) {
seq_len = 384;
launch = &run_fmha_fp16_384_64_sm80;
} else if( max_seq_len <= 512 ) {
seq_len = 512;
launch = &run_fmha_fp16_512_64_sm80;
} else {
TORCH_CHECK(false);
}
TORCH_CHECK(qkv.is_cuda())
TORCH_CHECK(cu_seqlens.is_cuda())
TORCH_CHECK(qkv.is_contiguous())
TORCH_CHECK(cu_seqlens.is_contiguous())
TORCH_CHECK(cu_seqlens.dim() == 1);
TORCH_CHECK(qkv.dim() == 4);
const auto sizes = qkv.sizes();
TORCH_CHECK(sizes[THREE_DIM] == 3);
const int batch_size = cu_seqlens.numel() - 1;
const int total = sizes[TOTAL_DIM];
const int num_heads = sizes[H_DIM];
const int head_size = sizes[D_DIM];
TORCH_CHECK(batch_size > 0);
TORCH_CHECK(head_size == 64);
auto opts = qkv.options();
auto ctx = torch::empty({ total, num_heads, head_size }, opts);
auto s = torch::empty({ batch_size, num_heads, seq_len, seq_len }, opts);
if( zero_tensors ) {
ctx.zero_();
s.zero_();
}
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator());
set_params(launch_params.params,
batch_size,
seq_len,
num_heads,
head_size,
qkv.data_ptr(),
cu_seqlens.data_ptr(),
ctx.data_ptr(),
s.data_ptr(),
p_dropout);
launch(launch_params, /*configure=*/ true);
// number of times random will be generated per thread, to offset philox counter in thc random
// state
int64_t counter_offset = launch_params.elts_per_thread;
at::PhiloxCudaState rng_engine_inputs;
if( is_training ) {
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
launch_params.params.philox_args = gen->philox_cuda_state(counter_offset);
}
launch(launch_params, /*configure=*/ false);
return { ctx, s };
}
std::vector<at::Tensor>
mha_bwd(const at::Tensor &dout, // total x num_heads, x head_size
const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
at::Tensor &softmax, // b x h x s x s softmax and dmask - will be overwritten with dP
const at::Tensor &cu_seqlens, // b+1
const float p_dropout, // probability to drop
const int max_seq_len, // max sequence length to choose the kernel
const bool zero_tensors
) {
auto dprops = at::cuda::getCurrentDeviceProperties();
TORCH_CHECK(dprops->major == 8 && dprops->minor == 0);
int seq_len = 512;
auto launch = &run_fmha_dgrad_fp16_512_64_sm80;
if( max_seq_len <= 128 ) {
seq_len = 128;
launch = &run_fmha_dgrad_fp16_128_64_sm80;
} else if( max_seq_len <= 256 ) {
seq_len = 256;
launch = &run_fmha_dgrad_fp16_256_64_sm80;
} else if( max_seq_len <= 384 ) {
seq_len = 384;
launch = &run_fmha_dgrad_fp16_384_64_sm80;
} else if( max_seq_len <= 512 ) {
seq_len = 512;
launch = &run_fmha_dgrad_fp16_512_64_sm80;
} else {
TORCH_CHECK(false);
}
auto stream = at::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK(qkv.dtype() == torch::kFloat16);
TORCH_CHECK(dout.dtype() == torch::kFloat16);
TORCH_CHECK(softmax.dtype() == torch::kFloat16);
TORCH_CHECK(cu_seqlens.dtype() == torch::kInt32);
TORCH_CHECK(qkv.is_cuda());
TORCH_CHECK(cu_seqlens.is_cuda());
TORCH_CHECK(qkv.is_contiguous());
TORCH_CHECK(cu_seqlens.is_contiguous());
TORCH_CHECK(cu_seqlens.dim() == 1);
TORCH_CHECK(qkv.dim() == 4);
const auto sizes = qkv.sizes();
TORCH_CHECK(sizes[THREE_DIM] == 3);
const int batch_size = cu_seqlens.numel() - 1;
const int num_heads = sizes[H_DIM];
const int head_size = sizes[D_DIM];
TORCH_CHECK(batch_size > 0);
TORCH_CHECK(head_size == 64);
auto dqkv = torch::empty_like(qkv);
if( zero_tensors ) {
dqkv.zero_();
}
Fused_multihead_attention_fprop_params params;
set_params(params,
batch_size,
seq_len,
num_heads,
head_size,
qkv.data_ptr(),
cu_seqlens.data_ptr(),
dout.data_ptr(), // we set o_ptr to dout
softmax.data_ptr(), // softmax gets overwritten by dP!
p_dropout);
// we're re-using these scales
Data_type acc_type = DATA_TYPE_FP32;
set_alpha(params.scale_bmm1, 1.f, acc_type);
set_alpha(params.scale_softmax, 1.f / sqrtf(head_size), acc_type);
set_alpha(params.scale_bmm2, 1.f, DATA_TYPE_FP16);
params.dqkv_ptr = dqkv.data_ptr();
launch(params, stream);
return { dqkv, softmax };
}
std::vector<at::Tensor> mha_bwd_nl(const at::Tensor &dout, // total x num_heads, x head_size
const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
at::Tensor &softmax, // b x h x s x s softmax and dmask - will be overwritten with dP
const at::Tensor &cu_seqlens, // b+1
const float p_dropout, // probability to drop
const int max_seq_len, // max sequence length to choose the kernel
const bool zero_tensors
) {
auto stream = at::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK(qkv.is_cuda())
TORCH_CHECK(cu_seqlens.is_cuda())
TORCH_CHECK(qkv.is_contiguous())
TORCH_CHECK(cu_seqlens.is_contiguous())
TORCH_CHECK(cu_seqlens.dim() == 1);
TORCH_CHECK(qkv.dim() == 4);
const auto sizes = qkv.sizes();
TORCH_CHECK(sizes[THREE_DIM] == 3);
const int batch_size = cu_seqlens.numel() - 1;
const int total = sizes[TOTAL_DIM];
const int num_heads = sizes[H_DIM];
const int head_size = sizes[D_DIM];
TORCH_CHECK(batch_size > 0);
TORCH_CHECK(head_size == 64);
int seq_len = 512;
auto launch = &run_fmha_dgrad_fp16_512_64_sm80_nl;
auto opts = qkv.options();
auto dqkv = torch::empty_like(qkv);
if( zero_tensors ) {
dqkv.zero_();
}
int num_chunks = 2;
if( batch_size == 1 ) {
num_chunks = 4;
}else if( batch_size == 2 ) {
num_chunks = 3;
}
auto dkv = torch::empty({total, num_chunks, 2, num_heads, head_size}, opts);
Fused_multihead_attention_fprop_params params;
set_params(params,
batch_size,
seq_len,
num_heads,
head_size,
qkv.data_ptr(),
cu_seqlens.data_ptr(),
dout.data_ptr(), // o_ptr = dout
softmax.data_ptr(), // softmax gets overwritten by dP!
p_dropout);
params.dkv_ptr = dkv.data_ptr();
Data_type acc_type = DATA_TYPE_FP32;
set_alpha(params.scale_bmm1, 1.f, acc_type);
set_alpha(params.scale_softmax, 1.f / sqrtf(head_size), acc_type);
set_alpha(params.scale_bmm2, 1.f, DATA_TYPE_FP16);
params.dqkv_ptr = dqkv.data_ptr();
launch(params, num_chunks, stream);
//SPLIT-K reduction of num_chunks dK, dV parts
// The equivalent of the following Pytorch code:
// using namespace torch::indexing;
// at::Tensor view_out = dqkv.index({Slice(), Slice(1, None, None)});
// torch::sum_out(view_out, dkv, 1);
const int hidden_size = num_heads * head_size;
fmha_run_noloop_reduce(
dqkv.data_ptr(), dkv.data_ptr(), cu_seqlens.data_ptr<int>(), hidden_size, batch_size, total, num_chunks, stream);
return { dqkv, softmax, dkv };
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "Fused Multi-head Self-attention for BERT";
m.def("fwd", &mha_fwd, "Forward pass");
m.def("bwd", &mha_bwd, "Backward pass");
m.def("bwd_nl", &mha_bwd_nl, "Backward pass (small-batch)");
}
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include <cuda.h>
#include <vector>
#ifdef OLD_GENERATOR_PATH
#include <ATen/CUDAGeneratorImpl.h>
#else
#include <ATen/cuda/CUDAGeneratorImpl.h>
#endif
#include <ATen/cuda/CUDAGraphsUtils.cuh>
#include <fmha_utils.h>
constexpr int TOTAL_DIM = 0;
constexpr int THREE_DIM = 1;
constexpr int H_DIM = 2;
constexpr int D_DIM = 3;
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Qkv_params {
// The QKV matrices.
void * __restrict__ qkv_ptr;
// The stride between rows of the Q, K and V matrices.
size_t qkv_stride_in_bytes;
// The number of heads.
int h;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Fused_multihead_attention_fprop_params : public Qkv_params {
// The dQKV matrices.
void * __restrict__ dqkv_ptr;
// Temporary for dKV.
void * __restrict__ dkv_ptr;
// The O matrix (output).
void * __restrict__ o_ptr;
// The stride between rows of O.
int64_t o_stride_in_bytes;
// The pointer to the S matrix, overwritten by the dP matrix (bwd).
void * __restrict__ s_ptr;
// The stride between rows of the S matrix.
int64_t s_stride_in_bytes;
// The dimensions.
int b, s, d;
// The scaling factors for the kernel.
uint32_t scale_bmm1, scale_softmax, scale_bmm2;
// array of length b+1 holding starting offset of each sequence.
int * __restrict__ cu_seqlens;
// The dropout probability (probability of keeping an activation).
float p_dropout;
// Scale factor of 1 / (1 - p_dropout).
float rp_dropout;
// Scale factor of 1 / (1 - p_dropout), in half2.
uint32_t scale_dropout;
// Random state.
at::PhiloxCudaState philox_args;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_params>
struct Launch_params{
Launch_params(cudaDeviceProp * props_,
cudaStream_t stream_,
bool is_training_,
bool is_nl_)
: elts_per_thread(0)
, props(props_)
, stream(stream_)
, is_training(is_training_)
, is_nl(is_nl_) {
}
size_t elts_per_thread;
cudaDeviceProp * props;
cudaStream_t stream;
bool is_training;
Kernel_params params;
int num_full_heads;
int num_main_groups;
int heads_last_wave;
int main_steps;
int rest_steps;
bool is_nl;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
void run_fmha_fp16_128_64_sm80(Launch_params<Fused_multihead_attention_fprop_params> &launch_params, const bool configure);
void run_fmha_fp16_256_64_sm80(Launch_params<Fused_multihead_attention_fprop_params> &launch_params, const bool configure);
void run_fmha_fp16_384_64_sm80(Launch_params<Fused_multihead_attention_fprop_params> &launch_params, const bool configure);
void run_fmha_fp16_512_64_sm80(Launch_params<Fused_multihead_attention_fprop_params> &launch_params, const bool configure);
void run_fmha_dgrad_fp16_128_64_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream);
void run_fmha_dgrad_fp16_256_64_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream);
void run_fmha_dgrad_fp16_384_64_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream);
void run_fmha_dgrad_fp16_512_64_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream);
void run_fmha_fp16_512_64_sm80_nl(const Fused_multihead_attention_fprop_params &params, const bool is_training, const int num_chunks, cudaStream_t stream);
void run_fmha_dgrad_fp16_512_64_sm80_nl(const Fused_multihead_attention_fprop_params &params, const int num_chunks, cudaStream_t stream);
void fmha_run_noloop_reduce(void *out,
const void *in,
const int *cu_seqlens,
const int hidden_size,
const int batch_size,
const int total,
const int num_chunks,
cudaStream_t stream);
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include <fmha/utils.h>
#define FMHA_DIV_UP(m, n) (((m) + (n)-1) / (n))
namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
template< typename Data_type_, int NUM_ELTS_, int BITS_PER_ELT_, int ALIGNMENT_ >
struct Fragment_base_ {
// The data type.
using Data_type = Data_type_;
// default input type
using Input_type_ = Data_type_;
// Does it store the array of elements.
enum { HAS_ELTS = BITS_PER_ELT_ >= 8 };
// The number of elements.
enum { NUM_ELTS = NUM_ELTS_ };
// The size of element in bits.
enum { BITS_PER_ELT = BITS_PER_ELT_ };
// The size of byte of a single register.
enum { BYTES_PER_REG = 4 };
// The size in bits.
enum { BITS_PER_REG = BYTES_PER_REG * 8 };
// The number of registers needed to store the fragment.
enum { NUM_REGS = Div_up<NUM_ELTS * BITS_PER_ELT, BITS_PER_REG>::VALUE };
// The size in bytes (as returned by sizeof(Fragment_base<>).
enum { SIZE_IN_BYTES = NUM_REGS * BYTES_PER_REG };
// The alignment.
enum { ALIGNMENT = ALIGNMENT_ > 0 ? ALIGNMENT_ : Min<NUM_REGS * BYTES_PER_REG, 16>::VALUE };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The type of the elements.
typename Data_type_,
// The number of elements.
int NUM_ELTS_,
// The alignment if you want to force a value -- use 0 otherwise.
int ALIGNMENT_ = 0,
// The base class.
typename Base_ = Fragment_base_<Data_type_, NUM_ELTS_, 8 * sizeof(Data_type_), ALIGNMENT_>
>
struct alignas(static_cast<int>(Base_::ALIGNMENT)) Fragment : public Base_ {
// The size of a load/store.
enum { BYTES_PER_LOAD_STORE = Base_::NUM_REGS * sizeof(uint32_t) };
// Clear the fragment. Using PTX in that code seems to produce better SASS...
inline __device__ void clear() {
#pragma unroll
for( int ii = 0; ii < Base_::NUM_REGS; ++ii ) {
asm volatile("mov.u32 %0, 0; \n" : "=r"(this->reg(ii)) : );
}
}
// Immutable access to a register.
inline __device__ const uint32_t& reg(int ii) const {
return this->regs_[ii];
}
// Mutable access to a register.
inline __device__ uint32_t& reg(int ii) {
return this->regs_[ii];
}
uint32_t regs_[Base_::NUM_REGS];
// Immutable access to the elements.
inline __device__ const Data_type_& elt(int ii) const {
return reinterpret_cast<const Data_type_*>(&this->regs_[0])[ii];
}
// Mutable access to the elements.
inline __device__ Data_type_& elt(int ii) {
return reinterpret_cast<Data_type_*>(&this->regs_[0])[ii];
}
// Immutable access to the elements with a cast.
template< typename Cast_type >
inline __device__ const Cast_type& elt_as(int ii) const {
return reinterpret_cast<const Cast_type*>(&this->regs_[0])[ii];
}
// Mutable access to the elements.
template< typename Cast_type >
inline __device__ Cast_type& elt_as(int ii) {
return reinterpret_cast<Cast_type*>(&this->regs_[0])[ii];
}
// Add another fragment.
inline __device__ void add(const Fragment &other) {
#pragma unroll
for( int ii = 0; ii < NUM_ELTS_; ++ii ) {
this->elt(ii) += other.elt(ii);
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< typename Layout >
struct Fragment_a : public Fragment<uint16_t, 8> {
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< typename Layout >
struct Fragment_b : public Fragment<uint16_t, 8> {
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Fragment_accumulator : public Fragment<float, 8> {
// The base class.
using Base = Fragment<float, 8>;
// Add two fragments.
template< typename Other_fragment_ >
inline __device__ void add(const Other_fragment_ &other) {
for( int ii = 0; ii < Base::NUM_ELTS; ++ii ) {
this->elt(ii) = this->elt(ii) + other.elt(ii);
}
}
// Do the HMMA.
template< typename Layout_a, typename Layout_b >
inline __device__ void mma(const Fragment_a<Layout_a> &a,
const Fragment_b<Layout_b> &b) {
asm volatile( \
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \n" \
" {%0, %1, %2, %3}, \n" \
" {%4, %5, %6, %7}, \n" \
" {%8, %9}, \n" \
" {%0, %1, %2, %3}; \n" \
: "+f"( elt(0)), "+f"( elt(1)), "+f"( elt(2)), "+f"( elt(3))
: "r"(a.reg(0)), "r"(a.reg(1)), "r"(a.reg(2)), "r"(a.reg(3))
, "r"(b.reg(0)), "r"(b.reg(1)));
asm volatile( \
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \n" \
" {%0, %1, %2, %3}, \n" \
" {%4, %5, %6, %7}, \n" \
" {%8, %9}, \n" \
" {%0, %1, %2, %3}; \n" \
: "+f"( elt(4)), "+f"( elt(5)), "+f"( elt(6)), "+f"( elt(7))
: "r"(a.reg(0)), "r"(a.reg(1)), "r"(a.reg(2)), "r"(a.reg(3))
, "r"(b.reg(2)), "r"(b.reg(3)));
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< typename Fragment, int M, int N >
inline __device__ void clear(Fragment (&frag)[M][N]) {
#pragma unroll
for( int mi = 0; mi < M; ++mi ) {
#pragma unroll
for( int ni = 0; ni < N; ++ni ) {
frag[mi][ni].clear();
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< typename Accumulator_type, int WARPS_K >
struct Clear_accumulator {
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int WARPS_K >
struct Clear_accumulator<float, WARPS_K> {
template< typename Acc, int M, int N >
static inline __device__ void apply(Acc (&acc)[M][N], bool = false) {
fmha::clear(acc);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Acc, typename A, typename B, int M, int N>
inline __device__ void gemm(Acc (&acc)[M][N], const A (&a)[M], const B (&b)[N]) {
#pragma unroll
for( int mi = 0; mi < M; ++mi ) {
#pragma unroll
for( int ni = 0; ni < N; ++ni ) {
acc[mi][ni].mma(a[mi], b[ni]);
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The number of rows in the CTA tile.
int M_,
// The number of cols in the CTA tile.
int N_,
// The number of elements in the the K dimension of the GEMM loop.
int K_,
// The number of rows of warps.
int WARPS_M_,
// The number of cols of warps.
int WARPS_N_,
// The number of warps in the K dimension of the GEMM loop.
int WARPS_K_>
struct Cta_tile_ {
enum { M = M_, N = N_, K = K_ };
// The number of warps.
enum { WARPS_M = WARPS_M_, WARPS_N = WARPS_N_, WARPS_K = WARPS_K_ };
// The number of warps per CTA.
enum { WARPS_PER_CTA = WARPS_M * WARPS_N * WARPS_K };
// The number of threads per warp.
enum { THREADS_PER_WARP = 32 };
// The number of threads per CTA.
enum { THREADS_PER_CTA = WARPS_PER_CTA * THREADS_PER_WARP };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Cta_tile>
struct Hmma_tile {
// The number of elements computed with a single warp-MMA.
enum { M_PER_MMA = 16, N_PER_MMA = 16, K_PER_MMA = 16 };
// The number of elements computed with a single CTA-MMA.
enum {
M_PER_MMA_PER_CTA = M_PER_MMA * Cta_tile::WARPS_M,
N_PER_MMA_PER_CTA = N_PER_MMA * Cta_tile::WARPS_N,
K_PER_MMA_PER_CTA = K_PER_MMA * Cta_tile::WARPS_K
};
// The number of MMAs needed to compute the GEMM.
enum {
MMAS_M = Div_up<Cta_tile::M, M_PER_MMA_PER_CTA>::VALUE,
MMAS_N = Div_up<Cta_tile::N, N_PER_MMA_PER_CTA>::VALUE,
MMAS_K = Div_up<Cta_tile::K, K_PER_MMA_PER_CTA>::VALUE,
};
// The number of elements computed per warp.
enum {
M_PER_WARP = MMAS_M * M_PER_MMA,
N_PER_WARP = MMAS_N * N_PER_MMA,
K_PER_WARP = MMAS_K * K_PER_MMA,
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
using A_type = uint16_t;
using B_type = uint16_t;
using C_type = uint16_t;
using Accumulator_type = float;
using Epilogue_type = float;
constexpr int BITS_PER_ELEMENT_A = sizeof(A_type) * 8;
constexpr int BITS_PER_ELEMENT_B = sizeof(B_type) * 8;
constexpr int BITS_PER_ELEMENT_C = sizeof(C_type) * 8;
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int M, int N, int K, int WARPS_M, int WARPS_N, int WARPS_K>
using Cta_tile_extd = Cta_tile_<M, N, K, WARPS_M, WARPS_N, WARPS_K>;
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Cta_tile_>
using Cta_tile_with_k_with_padding = Cta_tile_extd<Cta_tile_::M,
Cta_tile_::N,
Next_power_of_two<Cta_tile_::K>::VALUE,
Cta_tile_::WARPS_M,
Cta_tile_::WARPS_N,
Cta_tile_::WARPS_K>;
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The number of bits per element.
int BITS_PER_ELEMENT,
// The number of rows of Q, K or V loaded by this tile.
int ROWS,
// The number of columns.
int COLS,
// The number of matrics.
int NUM_MATS = 3
>
struct Gmem_tile_qkv {
// The size of each LDG.
enum { BYTES_PER_LDG = 16 };
// The size of a row in bytes.
enum { BYTES_PER_ROW = COLS * BITS_PER_ELEMENT / 8 };
// The number of threads to load a "row" of the matrix.
enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_LDG };
// The number of "rows" loaded per LDG.
enum { ROWS_PER_LDG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW };
// The number of LDGs needed to load a chunk of the Q matrix.
enum { LDGS = fmha::Div_up<ROWS, ROWS_PER_LDG>::VALUE };
// Ctor.
template< typename Params, typename BInfo >
inline __device__ Gmem_tile_qkv(const Params &params, const int qkv_offset, const BInfo &binfo, const int tidx)
: params_qkv_stride_in_bytes_(params.qkv_stride_in_bytes)
, actual_seqlen(binfo.actual_seqlen)
, qkv_ptr_(reinterpret_cast<char *>(params.qkv_ptr)) {
// Compute the position in the sequence (within the CTA for the moment).
int row = tidx / THREADS_PER_ROW;
// Compute the position of the thread in the row.
int col = tidx % THREADS_PER_ROW;
// Store the row as we need it to disable the loads.
row_ = row;
// The row offset in the batched GEMM. For each seq element, we store QKV in that order.
int64_t row_offset = (int64_t)row * params.qkv_stride_in_bytes;
// Add the block index.
row_offset += (int64_t)((binfo.sum_s * NUM_MATS + qkv_offset) * binfo.h + binfo.bidh) * BYTES_PER_ROW;
// Assemble the final pointer.
qkv_ptr_ += row_offset + col * BYTES_PER_LDG;
}
// Store data to shared memory.
template< typename Smem_tile >
inline __device__ void commit(Smem_tile &smem_tile) {
smem_tile.store(fetch_);
}
// Load data from memory.
template< typename Smem_tile >
inline __device__ void load(Smem_tile &smem_tile) {
const void *ptrs[LDGS];
uint32_t preds[LDGS];
#pragma unroll
for( int ii = 0; ii < LDGS; ++ii ) {
ptrs[ii] = qkv_ptr_ + (int64_t)ii * ROWS_PER_LDG * params_qkv_stride_in_bytes_;
preds[ii] = ((row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen));
fetch_[ii] = make_uint4(0, 0, 0, 0);
}
// not packing predicates removes restrictions (e.g. FP16 384, 4 warps)
Ldg_functor<uint4, LDGS> fct(fetch_, ptrs);
#pragma unroll
for( int ii = 0; ii < LDGS; ++ii ) {
fct.load(ii, preds[ii]);
}
}
// Store data to memory.
inline __device__ void store(const uint4 (&data)[LDGS]) {
#pragma unroll
for( int ii = 0; ii < LDGS; ++ii ) {
char *ptr = qkv_ptr_ + (int64_t)ii * ROWS_PER_LDG * params_qkv_stride_in_bytes_;
if( (row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen) ) {
fmha::stg(ptr, data[ii]);
}
}
}
// Move the pointer to the next location.
inline __device__ void move() {
qkv_ptr_ += (int64_t)ROWS * params_qkv_stride_in_bytes_;
actual_seqlen -= ROWS;
}
inline __device__ void move(int steps) {
qkv_ptr_ += (int64_t)ROWS * params_qkv_stride_in_bytes_ * steps;
actual_seqlen -= ROWS * steps;
}
// The stride between rows for the QKV matrice.
int64_t params_qkv_stride_in_bytes_;
// The pointer.
char *qkv_ptr_;
// The fetch registers.
uint4 fetch_[LDGS];
// Keep track of the row the thread is processing as we move the tile.
int row_;
// The length of the sequence loaded by that memory tile.
int actual_seqlen;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< typename Cta_tile >
struct Gmem_tile_o {
// The mma tile.
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
// The size of each element.
enum { BYTES_PER_ELEMENT = 2 };
// The size of a row in bytes.
enum { BYTES_PER_ROW = Cta_tile::N * BYTES_PER_ELEMENT };
// The number of threads to store a "row" of the matrix.
enum { THREADS_PER_ROW = 16 };
// The size of each STG.
enum { BYTES_PER_STG = BYTES_PER_ROW / THREADS_PER_ROW };
// The number of "rows" stored per iteration of the loop. The output of 1 MMA.
enum { ROWS = Cta_tile::M };
// The number of "rows" stored per iteration of the loop. The output of 1 MMA.
enum { ROWS_PER_LOOP = ROWS <= 64 ? ROWS : (int)Mma_tile::M_PER_MMA_PER_CTA };
// The number of outter loop for the stores.
enum { LOOPS = ROWS / ROWS_PER_LOOP };
// The number of "rows" stored per STG.
enum { ROWS_PER_STG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW };
// Do we have to guard against partial writes/reads.
enum { HAS_INCOMPLETE_STG = Cta_tile::M % ROWS_PER_STG != 0 };
// The number of STGs needed to store a chunk of the Q matrix.
enum { STGS_PER_LOOP = fmha::Div_up<ROWS_PER_LOOP, ROWS_PER_STG>::VALUE };
// The number of STGs needed to store a chunk of the Q matrix in total.
enum { STGS = STGS_PER_LOOP * LOOPS };
// Ctor.
template<typename Params, typename BInfo>
inline __device__ Gmem_tile_o(const Params &params, const BInfo &binfo, int tidx)
: params_o_stride_in_bytes_(params.o_stride_in_bytes)
, actual_seqlen_(binfo.actual_seqlen)
, o_ptr_(reinterpret_cast<char *>(params.o_ptr)) {
// Compute the position in the sequence (within the CTA for the moment).
int row = tidx / THREADS_PER_ROW;
// Compute the position of the thread in the row.
int col = tidx % THREADS_PER_ROW;
// Store the row as we need it to disable loads.
row_ = row;
// The row offset in the batched GEMM.
int64_t row_offset = (int64_t)row * params.o_stride_in_bytes + binfo.bidx * BYTES_PER_ROW;
// Assemble the final pointer.
o_ptr_ += row_offset + col * BYTES_PER_STG;
// Is that thread active on the last STG?
if( HAS_INCOMPLETE_STG ) {
is_active_for_last_stg_ = row + (STGS - 1) * ROWS_PER_STG < Cta_tile::M;
}
}
// Store data to global memory.
inline __device__ void store(const uint4 (&src)[STGS_PER_LOOP], int mi) {
#pragma unroll
for( int ii = 0; ii < STGS_PER_LOOP; ++ii ) {
int jj = mi * STGS_PER_LOOP + ii;
if( this->row_ + jj * ROWS_PER_STG >= this->actual_seqlen_ ) {
break;
}
float x = reinterpret_cast<const float &>(src[ii].x);
float y = reinterpret_cast<const float &>(src[ii].y);
float z = reinterpret_cast<const float &>(src[ii].z);
float w = reinterpret_cast<const float &>(src[ii].w);
uint2 out = float4_to_half4(x, y, z, w);
if( !HAS_INCOMPLETE_STG || (jj < STGS - 1 || this->is_active_for_last_stg_) ) {
fmha::stg(this->o_ptr_ + jj * ROWS_PER_STG * this->params_o_stride_in_bytes_, out);
}
}
}
// Move the pointer to the next location.
inline __device__ void move() {
row_ += ROWS;
o_ptr_ += (int64_t)ROWS * params_o_stride_in_bytes_;
}
inline __device__ void move(const int steps) {
row_ += ROWS * steps;
o_ptr_ += (int64_t)ROWS * params_o_stride_in_bytes_ * steps;
}
// The stride between rows for the QKV matrice.
int64_t params_o_stride_in_bytes_;
// The pointer.
char *o_ptr_;
// Is the thread active for the last STG?
int is_active_for_last_stg_;
// Keep track of the row to disable loads.
int row_;
// The length of the sequence loaded by that memory tile.
int actual_seqlen_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< typename Cta_tile, int BYTES_PER_ELEMENT >
struct Gmem_tile_mma_sd {
// The mma tile.
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
// Each STG stores 8 elements.
enum { BYTES_PER_STG = BYTES_PER_ELEMENT * 8 };
// The number of MMAs in the M dimension.
enum { MMAS_M = Mma_tile::MMAS_M };
// The number of MMAs in the N dimension.
enum { MMAS_N = Mma_tile::MMAS_N };
// The number of rows computed per MMA per thread block.
enum { M_PER_MMA_PER_CTA = Mma_tile::M_PER_MMA_PER_CTA };
// The number of cols computed per MMA per thread block.
enum { N_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA };
// The number of threads per block.
enum { THREADS_PER_CTA = Cta_tile::THREADS_PER_CTA };
// The size of each row in bytes. I.e. how many bytes are stored per STG.
enum { BYTES_PER_ROW = THREADS_PER_CTA * BYTES_PER_STG };
// The fixed sequence length.
enum { SEQLEN = Cta_tile::N };
// The distance between two blocks (in bytes).
enum { BLOCK_STRIDE_BYTES = SEQLEN * SEQLEN * BYTES_PER_ELEMENT };
// The distance between elements stored per loop (in bytes).
enum { LOOP_STRIDE_BYTES = MMAS_M * MMAS_N * BYTES_PER_ROW };
// The type of elements stored per STG.
using Type = typename fmha::Uint_from_size_in_bytes<BYTES_PER_STG>::Type;
// Ctor.
template<typename Params>
inline __device__ Gmem_tile_mma_sd(void *ptr, const Params &params, const int bidb, const int bidh, const int tidx)
: ptr_(static_cast<char *>(ptr)) {
// The block index.
size_t bidx = bidb * params.h + bidh;
// Set store location for each thread at the beginning of the loop
ptr_ += bidx * BLOCK_STRIDE_BYTES + tidx * BYTES_PER_STG;
}
// Store to global memory.
inline __device__ void store(const Type &data, const int mi, const int ni) {
size_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW;
fmha::stg(ptr_ + offset, data);
}
// Load from global memory.
inline __device__ void load(Type &data, const int mi, const int ni) {
size_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW;
fmha::ldg(data, ptr_ + offset);
}
// Move to the next tile.
inline __device__ void move() {
ptr_ += LOOP_STRIDE_BYTES;
}
inline __device__ void move(const int steps) {
ptr_ += LOOP_STRIDE_BYTES * steps;
}
// The pointer in global memory.
char *ptr_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< typename Cta_tile, typename Base = Gmem_tile_mma_sd<Cta_tile, sizeof(uint16_t)> >
struct Gmem_tile_mma_s : public Base {
// The number of mmas in the vertical dimension.
enum { M = Base::MMAS_M };
// The number of mmas in the horizontal dimension.
enum { N = Base::MMAS_N };
// The type of the vectors stored by each STG.
using Type = typename Base::Type;
// Ctor.
template< typename Params, typename Block_info >
inline __device__ Gmem_tile_mma_s(const Params &params, const Block_info& binfo, const int tidx)
: Base(params.s_ptr, params, binfo.bidb, binfo.bidh, tidx) {
}
// Store to global memory.
template<typename Mask>
inline __device__ void store(const float (&softmax)[2 * M][4 * N], const Mask &mask) {
#pragma unroll
for( int mi = 0; mi < M; mi++ ) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
float tmp00 = softmax[2 * mi + 0][4 * ni + 0];
float tmp01 = softmax[2 * mi + 0][4 * ni + 1];
float tmp02 = softmax[2 * mi + 0][4 * ni + 2];
float tmp03 = softmax[2 * mi + 0][4 * ni + 3];
float tmp10 = softmax[2 * mi + 1][4 * ni + 0];
float tmp11 = softmax[2 * mi + 1][4 * ni + 1];
float tmp12 = softmax[2 * mi + 1][4 * ni + 2];
float tmp13 = softmax[2 * mi + 1][4 * ni + 3];
uint4 dst;
dst.x = fmha::float2_to_half2(tmp00, tmp01);
dst.y = fmha::float2_to_half2(tmp02, tmp03);
dst.z = fmha::float2_to_half2(tmp10, tmp11);
dst.w = fmha::float2_to_half2(tmp12, tmp13);
if( mask.is_valid(mi, ni, 0, 0) ) {
Base::store(dst, mi, ni);
}
}
}
}
// Store to global memory.
template<typename Mask, typename Fragment>
inline __device__ void store(const Fragment (&frag)[N][M], const Mask& mask){
#pragma unroll
for( int mi = 0; mi < M; mi++ ) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
uint4 dst;
dst.x = frag[ni][mi].reg(0);
dst.y = frag[ni][mi].reg(2);
dst.z = frag[ni][mi].reg(1);
dst.w = frag[ni][mi].reg(3);
if( mask.any_valid(mi, ni) ) {
Base::store(dst, mi, ni);
}
}
}
}
// Load from global memory.
template<typename Mask>
inline __device__ void load(uint4 (&regs)[M][N], const Mask &mask) {
#pragma unroll
for( int mi = 0; mi < M; mi++ ) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
regs[mi][ni] = make_uint4(0, 0, 0, 0);
if( mask.any_valid(mi, ni) ) {
Base::load(regs[mi][ni], mi, ni);
}
}
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The base class.
typename Base = fmha::Gmem_tile_qkv<Cta_tile, fmha::BITS_PER_ELEMENT_A, Cta_tile::M, Cta_tile::K>
>
struct Gmem_tile_dout : public Base {
// Ctor.
template<typename Params, typename BInfo>
inline __device__ Gmem_tile_dout(const Params &params, const BInfo &binfo, int tidx)
: Base(params, 0, binfo, tidx) {
this->qkv_ptr_ = reinterpret_cast<char *>(params.o_ptr);
this->params_qkv_stride_in_bytes_ = params.o_stride_in_bytes; // needed for move
// Compute the position of the thread in the row.
int col = tidx % Base::THREADS_PER_ROW;
// The row offset in the batched GEMM. For each seq element, we store O in that order.
int64_t row_offset = (int64_t)this->row_ * params.o_stride_in_bytes + binfo.bidx * Base::BYTES_PER_ROW;
// Assemble the final pointer.
this->qkv_ptr_ += row_offset + col * Base::BYTES_PER_LDG;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< typename Cta_tile, typename Base = fmha::Gmem_tile_o<Cta_tile> >
struct Gmem_tile_dq : public Base {
// Ctor.
template<typename Params, typename BInfo>
inline __device__ Gmem_tile_dq(const Params &params, const BInfo &binfo, int tidx)
: Base(params, binfo, tidx) {
this->o_ptr_ = reinterpret_cast<char *>(params.dqkv_ptr);
this->params_o_stride_in_bytes_ = params.qkv_stride_in_bytes; // needed for move
// Compute the position of the thread in the row.
int col = tidx % Base::THREADS_PER_ROW;
// The row offset in the batched GEMM. For each seq element, we store O in that order.
int64_t row_offset = (int64_t)this->row_ * params.qkv_stride_in_bytes +
(binfo.sum_s * 3 * binfo.h + binfo.bidh) * Base::BYTES_PER_ROW;
// Assemble the final pointer.
this->o_ptr_ += row_offset + col * Base::BYTES_PER_STG;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int S, int D, int STEP, int WARPS_M, int WARPS_N, uint32_t FLAGS = 0x08u>
struct FMHA_kernel_traits {
// The CTA description for the 1st GEMM.
using Cta_tile_p = fmha::Cta_tile_extd<STEP, S, D, WARPS_M, WARPS_N, 1>;
// The CTA description for the 2nd GEMM.
using Cta_tile_o = fmha::Cta_tile_extd<STEP, D, S, WARPS_M, 1, WARPS_N>;
// Do we use one buffer for K and V.
enum { SHARE_SMEM_FOR_K_AND_V = (FLAGS & 0x08u) != 0u };
// Do we keep K in registers.
enum { K_IN_REGS = (FLAGS & 0x10u) == 0u };
// The global memory tile to load Q.
using Gmem_tile_q = fmha::Gmem_tile_qkv<Cta_tile_p, fmha::BITS_PER_ELEMENT_A, STEP, D>;
// The shared memory tile to swizzle Q.
using Smem_tile_q = fmha::Smem_tile_a<Cta_tile_p, fmha::Row, Gmem_tile_q::BYTES_PER_LDG, 1>;
// The global memory tile to load K.
using Gmem_tile_k = fmha::Gmem_tile_qkv<Cta_tile_p, fmha::BITS_PER_ELEMENT_B, S, D>;
// The shared memory tile to swizzle K.
using Smem_tile_k = fmha::Smem_tile_b<Cta_tile_p, fmha::Col>;
// The global memory tile to load V.
using Gmem_tile_v = fmha::Gmem_tile_qkv<Cta_tile_o, fmha::BITS_PER_ELEMENT_B, S, D>;
// The shared memory tile to swizzle V.
using Smem_tile_v = fmha::Smem_tile_v<Cta_tile_o>;
// The global memory tile to store O.
using Gmem_tile_o = fmha::Gmem_tile_o<Cta_tile_o>;
// The shared memory tile for O.
using Smem_tile_o = fmha::Smem_tile_o<Cta_tile_o>;
// The global memory tile to load/store S.
using Gmem_tile_s = fmha::Gmem_tile_mma_s<Cta_tile_p>;
// The shared memory tile to transpose S.
using Smem_tile_st = fmha::Smem_tile_mma_transposed<Cta_tile_p>;
using Gmem_tile_do = fmha::Gmem_tile_dout<Cta_tile_p>;
// Make sure the number of threads match.
static_assert((int)Gmem_tile_o::THREADS_PER_ROW == (int)Smem_tile_o::THREADS_PER_ROW, "");
// The number of threads.
enum { THREADS = Cta_tile_p::THREADS_PER_CTA };
// Make sure the number of threads matches both CTAs.
static_assert((int)THREADS == (int)Cta_tile_o::THREADS_PER_CTA, "");
// The amount of shared memory needed to load Q and K.
enum { BYTES_PER_SMEM_QK = Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE };
// The extra amount of shared memory needed to load V.
enum { BYTES_PER_SMEM_V = SHARE_SMEM_FOR_K_AND_V ? 0u : Smem_tile_v::BYTES_PER_TILE };
// The amount of shared memory needed for Q, K and V..
enum { BYTES_PER_SMEM_QKV = BYTES_PER_SMEM_QK + BYTES_PER_SMEM_V };
// The amount of shared memory needed to load Q and store O.
enum { BYTES_PER_SMEM_QO = Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE };
// The amount of shared memory needed for Q, K, V and O.
enum { BYTES_PER_SMEM = fmha::Max<BYTES_PER_SMEM_QKV, BYTES_PER_SMEM_QO>::VALUE };
// Make sure we have enough shared memory.
static_assert(Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE <= BYTES_PER_SMEM, "");
};
////////////////////////////////////////////////////////////////////////////////////////////////////
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
namespace fmha {
template<typename Cta_tile>
struct Mask {
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
template<typename Params, typename BInfo>
__device__ Mask(const Params &params, const BInfo &blockInfo, int tidx) {
actual_seqlen = blockInfo.actual_seqlen;
const int warp = tidx / Cta_tile::THREADS_PER_WARP;
const int lane = tidx % Cta_tile::THREADS_PER_WARP;
static_assert(Cta_tile::WARPS_K == 1, "");
// find the warp in the Cta tile
const int warp_n = (warp / Cta_tile::WARPS_M);
const int warp_m = (warp % Cta_tile::WARPS_M);
// decompose warp into 8x4 tile
const int quad = lane / 4;
const int tid = (lane % 4) * 2;
row = warp_m * 16 + quad;
col = warp_n * 16 + tid;
}
inline __device__ bool is_valid(const int mi, const int ni, const int ii, const int jj) const {
// ii and jj iterate over the 2x4 fragment
const bool col_valid = (ni * Mma_tile::N_PER_MMA_PER_CTA + col + (jj & 2) * 4 + (jj & 1)) < actual_seqlen;
//&& (row + mi * Mma_tile::M_PER_MMA_PER_CTA + ii * 8) < actual_seqlen;
return col_valid;
// return row_valid && col_valid;
}
//BERT Mask: if upper left is invalid, none are valid
inline __device__ bool any_valid(int mi, int ni) const {
return is_valid(mi, ni, 0, 0);
}
inline __device__ void load(int it) {
row_offset = it * Cta_tile::M + row;
}
int row_offset;
int row;
int col;
int actual_seqlen;
};
} // namespace fmha
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include <fmha/utils.h>
#include <fmha/gemm.h>
namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The description of the tile computed by this CTA.
typename Cta_tile,
// The number of rows in the 2D shared memory buffer.
int M_,
// The number of cols.
int N_,
// The size in bits of each element.
int BITS_PER_ELEMENT_,
// The number of bytes per STS.
int BYTES_PER_STS_ = 16,
// The number of buffers. (Used in multistage and double buffer cases.)
int BUFFERS_PER_TILE_ = 1,
// Do we enable the fast path for LDS.128 and friends.
int ENABLE_LDS_FAST_PATH_ = 0,
// The number of rows that are used for the XOR swizzling to allow fast STS/LDS.
int ROWS_PER_XOR_PATTERN_ = 8,
// The number of cols that are used for the XOR swizzling to allow fast STS/LDS.
int COLS_PER_XOR_PATTERN_ = 1,
// Use or not predicates
bool USE_PREDICATES_ = true
>
struct Smem_tile_without_skews {
// The size in bits of each element.
enum { BITS_PER_ELEMENT = BITS_PER_ELEMENT_ };
// The size in bytes of a single STS.
enum { BYTES_PER_STS = BYTES_PER_STS_ };
// The number of elements per STS.
enum { ELEMENTS_PER_STS = BYTES_PER_STS * 8 / BITS_PER_ELEMENT };
// To support arbitrary N, we pad some values to a power-of-2.
enum { N_WITH_PADDING = Next_power_of_two<N_>::VALUE };
// The number of bytes per row without packing of rows.
enum { BYTES_PER_ROW_BEFORE_PACKING = N_WITH_PADDING * BITS_PER_ELEMENT / 8 };
// The number of bytes per row -- we want at least 128B per row.
enum { BYTES_PER_ROW = Max<BYTES_PER_ROW_BEFORE_PACKING, 128>::VALUE };
// The number of rows in shared memory (two rows may be packed into a single one).
enum { ROWS = M_ * BYTES_PER_ROW_BEFORE_PACKING / BYTES_PER_ROW };
// The number of threads per row.
enum { THREADS_PER_ROW_UNBOUNDED = BYTES_PER_ROW / BYTES_PER_STS };
// The number of threads per row.
enum { THREADS_PER_ROW = Min<Cta_tile::THREADS_PER_CTA, THREADS_PER_ROW_UNBOUNDED>::VALUE };
// The number of STS per row.
enum { STS_PER_ROW = BYTES_PER_ROW / THREADS_PER_ROW / BYTES_PER_STS };
// It must be at least one.
static_assert(STS_PER_ROW >= 1, "");
// The number of rows written with a single STS.
enum { ROWS_PER_STS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW };
// Make sure we write to at least one row per STS. Thanks Dr. Obvious ;)
static_assert(ROWS_PER_STS >= 1, "");
// The number of STS needed to store all rows.
enum { STS_PER_COL = Div_up<ROWS, ROWS_PER_STS>::VALUE };
// The number of STS in total.
enum { STS = STS_PER_COL * STS_PER_ROW };
// The size of one buffer in bytes in shared memory.
enum { BYTES_PER_BUFFER = STS * BYTES_PER_STS * Cta_tile::THREADS_PER_CTA };
// The number of buffers.
enum { BUFFERS_PER_TILE = BUFFERS_PER_TILE_ };
// The size in bytes of total buffers.
enum { BYTES_PER_TILE = BYTES_PER_BUFFER * BUFFERS_PER_TILE };
// The boundary for smem_read_offset and smem_write_offset increment.
enum { BYTES_PER_TILE_INC_BOUNDARY = BYTES_PER_TILE - BYTES_PER_BUFFER };
// Do we enable the LDS.128 fast path?
enum { ENABLE_LDS_FAST_PATH = ENABLE_LDS_FAST_PATH_ };
static_assert(ENABLE_LDS_FAST_PATH == 0);
// The number of rows that are used for the XOR swizzling to allow fast STS/LDS.
enum { ROWS_PER_XOR_PATTERN = ROWS_PER_XOR_PATTERN_ };
// The number of cols that are used for the XOR swizzling to allow fast STS/LDS.
enum { COLS_PER_XOR_PATTERN = COLS_PER_XOR_PATTERN_ * 16 / BYTES_PER_STS };
// Use or not predicates
enum { USE_PREDICATES = USE_PREDICATES_ };
// The type of elements that are stored in shared memory by each thread.
using Store_type = typename Uint_from_size_in_bytes<BYTES_PER_STS>::Type;
// Ctor.
inline __device__ Smem_tile_without_skews(void *smem, int tidx)
: smem_(__nvvm_get_smem_pointer(smem)) {
// The row written by a thread. See doc/mma_smem_layout.xlsx.
int smem_write_row = tidx / THREADS_PER_ROW;
// The XOR pattern.
int smem_write_xor = smem_write_row % ROWS_PER_XOR_PATTERN * COLS_PER_XOR_PATTERN;
// Compute the column and apply the XOR pattern.
int smem_write_col = (tidx % THREADS_PER_ROW) ^ smem_write_xor;
// The offset.
this->smem_write_offset_ = smem_write_row*BYTES_PER_ROW + smem_write_col*BYTES_PER_STS;
// TODO: Why not merge it with the read offset?
this->smem_read_buffer_ = __shfl_sync(0xffffffff, 0, 0);
this->smem_write_buffer_ = __shfl_sync(0xffffffff, 0, 0);
}
// Compute the store pointers.
template< int N >
inline __device__ void compute_store_pointers(uint32_t (&ptrs)[N]) {
#pragma unroll
for( int ii = 0; ii < N; ++ii ) {
// Decompose the STS into row/col.
int row = ii / STS_PER_ROW;
int col = ii % STS_PER_ROW;
// Assemble the offset.
int offset = smem_write_offset_ + row*ROWS_PER_STS*BYTES_PER_ROW;
// Take the column into account.
if( STS_PER_ROW > 1 ) {
offset += col*THREADS_PER_ROW*BYTES_PER_STS;
}
// Apply the XOR pattern if needed.
if( ROWS_PER_STS < ROWS_PER_XOR_PATTERN ) {
const int m = row * ROWS_PER_STS % ROWS_PER_XOR_PATTERN;
offset ^= m * COLS_PER_XOR_PATTERN * BYTES_PER_STS;
}
// Assemble the final pointer :)
ptrs[ii] = smem_ + offset + smem_write_buffer_;
}
}
inline __device__ void debug_reset() {
for( int buffer = 0; buffer < BYTES_PER_TILE; buffer += BYTES_PER_BUFFER) {
for( int row = 0; row < ROWS; ++row ) {
for( int col = 0; col < BYTES_PER_ROW; col += 4 ) {
if( threadIdx.x == 0 ) {
uint32_t val = 0x0;
sts(val, smem_ + row*BYTES_PER_ROW + col + buffer);
}
}
}
}
}
// Print the content of the tile (only for debug ;)).
inline __device__ void debug_print() const {
for( int buffer = 0; buffer < BYTES_PER_TILE; buffer += BYTES_PER_BUFFER) {
for( int row = 0; row < ROWS; ++row ) {
for( int col = 0; col < BYTES_PER_ROW; col += 4 ) {
if( threadIdx.x == 0 ) {
uint32_t val;
lds(val, smem_ + row*BYTES_PER_ROW + col + buffer);
printf("block=(x=%2d, y=%2d, z=%2d) (smem_=%2d, buffer=%2d, row=%2d, byte=%4d)=0x%08x\n",
blockIdx.x,
blockIdx.y,
blockIdx.z,
smem_,
buffer,
row,
col,
val);
}
}
}
}
}
// Move the read offset to next buffer.
inline __device__ void move_to_next_read_buffer() {
if( BUFFERS_PER_TILE > 1 && smem_read_buffer_ >= BYTES_PER_TILE_INC_BOUNDARY ) {
this->smem_read_buffer_ -= BYTES_PER_TILE_INC_BOUNDARY;
} else if( BUFFERS_PER_TILE > 1 ) {
this->smem_read_buffer_ += BYTES_PER_BUFFER;
}
}
// Move the read offset to next buffer. TODO: Remove this member function!!!
inline __device__ void move_next_read_buffer() {
this->move_to_next_read_buffer();
}
// Move the read offset to next N buffer (circular-buffer).
inline __device__ void move_to_next_read_buffer(int N) {
if( BUFFERS_PER_TILE > 1 ) {
this->smem_read_buffer_ += N * BYTES_PER_BUFFER;
this->smem_read_buffer_ -= smem_read_buffer_ >= BYTES_PER_TILE ? BYTES_PER_TILE : 0;
}
}
// Move the read offset to next N buffer (circular-buffer). TODO: Remove this member function!!!
inline __device__ void move_next_read_buffer(int N) {
this->move_to_next_read_buffer(N);
}
// Move the write offset to next buffer.
inline __device__ void move_to_next_write_buffer() {
if( BUFFERS_PER_TILE > 1 && smem_write_buffer_ >= BYTES_PER_TILE_INC_BOUNDARY ) {
this->smem_write_buffer_ -= BYTES_PER_TILE_INC_BOUNDARY;
} else if( BUFFERS_PER_TILE > 1 ) {
this->smem_write_buffer_ += BYTES_PER_BUFFER;
}
}
// Move the write offset to next buffer. TODO: Remove that member function!
inline __device__ void move_next_write_buffer() {
this->move_to_next_write_buffer();
}
// Move the read offset.
inline __device__ void move_read_offset(int delta) {
this->smem_read_offset_ += delta;
}
// Move the write offset.
inline __device__ void move_write_offset(int delta) {
this->smem_write_offset_ += delta;
}
// Store to the tile in shared memory.
template< int N >
inline __device__ void store(const Store_type (&data)[N], uint64_t = 0) {
uint32_t smem_ptrs[N];
this->compute_store_pointers(smem_ptrs);
sts(smem_ptrs, data);
}
// Store to the tile in shared memory.
template< int N, int M >
inline __device__ void store(const Store_type (&data)[N], uint32_t (&preds)[M], uint64_t = 0) {
uint32_t smem_ptrs[N];
this->compute_store_pointers(smem_ptrs);
sts(smem_ptrs, data, preds);
}
// Store to the tile in shared memory.
template< int N >
inline __device__ void store(const Store_type (&data)[N], uint32_t preds, uint64_t = 0) {
this->store(data, preds);
}
// Store to the tile in shared memory.
template< int N >
inline __device__ void store(const void* (&gmem_ptrs)[N], uint32_t preds, uint64_t = 0) {
uint32_t tmp[1] = { preds };
this->store(gmem_ptrs, tmp);
}
// The shared memory pointer.
uint32_t smem_;
// The read offset. Reserve 4 offsets if needed.
int smem_read_offset_;
// The write offset.
int smem_write_offset_;
// The buffer base offset for read.
int smem_read_buffer_;
// The buffer base offset for write.
int smem_write_buffer_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The layout of the tile.
typename Layout,
// The size of the STS.
int BYTES_PER_STS = 16,
// The number of buffers per tile.
int BUFFERS_PER_TILE = 1,
// Use or not predicates
bool USE_PREDICATES = true
>
struct Smem_tile_a {
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int MMAS_K, int MMAS_K_WITH_PADDING >
struct Compute_reset_mask {
// The potential mask.
enum { HALF = MMAS_K_WITH_PADDING / 2 };
// The remainder.
enum { MOD = MMAS_K % HALF };
// The final value.
enum { VALUE = (MMAS_K == MOD ? 0 : HALF) | Compute_reset_mask<MOD, HALF>::VALUE };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int MMAS_K_WITH_PADDING >
struct Compute_reset_mask<0, MMAS_K_WITH_PADDING> {
enum { VALUE = 0 };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int MMAS_K >
struct Compute_reset_mask<MMAS_K, MMAS_K> {
enum { VALUE = MMAS_K - 1 };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
struct Rows_per_xor_pattern_a {
// The size in bits.
enum { N_IN_BITS = N * fmha::BITS_PER_ELEMENT_A };
// The number of rows.
enum { VALUE = N_IN_BITS <= 256 ? 2 : (N_IN_BITS <= 512 ? 4 : 8) };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
struct Rows_per_xor_pattern_row_a : public Rows_per_xor_pattern_a<N> {
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The size of the STS.
int BYTES_PER_STS,
// The number of buffers per tile.
int BUFFERS_PER_TILE,
// How many rows to use for the XOR pattern to avoid bank conflicts?
int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_row_a<Cta_tile::K>::VALUE
>
struct Smem_tile_row_a : public Smem_tile_without_skews<Cta_tile,
Cta_tile::M,
Cta_tile::K,
fmha::BITS_PER_ELEMENT_A,
BYTES_PER_STS,
BUFFERS_PER_TILE,
0,
ROWS_PER_XOR_PATTERN_,
1> {
// The MMA tile.
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
// The base class.
using Base = Smem_tile_without_skews<Cta_tile,
Cta_tile::M,
Cta_tile::K,
fmha::BITS_PER_ELEMENT_A,
BYTES_PER_STS,
BUFFERS_PER_TILE,
0,
ROWS_PER_XOR_PATTERN_,
1>;
// The fragment.
using Fragment = Fragment_a<Row>;
// When we use padding to reach a power of two, special care has to be taken.
using Cta_tile_with_padding = Cta_tile_with_k_with_padding<Cta_tile>;
// The number of MMAs.
using Mma_tile_with_padding = fmha::Hmma_tile<Cta_tile_with_padding>;
// The size of a single LDS in bytes.
enum { BYTES_PER_LDS = 16 };
// Ctor.
inline __device__ Smem_tile_row_a(void *smem, int tidx) : Base(smem, tidx) {
// For documentation on the layout, see doc/mma_smem_layout.xlsx.
// The number of warps.
const int WARPS_M = Cta_tile::WARPS_M;
const int WARPS_N = Cta_tile::WARPS_N;
const int WARPS_K = Cta_tile::WARPS_K;
static_assert(WARPS_M == 1);
static_assert(WARPS_N == 4 || WARPS_N == 8);
static_assert(WARPS_K == 1);
static_assert(Base::ROWS_PER_XOR_PATTERN == 8);
// The row and column read by the thread.
int smem_read_row = (tidx & 0x0f);
int smem_read_col = (tidx & 0x07);
smem_read_col ^= (tidx & 0x10) / 16;
// The shared memory offset.
this->smem_read_offset_ = smem_read_row*Base::BYTES_PER_ROW + smem_read_col*BYTES_PER_LDS;
}
// Rewind smem_read_offset for last LDS phase in main loop.
inline __device__ void reverse_smem_read_offset(int ki = 0) {
// Undo the pointer increment for the next ni.
// Should match the load function below for ki = 0.
if( Mma_tile_with_padding::MMAS_K >= 2 ) {
this->smem_read_offset_ ^= BYTES_PER_LDS * 2;
}
}
// Load from shared memory.
inline __device__ void load(Fragment (&a)[Mma_tile::MMAS_M], int ki) {
#pragma unroll
for( int mi = 0; mi < Mma_tile::MMAS_M; ++mi ) {
// Jump by as many matrix rows as needed (a row in smem may pack multiple matrix rows).
int offset = mi * Mma_tile::M_PER_MMA_PER_CTA * Base::BYTES_PER_ROW_BEFORE_PACKING;
// Load using LDSM.M88.4.
uint4 tmp;
ldsm(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset);
// Store the value into the fragment.
a[mi].reg(0) = tmp.x;
a[mi].reg(1) = tmp.y;
a[mi].reg(2) = tmp.z;
a[mi].reg(3) = tmp.w;
}
// Move the offset to the next possition. See doc/mma_smem_layout.xlsx.
static_assert(Mma_tile_with_padding::MMAS_K < 64, "Not implemented");
if( Mma_tile_with_padding::MMAS_K >= 32 && ki % 16 == 15 ) {
this->smem_read_offset_ ^= 31 * BYTES_PER_LDS * 2;
} else if( Mma_tile_with_padding::MMAS_K >= 16 && ki % 8 == 7 ) {
this->smem_read_offset_ ^= 15 * BYTES_PER_LDS * 2;
} else if( Mma_tile_with_padding::MMAS_K >= 8 && ki % 4 == 3 ) {
this->smem_read_offset_ ^= 7 * BYTES_PER_LDS * 2;
} else if( Mma_tile_with_padding::MMAS_K >= 4 && ki % 2 == 1 ) {
this->smem_read_offset_ ^= 3 * BYTES_PER_LDS * 2;
} else if( Mma_tile_with_padding::MMAS_K >= 2 ) {
this->smem_read_offset_ ^= 1 * BYTES_PER_LDS * 2;
}
}
// Reset the read offset.
inline __device__ void reset_read_offset() {
// The number of MMAs in the K dimension.
enum { MMAS_K = Mma_tile::MMAS_K };
// The number of MMAs in the K dimension when we include padding.
enum { MMAS_K_WITH_PADDING = Mma_tile_with_padding::MMAS_K };
// Assemble the mask.
enum { MASK = Compute_reset_mask<MMAS_K, MMAS_K_WITH_PADDING>::VALUE };
// Reset the read offset.
this->smem_read_offset_ ^= MASK * BYTES_PER_LDS * 2;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The size of the STS.
int BYTES_PER_STS,
// The number of buffers per tile.
int BUFFERS_PER_TILE
>
struct Smem_tile_a<Cta_tile, Row, BYTES_PER_STS, BUFFERS_PER_TILE>
: public Smem_tile_row_a<Cta_tile,
BYTES_PER_STS,
BUFFERS_PER_TILE> {
// The base class.
using Base = Smem_tile_row_a<Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;
// Ctor.
inline __device__ Smem_tile_a(void *smem, int tidx) : Base(smem, tidx) {
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The layout of the tile.
typename Layout,
// The size of the STS.
int BYTES_PER_STS = 16,
// The number of buffers per tile.
int BUFFERS_PER_TILE = 1,
// Use or not predicates
bool USE_PREDICATES = true
>
struct Smem_tile_b {
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
struct Rows_per_xor_pattern_b {
// The size in bits.
enum { N_IN_BITS = N * fmha::BITS_PER_ELEMENT_B };
// The number of rows.
enum { VALUE = N_IN_BITS <= 256 ? 2 : (N_IN_BITS <= 512 ? 4 : 8) };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
struct Rows_per_xor_pattern_col_b : public Rows_per_xor_pattern_b<N> {
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The size of the STS.
int BYTES_PER_STS,
// The number of buffers per tile.
int BUFFERS_PER_TILE,
// How many rows to use for the XOR pattern to avoid bank conflicts?
int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_col_b<Cta_tile::K>::VALUE
>
struct Smem_tile_col_b : public Smem_tile_without_skews<Cta_tile,
Cta_tile::N,
Cta_tile::K,
fmha::BITS_PER_ELEMENT_B,
BYTES_PER_STS,
BUFFERS_PER_TILE,
0,
ROWS_PER_XOR_PATTERN_,
1> {
// The MMA tile.
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
// The base class.
using Base = Smem_tile_without_skews<Cta_tile,
Cta_tile::N,
Cta_tile::K,
fmha::BITS_PER_ELEMENT_B,
BYTES_PER_STS,
BUFFERS_PER_TILE,
0,
ROWS_PER_XOR_PATTERN_,
1>;
// The fragment.
using Fragment = Fragment_b< Col>;
// When we use padding to reach a power of two, special care has to be taken.
using Cta_tile_with_padding = Cta_tile_with_k_with_padding< Cta_tile>;
// The number of MMAs.
using Mma_tile_with_padding = fmha::Hmma_tile<Cta_tile_with_padding>;
// The size of a single LDS in bytes.
enum { BYTES_PER_LDS = 16 };
// The number of STS per thread
enum { STS_PER_THREAD_ = Base::ROWS * Base::THREADS_PER_ROW / Cta_tile::THREADS_PER_CTA };
// The number of STS per thread must be at least 1.
enum { STS_PER_THREAD = Max<1, STS_PER_THREAD_>::VALUE };
// Ctor.
inline __device__ Smem_tile_col_b(void *smem, int tidx) : Base(smem, tidx) {
// For documentation on the layout, see doc/mma_smem_layout.xlsx.
// The number of warps.
const int WARPS_M = Cta_tile::WARPS_M;
const int WARPS_N = Cta_tile::WARPS_N;
const int WARPS_K = Cta_tile::WARPS_K;
static_assert(Base::ROWS_PER_XOR_PATTERN == 8);
static_assert(WARPS_M == 1);
static_assert(WARPS_N == 4 || WARPS_N == 8);
static_assert(WARPS_K == 1);
// The masks to select the warps.
const int WARP_MASK_N = Warp_masks<WARPS_M, WARPS_N, WARPS_K>::N;
// The divisor for the warps.
const int WARP_DIV_N = WARPS_M * 1 * Cta_tile::THREADS_PER_WARP;
// The row and column read by the thread.
int smem_read_row = (tidx & WARP_MASK_N) / WARP_DIV_N * Mma_tile::N_PER_MMA +
(tidx & 0x07) +
(tidx & 0x10) / 2;
int smem_read_col = (tidx & 0x07);
smem_read_col ^= (tidx & 0x08) / 8;
// The shared memory offset.
this->smem_read_offset_ = smem_read_row*Base::BYTES_PER_ROW + smem_read_col*BYTES_PER_LDS;
}
// Rewind smem_read_offset for last LDS phase in main loop.
inline __device__ void reverse_smem_read_offset(int ki = 0) {
// Undo the pointer increment for the next ni.
// Should match the load function below for ki = 0.
if( Mma_tile_with_padding::MMAS_K >= 2 ) {
this->smem_read_offset_ ^= BYTES_PER_LDS * 2;
}
}
// Load from shared memory.
inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki) {
#pragma unroll
for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) {
// Jump by as many matrix rows as needed (a row in smem may pack multiple matrix rows).
int offset = ni * Mma_tile::N_PER_MMA_PER_CTA * Base::BYTES_PER_ROW_BEFORE_PACKING;
// Load using LDSM.M88.4.
uint4 tmp;
ldsm(tmp, this->smem_ + this->smem_read_offset_ + this->smem_read_buffer_ + offset);
// Store the value into the fragment.
b[ni].reg(0) = tmp.x;
b[ni].reg(1) = tmp.y;
b[ni].reg(2) = tmp.z;
b[ni].reg(3) = tmp.w;
}
// Move the offset to the next possition. See doc/mma_smem_layout.xlsx.
static_assert(Mma_tile_with_padding::MMAS_K < 64, "Not implemented");
if( Mma_tile_with_padding::MMAS_K >= 32 && ki % 16 == 15 ) {
this->smem_read_offset_ ^= 31 * BYTES_PER_LDS * 2;
} else if( Mma_tile_with_padding::MMAS_K >= 16 && ki % 8 == 7 ) {
this->smem_read_offset_ ^= 15 * BYTES_PER_LDS * 2;
} else if( Mma_tile_with_padding::MMAS_K >= 8 && ki % 4 == 3 ) {
this->smem_read_offset_ ^= 7 * BYTES_PER_LDS * 2;
} else if( Mma_tile_with_padding::MMAS_K >= 4 && ki % 2 == 1 ) {
this->smem_read_offset_ ^= 3 * BYTES_PER_LDS * 2;
} else if( Mma_tile_with_padding::MMAS_K >= 2 ) {
this->smem_read_offset_ ^= 1 * BYTES_PER_LDS * 2;
}
}
// Reset the read offset.
inline __device__ void reset_read_offset() {
// The number of MMAs in the K dimension.
enum { MMAS_K = Mma_tile::MMAS_K };
// The number of MMAs in the K dimension when we include padding.
enum { MMAS_K_WITH_PADDING = Mma_tile_with_padding::MMAS_K };
// Assemble the mask.
enum { MASK = Compute_reset_mask<MMAS_K, MMAS_K_WITH_PADDING>::VALUE };
// Reset the read offset.
this->smem_read_offset_ ^= MASK * BYTES_PER_LDS * 2;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The size of the STS.
int BYTES_PER_STS,
// The number of buffers per tile.
int BUFFERS_PER_TILE
>
struct Smem_tile_b< Cta_tile, Col, BYTES_PER_STS, BUFFERS_PER_TILE >
: public Smem_tile_col_b<Cta_tile,
BYTES_PER_STS,
BUFFERS_PER_TILE> {
// The base class.
using Base = Smem_tile_col_b< Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;
// Ctor.
inline __device__ Smem_tile_b(void *smem, int tidx) : Base(smem, tidx) {
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
struct Rows_per_xor_pattern_row_b : public Rows_per_xor_pattern_b< N> {
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The size of the STS.
int BYTES_PER_STS,
// The number of buffers per tile.
int BUFFERS_PER_TILE,
// How many rows to use for the XOR pattern to avoid bank conflicts?
int ROWS_PER_XOR_PATTERN_ = Rows_per_xor_pattern_row_b<Cta_tile::N>::VALUE,
// How many cols to use for the XOR pattern to avoid bank conflicts?
int COLS_PER_XOR_PATTERN_ = 1
>
struct Smem_tile_row_b : public Smem_tile_without_skews<Cta_tile,
Cta_tile::K,
Cta_tile::N,
fmha::BITS_PER_ELEMENT_B,
BYTES_PER_STS,
BUFFERS_PER_TILE,
0,
ROWS_PER_XOR_PATTERN_,
COLS_PER_XOR_PATTERN_> {
// The MMA tile.
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
// The base class.
using Base = Smem_tile_without_skews<Cta_tile,
Cta_tile::K,
Cta_tile::N,
fmha::BITS_PER_ELEMENT_B,
BYTES_PER_STS,
BUFFERS_PER_TILE,
0,
ROWS_PER_XOR_PATTERN_,
COLS_PER_XOR_PATTERN_>;
// The fragment.
using Fragment = Fragment_b<Row>;
// Can we use LDSM? No if the data type is 32-bit large.
enum { USE_LDSMT = fmha::BITS_PER_ELEMENT_B == 16 };
// The size of a single LDS in bytes.
enum { BYTES_PER_LDS = USE_LDSMT ? 16 : 4 };
// The number of elements per LDS.
enum { ELEMENTS_PER_LDS = BYTES_PER_LDS * 8 / fmha::BITS_PER_ELEMENT_B };
// The number of STS per thread
enum { STS_PER_THREAD_ = Base::ROWS * Base::THREADS_PER_ROW / Cta_tile::THREADS_PER_CTA };
// The number of STS per thread must be at least 1.
enum { STS_PER_THREAD = Max<1, STS_PER_THREAD_>::VALUE };
// Ctor.
inline __device__ Smem_tile_row_b(void *smem, int tidx) : Base(smem, tidx) {
// The number of warps.
const int WARPS_M = Cta_tile::WARPS_M;
const int WARPS_N = Cta_tile::WARPS_N;
const int WARPS_K = Cta_tile::WARPS_K;
static_assert(WARPS_K == 1);
static_assert(WARPS_M == 4 || WARPS_M == 8);
static_assert(WARPS_N == 1);
// The masks to select the warps.
const int WARP_MASK_N = Warp_masks<WARPS_M, WARPS_N, WARPS_K>::N;
const int WARP_MASK_K = Warp_masks<WARPS_M, WARPS_N, WARPS_K>::K;
// The divisor for the warps.
const int WARP_DIV_N = WARPS_M * 1 * Cta_tile::THREADS_PER_WARP;
const int WARP_DIV_K = WARPS_M * WARPS_N * Cta_tile::THREADS_PER_WARP;
// The row/col read by the thread.
int smem_read_row, smem_read_col;
static_assert(USE_LDSMT);
static_assert(Base::ROWS_PER_XOR_PATTERN == 8);
smem_read_row = (tidx & WARP_MASK_K) / WARP_DIV_K * Mma_tile::MMAS_K * 16 +
(tidx & 0x07) + (tidx & 0x08);
smem_read_col = (tidx & 0x07);
smem_read_col ^= (tidx & WARP_MASK_N) / WARP_DIV_N * 2 + (tidx & 0x10) / 16;
// The shared memory offset.
this->smem_read_offset_ = smem_read_row*Base::BYTES_PER_ROW + smem_read_col*BYTES_PER_LDS;
// Fill zeroes for group conv
}
// Rewind smem_read_offset for last LDS phase in main loop.
inline __device__ void reverse_smem_read_offset(int ki = 0) {
// The size of each element in bits.
const int BITS_PER_ELT = fmha::BITS_PER_ELEMENT_B;
// The size in bytes of the data needed to compute an MMA per CTA.
const int BYTES_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA * BITS_PER_ELT / 8;
#pragma unroll
for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) {
// Undo the pointer increment for the next ni.
// Should match the load function below for ki = 0.
if( BYTES_PER_MMA_PER_CTA >= 128 ) {
// Nothing to do!
} else if( BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 ) {
this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA;
} else if( BYTES_PER_MMA_PER_CTA == 64 ) {
// Nothing to do!
} else if( BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 4 ) {
this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 2 == 0 ? 2 : 6);
} else if( BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 2 ) {
this->smem_read_offset_ ^= BYTES_PER_LDS * 2;
}
}
// Reset smem_read_offset for odd MMAS_N > 1 (npo2 kernels)
if( BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 &&
Mma_tile::MMAS_N % 2 == 1 ) {
this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA;
}
}
// Load from shared memory.
inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki) {
// The size of each element in bits.
const int BITS_PER_ELT = fmha::BITS_PER_ELEMENT_B;
// The size in bytes of the data needed to compute an MMA per CTA.
const int BYTES_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA * BITS_PER_ELT / 8;
#pragma unroll
for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) {
// Prepare the offset.
int offset = ki * Base::ROWS_PER_XOR_PATTERN * 2 * Base::BYTES_PER_ROW;
if ( BYTES_PER_MMA_PER_CTA == 32 ) {
offset += this->smem_read_offset_;
} else if ( BYTES_PER_MMA_PER_CTA == 64 ) {
offset += this->smem_read_offset_ + (ni/2) * BYTES_PER_MMA_PER_CTA * 2;
} else {
offset += this->smem_read_offset_ + (ni ) * BYTES_PER_MMA_PER_CTA;
}
// Load the data using LDSM.MT88.2.
uint32_t ptr = this->smem_ + this->smem_read_buffer_ + offset;
uint4 tmp;
if( USE_LDSMT ) {
ldsmt(tmp, ptr);
} else {
lds(tmp.x, (ptr ) + 0*Base::BYTES_PER_ROW);
lds(tmp.y, (ptr ) + 4*Base::BYTES_PER_ROW);
lds(tmp.z, (ptr ^ 32) + 0*Base::BYTES_PER_ROW);
lds(tmp.w, (ptr ^ 32) + 4*Base::BYTES_PER_ROW);
}
// Store those values in the fragment.
b[ni].reg(0) = tmp.x;
b[ni].reg(1) = tmp.y;
b[ni].reg(2) = tmp.z;
b[ni].reg(3) = tmp.w;
// Move the pointer for the next ni. I expect the compiler to not recompute those.
if( BYTES_PER_MMA_PER_CTA >= 128 ) {
// Nothing to do!
} else if( BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 ) {
this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA;
} else if( BYTES_PER_MMA_PER_CTA == 64 ) {
// Nothing to do!
} else if( BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 4 ) {
this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 2 == 0 ? 2 : 6);
} else if( BYTES_PER_MMA_PER_CTA == 32 && Mma_tile::MMAS_N == 2 ) {
this->smem_read_offset_ ^= BYTES_PER_LDS * 2;
}
}
// Reset smem_read_offset for odd MMAS_N > 1 (npo2 kernels)
if( BYTES_PER_MMA_PER_CTA == 64 && Mma_tile::MMAS_N > 1 &&
Mma_tile::MMAS_N % 2 == 1 ) {
this->smem_read_offset_ ^= BYTES_PER_MMA_PER_CTA;
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The dimensions of the tile computed by the CTA.
typename Cta_tile,
// The size of the STS.
int BYTES_PER_STS,
// The number of buffers per tile.
int BUFFERS_PER_TILE
>
struct Smem_tile_b<Cta_tile, Row, BYTES_PER_STS, BUFFERS_PER_TILE>
: public Smem_tile_row_b<Cta_tile,
BYTES_PER_STS,
BUFFERS_PER_TILE> {
// The base class.
using Base = Smem_tile_row_b<Cta_tile, BYTES_PER_STS, BUFFERS_PER_TILE>;
// Ctor.
inline __device__ Smem_tile_b(void *smem, int tidx) : Base(smem, tidx) {
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Cta_tile>
struct Smem_tile_v : public fmha::Smem_tile_without_skews<Cta_tile, Cta_tile::K, Cta_tile::N, 16, 16, 1, 0, 8, 1> {
// The base class.
using Base = Smem_tile_without_skews<Cta_tile, Cta_tile::K, Cta_tile::N, 16, 16, 1, 0, 8, 1>;
// The MMA tile.
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
// The fragment.
using Fragment = Fragment_b< fmha::Col>;
// The size of a single LDS in bytes.
enum { BYTES_PER_LDS = 16 };
// Ctor.
inline __device__ Smem_tile_v(void *smem, int tidx) : Base(smem, tidx) {
// The row/col read by the thread.
int read_row, read_col;
static_assert(Cta_tile::WARPS_M == 1 && Cta_tile::WARPS_N == 1 && (Cta_tile::WARPS_K == 4 || Cta_tile::WARPS_K == 8));
read_row = (tidx & 0xe0) / 2 + (tidx & 0x0f);
read_col = (tidx & 0x07);
read_col ^= (tidx & 0x10) / 16;
// The shared memory offset.
this->smem_read_offset_ = read_row * Base::BYTES_PER_ROW + read_col * BYTES_PER_LDS;
}
// Load from shared memory.
inline __device__ void load(Fragment (&b)[Mma_tile::MMAS_N], int ki) {
#pragma unroll
for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) {
// Jump by 16 * #warps row.
int row = ki * 16 * Cta_tile::WARPS_K;
// Load the data using LDSM.MT88.2.
uint4 tmp;
fmha::ldsmt(tmp, this->smem_ + this->smem_read_offset_ + row * Base::BYTES_PER_ROW);
b[ni].reg(0) = tmp.x;
b[ni].reg(1) = tmp.y;
b[ni].reg(2) = tmp.z;
b[ni].reg(3) = tmp.w;
// Move the pointer for the next ni. I expect the compiler to not recompute those.
if( Mma_tile::MMAS_N == 4 ) {
this->smem_read_offset_ ^= BYTES_PER_LDS * (ni % 2 == 0 ? 2 : 6);
} else {
assert(false); // Not implemented!
}
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Cta_tile>
struct Smem_tile_o {
// The MMA tile.
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
// The accumulators.
using Accumulator = fmha::Fragment_accumulator;
// The accumulators.
using Data_type = typename Accumulator::Data_type;
// The size of each element.
enum { BYTES_PER_ELEMENT = sizeof(Data_type) };
// The size of each STS.
enum { BYTES_PER_STS = 8 };
// The size of each row in shared memory.
enum { BYTES_PER_ROW = Cta_tile::N * Cta_tile::WARPS_K * BYTES_PER_ELEMENT };
// The size of each LDS.
enum { BYTES_PER_LDS = 16 };
enum { THREADS_PER_ROW = 16 };
// The number of rows.
enum { ROWS = Cta_tile::M };
// The number of "rows" to process per loop iteration (in the "epilogue").
enum { ROWS_PER_LOOP = ROWS <= 64 ? ROWS : (int)Mma_tile::M_PER_MMA_PER_CTA };
// The number of outer loops.
enum { LOOPS = ROWS / ROWS_PER_LOOP };
// Make sure it matches our expectations.
static_assert(LOOPS == 1 || LOOPS == (int)Mma_tile::MMAS_M, "");
// The number of rows loaded per LDS.
enum { ROWS_PER_LDS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW };
// Do we have to guard against partial writes/reads.
enum { HAS_INCOMPLETE_LDS = ROWS_PER_LOOP % ROWS_PER_LDS != 0 };
// The total number of LDS per loop.
enum { LDS_PER_LOOP = fmha::Div_up<ROWS_PER_LOOP, ROWS_PER_LDS>::VALUE };
// The amount of shared memory.
enum { BYTES_PER_TILE = ROWS_PER_LOOP * BYTES_PER_ROW };
// The write pointer.
uint32_t smem_write_, smem_read_;
// Is the thread active for the last LDS of the series?
int is_active_for_last_lds_;
static_assert(BYTES_PER_ROW == 64 * 4 * Cta_tile::WARPS_K);
static_assert(LOOPS == 1 || LOOPS == (int)Mma_tile::MMAS_M, "");
// Ctor.
inline __device__ Smem_tile_o(void *smem, int tidx) {
// Get a 32-bit value for the shared memory address.
uint32_t smem_ = __nvvm_get_smem_pointer(smem);
static_assert(Cta_tile::WARPS_M == 1 && Cta_tile::WARPS_N == 1 && (Cta_tile::WARPS_K == 4 || Cta_tile::WARPS_K == 8));
int write_row = (tidx & 0x1c) / 4;
int write_col = (tidx);
// Assemble the write pointer.
smem_write_ = smem_ + write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS;
// The element read by each thread.
int read_row = tidx / THREADS_PER_ROW;
int read_col = tidx % THREADS_PER_ROW;
// Take the XOR pattern into account for the column.
read_col ^= 2 * (read_row & 0x7);
// Assemble the read pointer.
this->smem_read_ = smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;
// Is that thread active on the last LDS?
if( HAS_INCOMPLETE_LDS ) {
this->is_active_for_last_lds_ = read_row + (LDS_PER_LOOP - 1) * ROWS_PER_LDS < Cta_tile::M;
}
}
// Load the output fragments.
inline __device__ void load(uint4 (&out)[LDS_PER_LOOP]) const {
#pragma unroll
for( int ii = 0; ii < LDS_PER_LOOP; ++ii ) {
// Load the elements before the reduction (split-K).
uint4 tmp[Cta_tile::WARPS_K];
#pragma unroll
for( int jj = 0; jj < Cta_tile::WARPS_K; ++jj ) {
int imm = ii * ROWS_PER_LDS * BYTES_PER_ROW + jj * Cta_tile::N * BYTES_PER_ELEMENT;
if( !HAS_INCOMPLETE_LDS || (ii < LDS_PER_LOOP - 1 || this->is_active_for_last_lds_) ) {
fmha::lds(tmp[jj], this->smem_read_ + imm);
}
}
// Perform the reduction.
out[ii] = tmp[0];
#pragma unroll
for( int jj = 1; jj < Cta_tile::WARPS_K; ++jj ) {
out[ii] = fmha::fadd4(out[ii], tmp[jj]);
}
}
}
// Store the accumulators.
template <int M, int N>
inline __device__ void store(const Accumulator (&acc)[M][N], int mi) {
enum { M_PER_MMA = Mma_tile::M_PER_MMA_PER_CTA };
#pragma unroll
for( int ni = 0; ni < Mma_tile::MMAS_N; ++ni ) {
// The number of MMAs that are stored per loop iteration.
enum { MMAS_M_PER_LOOP = Mma_tile::MMAS_M / LOOPS };
// Store 1st column of the different MMAs.
#pragma unroll
for( int mj = 0; mj < MMAS_M_PER_LOOP; ++mj ) {
// Precompute the immediates to jump between rows.
int row_0 = (mj * M_PER_MMA + 0) * BYTES_PER_ROW;
int row_1 = (mj * M_PER_MMA + 8) * BYTES_PER_ROW;
uint2 tmp0, tmp1;
tmp0.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(0);
tmp0.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(1);
tmp1.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(2);
tmp1.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(3);
// Store.
fmha::sts(this->smem_write_ + row_0, tmp0);
fmha::sts(this->smem_write_ + row_1, tmp1);
}
// Swizzle the write pointer using a XOR of 16B.
this->smem_write_ ^= 32;
// Store 2nd column of the different MMAs.
#pragma unroll
for( int mj = 0; mj < MMAS_M_PER_LOOP; ++mj ) {
// Precompute the immediates to jump between rows.
int row_0 = (mj * M_PER_MMA + 0) * BYTES_PER_ROW;
int row_1 = (mj * M_PER_MMA + 8) * BYTES_PER_ROW;
uint2 tmp0, tmp1;
tmp0.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(4);
tmp0.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(5);
tmp1.x = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(6);
tmp1.y = acc[mi * MMAS_M_PER_LOOP + mj][ni].reg(7);
// Store.
fmha::sts(this->smem_write_ + row_0, tmp0);
fmha::sts(this->smem_write_ + row_1, tmp1);
}
// Cancel the previous XOR of 1 + swizzle the write pointer using a XOR of 32B or 64B.
this->smem_write_ ^= (ni & 1) ? 7 * 32 : 3 * 32;
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Cta_tile>
struct Smem_tile_mma {
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
using Fragment = fmha::Fragment_a<fmha::Col>;
enum { COLS = Cta_tile::N };
enum { BYTES_PER_ELT = 2 };
enum { BYTES_PER_STS = 4 };
enum { BYTES_PER_ROW = COLS * BYTES_PER_ELT }; // TODO
enum { BYTES_PER_TILE = Cta_tile::M * BYTES_PER_ROW };
enum { WARPS_M = Cta_tile::WARPS_M };
enum { WARPS_N = Cta_tile::WARPS_N };
enum { WARPS_K = Cta_tile::WARPS_K };
static_assert(WARPS_K == 1);
inline __device__ Smem_tile_mma(char *smem, int tidx) {
smem_ = __nvvm_get_smem_pointer(smem);
int write_col, write_row;
static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8) || (WARPS_M == 4 || WARPS_N == 8) || WARPS_N == 1);
if( WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8) ) {
write_row = (tidx & 0x1c) / 4;
write_col = (tidx & 0xe0) / 4 + (tidx & 0x03);
} else {
write_row = (tidx & 0xe0) / 2 + (tidx & 0x1c) / 4;
write_col = (tidx & 0x03);
}
write_col ^= (write_row & 0x07) * 4;
write_offset_ = write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS;
}
template<int M, int N>
inline __device__ void store(const uint4 (&regs)[M][N]) {
static_assert(COLS == Cta_tile::N);
for( int mi = 0; mi < M; mi++ ) {
for( int ni = 0; ni < N; ni++ ) {
size_t offset = write_offset_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT;
fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].x);
fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].z);
offset ^= 4 * BYTES_PER_STS;
fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].y);
fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].w);
}
}
}
uint32_t smem_;
uint32_t write_offset_;
uint32_t warp_m;
uint32_t warp_n;
uint32_t lane;
};
template< typename Cta_tile, typename Base = Smem_tile_mma< Cta_tile>>
struct Smem_tile_mma_transposed : public Base {
enum { BYTES_PER_LDS = 16 };
enum { BYTES_PER_ROW = Base::BYTES_PER_ROW };
enum { BYTES_PER_ELT = Base::BYTES_PER_ELT };
enum { WARPS_M = Base::WARPS_M };
enum { WARPS_N = Base::WARPS_N };
static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8));
using Fragment = typename Base::Fragment;
inline __device__ Smem_tile_mma_transposed(char *smem, int tidx) : Base(smem, tidx) {
static_assert(WARPS_M == 1 && (WARPS_N == 4 || WARPS_N == 8));
int read_row, read_col;
read_row = (tidx & 0x0f);
read_col = (tidx & 0xe0) / 16 + (tidx & 0x1c) / 16;
read_col ^= (read_row & 0x07);
read_offset_ = read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;
}
template<int M, int N>
inline __device__ void load(Fragment (&frag)[M][N]) {
static_assert(Base::COLS == Cta_tile::N);
for( int mi = 0; mi < M; mi++ ) {
for( int ni = 0; ni < N; ni++ ) {
size_t offset = read_offset_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT;
uint4 dst;
fmha::ldsmt(dst, this->smem_ + offset);
frag[mi][ni].reg(0) = dst.x;
frag[mi][ni].reg(1) = dst.z; // Fragment A regs col major!
frag[mi][ni].reg(2) = dst.y;
frag[mi][ni].reg(3) = dst.w;
}
}
}
uint32_t read_offset_;
};
template< typename Cta_tile, typename Base = Smem_tile_mma< Cta_tile>>
struct Smem_tile_mma_epilogue : public Base {
enum { BYTES_PER_LDS = 16 };
enum { BYTES_PER_ROW = Base::BYTES_PER_ROW };
enum { BYTES_PER_ELT = Base::BYTES_PER_ELT };
enum { THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_LDS };
static_assert(THREADS_PER_ROW * BYTES_PER_LDS == BYTES_PER_ROW);
enum { ROWS_PER_LDS = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW };
enum { NUM_LDS = Cta_tile::M / ROWS_PER_LDS };
static_assert(NUM_LDS * ROWS_PER_LDS == Cta_tile::M);
enum { WARPS_M = Base::WARPS_M };
enum { WARPS_N = Base::WARPS_N };
static_assert((WARPS_M == 4 || WARPS_N == 8) || WARPS_N == 1);
using Acc = fmha::Fragment_accumulator;
inline __device__ Smem_tile_mma_epilogue(char *smem, int tidx) : Base(smem, tidx) {
const int read_row = tidx / THREADS_PER_ROW;
int read_col = tidx % THREADS_PER_ROW;
read_col ^= (read_row & 0x07);
read_offset_ = read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;
}
inline __device__ void load(uint4 (&data)[NUM_LDS]) {
for( int ii = 0; ii < NUM_LDS; ii++ ) {
size_t offset = read_offset_ + ii * ROWS_PER_LDS * BYTES_PER_ROW;
fmha::lds(data[ii], this->smem_ + offset);
}
}
template<int M, int N>
inline __device__ void store(const Acc (&acc)[M][N]){
#pragma unroll
for( int mi = 0; mi < M; mi++ ) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
// 1st row - 4 elements per row.
float tmp00 = acc[mi][ni].elt(0);
float tmp01 = acc[mi][ni].elt(1);
float tmp02 = acc[mi][ni].elt(4);
float tmp03 = acc[mi][ni].elt(5);
// 2nd row - 4 elements per row.
float tmp10 = acc[mi][ni].elt(2);
float tmp11 = acc[mi][ni].elt(3);
float tmp12 = acc[mi][ni].elt(6);
float tmp13 = acc[mi][ni].elt(7);
uint32_t x = fmha::float2_to_half2(tmp00, tmp01);
uint32_t y = fmha::float2_to_half2(tmp02, tmp03);
uint32_t z = fmha::float2_to_half2(tmp10, tmp11);
uint32_t w = fmha::float2_to_half2(tmp12, tmp13);
size_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW;
fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, x);
fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, z);
offset ^= 4 * Base::BYTES_PER_STS;
fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, y);
fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, w);
}
}
}
template<int M, int N>
inline __device__ void store(const uint4 (&regs)[M][N]) {
for( int mi = 0; mi < M; mi++ ) {
for( int ni = 0; ni < N; ni++ ) {
size_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW;
fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].x);
fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].z);
offset ^= 4 * Base::BYTES_PER_STS;
fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].y);
fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].w);
}
}
}
uint32_t read_offset_;
};
} // namespace fmha
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Sum_ {
enum { IS_SUM = 1 };
static inline __device__ float apply(float x, float y) {
return x + y;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Max_ {
enum { IS_SUM = 0 };
static inline __device__ float apply(float x, float y) {
return x > y ? x : y;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float apply_exp_(float x, float max) {
return __expf(x - max);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int COLS> struct ReadType {};
template<> struct ReadType<4> { using T = float;};
template<> struct ReadType<8> { using T = float2;};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Cta_tile, typename Kernel_traits>
struct Smem_tile_reduce {
// Helper class to distribute MMA tiles reduced over rows per warp over quads.
// The Mma tile.
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
// The number of MMAs in M/N dimensions.
enum { MMAS_M = Mma_tile::MMAS_M };
enum { MMAS_N = Mma_tile::MMAS_N };
enum { WARPS_M = Cta_tile::WARPS_M };
enum { WARPS_N = Cta_tile::WARPS_N };
static constexpr int ROWS = WARPS_M * MMAS_M * 16;
static constexpr int COLS = WARPS_N;
static_assert(COLS == 4 || COLS == 8);
static constexpr int ROWS_PER_XOR_PATTERN = (COLS == 8) ? 4 : 8;
static constexpr int BYTES_PER_TILE = ROWS * COLS * sizeof(float);
static constexpr int ELTS_PER_TILE = ROWS * COLS;
static constexpr int THREADS_PER_GROUP = Kernel_traits::Gmem_tile_o::THREADS_PER_ROW;
static_assert(THREADS_PER_GROUP == 16); // DEBUG
static constexpr int ROWS_PER_WARP = 32 / THREADS_PER_GROUP;
static constexpr int LOOPS = Kernel_traits::Gmem_tile_o::LOOPS;
static_assert(LOOPS == 1);
using read_t = typename ReadType<COLS>::T;
__device__ inline Smem_tile_reduce(float *smem_, const int tidx) {
int lane = tidx % 32;
int warp = tidx / 32;
int warp_m = warp % WARPS_M;
int warp_n = warp / WARPS_M;
qid_ = lane % 4;
int qp = lane / 4;
// Swizzle the column to avoid 2-fold bank conflicts when we have 8 warps.
// This won't affect reading as we assume commutative reduction ops.
const int col = warp_n ^ (qp / ROWS_PER_XOR_PATTERN);
smem_write_ = &smem_[warp_m * 16 * MMAS_M * WARPS_N + qp * WARPS_N + col];
smem_read_ = &reinterpret_cast<read_t *>(smem_)[warp_m * 16 * MMAS_M * 4 + qp * 4 + qid_];
}
__device__ inline void store(float (&frag)[2 * MMAS_M]) {
if( qid_ == 0 ) {
#pragma unroll
for( int mi = 0; mi < MMAS_M; mi++ ) {
int offset = mi * 16 * WARPS_N;
smem_write_[offset + 0 * 8 * WARPS_N] = frag[mi * 2 + 0];
smem_write_[offset + 1 * 8 * WARPS_N] = frag[mi * 2 + 1];
}
}
}
__device__ inline void load(read_t (&frag)[2 * MMAS_M]) {
#pragma unroll
for( int mi = 0; mi < MMAS_M; mi++ ) {
int offset = mi * 16 * 4;
frag[mi * 2 + 0] = smem_read_[offset + 0 * 8 * 4];
frag[mi * 2 + 1] = smem_read_[offset + 1 * 8 * 4];
}
}
int qid_;
float *smem_write_;
read_t *smem_read_;
};
template<typename Cta_tile, typename Kernel_traits>
struct Softmax_base {
// The Mma tile.
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
// The number of MMAs in M/N dimensions.
enum { MMAS_M = Mma_tile::MMAS_M };
enum { MMAS_N = Mma_tile::MMAS_N };
// The number of groups of warp such that we have at most 4 warps writing consecutive elements.
enum { GROUPS = fmha::Div_up<Cta_tile::WARPS_N, 4>::VALUE };
// The number of elements that we are going to store per row.
enum { ELEMENTS_PER_ROW = Cta_tile::WARPS_N / GROUPS };
// The number of rows.
enum { ROWS = Cta_tile::M * GROUPS };
// The total number of elements.
enum { ELEMENTS = ROWS * ELEMENTS_PER_ROW };
// Ctor.
template<typename Params>
inline __device__ Softmax_base(const Params &params, void *smem, int bidb, int tidx)
: // packed_mask_ptr_(reinterpret_cast<const char*>(params.packed_mask_ptr)),
smem_(reinterpret_cast<float *>(smem)), tidx_(tidx) {
// Move to the 1st mask loaded by the thread+ tidx;
// packed_mask_ptr_ += bidb * params.packed_mask_stride_in_bytes + tidx * sizeof(uint32_t);
// Extract the position in the warp.
int warp = tidx / Cta_tile::THREADS_PER_WARP;
int lane = tidx % Cta_tile::THREADS_PER_WARP;
// Decompose the warp index into M and N.
int warp_m = warp % Cta_tile::WARPS_M;
int warp_n = warp / Cta_tile::WARPS_M;
// Decompose the warp-n index into group/position-inside-the-group.
int warp_g = warp_n / ELEMENTS_PER_ROW;
int warp_i = warp_n % ELEMENTS_PER_ROW;
// The location written by the threads.
int write_row = warp_g * (ROWS / GROUPS) + warp_m * Mma_tile::M_PER_MMA + lane / 4;
int write_col = warp_i;
// Assemble the write pointer.
smem_write_ = &smem_[write_row * ELEMENTS_PER_ROW + write_col];
// Assemble the read pointer.
smem_read_ = &smem_[warp_m * Mma_tile::M_PER_MMA + lane / 4];
}
template<typename Mask>
inline __device__ void apply_mask(const Mask &mask) {
#pragma unroll
for( int mi = 0; mi < MMAS_M; ++mi ) {
#pragma unroll
for( int ii = 0; ii < 2; ++ii ) {
#pragma unroll
for( int ni = 0; ni < MMAS_N; ++ni ) {
#pragma unroll
for( int jj = 0; jj < 4; ++jj ) {
if( !mask.is_valid(mi, ni, ii, jj) ) {
elt_[2 * mi + ii][4 * ni + jj] = -INFINITY;
}
}
}
}
}
}
// Apply the exp to all the elements.
inline __device__ void apply_exp(const float (&max)[MMAS_M * 2]) {
#pragma unroll
for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
#pragma unroll
for( int ni = 0; ni < MMAS_N * 4; ++ni ) {
elt_[mi][ni] = apply_exp_(elt_[mi][ni], max[mi]);
}
}
}
// Scale all the elements.
inline __device__ void scale(const float (&sum)[MMAS_M * 2]) {
// Precompute the inverse sum to normalize. Without -use_fast_math, it makes a huge deal.
float inv_sum[MMAS_M * 2];
#pragma unroll
for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
inv_sum[mi] = (sum[mi] == 0.f || sum[mi] != sum[mi]) ? 1.f : 1.f / sum[mi];
}
// Update the values.
#pragma unroll
for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
#pragma unroll
for( int ni = 0; ni < MMAS_N * 4; ++ni ) {
elt_[mi][ni] *= inv_sum[mi];
}
}
}
// The pointer to the mask.
const char *packed_mask_ptr_;
// Shared memory for the CTA-wide reduction.
float *smem_, *smem_write_, *smem_read_;
// The current thread index.
int tidx_;
// The elements.
float elt_[MMAS_M * 2][MMAS_N * 4];
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Cta_tile, typename Kernel_traits>
struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
// The base class.
using Base = Softmax_base<Cta_tile, Kernel_traits>;
// The fragment.
using Fragment_a = fmha::Fragment_a<fmha::Row>;
static_assert(Fragment_a::NUM_REGS == 4);
enum { WARPS_M = Cta_tile::WARPS_M };
enum { WARPS_N = Cta_tile::WARPS_N };
// The MMAs.
enum { MMAS_M = Base::MMAS_M };
enum { MMAS_N = Base::MMAS_N };
// The accumulators.
using Accumulator = fmha::Fragment_accumulator;
using Accumulator_out = Fragment<uint16_t, 8>;
static_assert(Accumulator_out::NUM_REGS == 4);
static_assert(std::is_same<Accumulator::Data_type, float>::value);
using Smem_tile_red = Smem_tile_reduce<Cta_tile, Kernel_traits>;
static_assert(Smem_tile_red::ELTS_PER_TILE == Cta_tile::M * WARPS_N);
// Ctor.
template<typename Params>
inline __device__ Softmax(const Params &params, void *smem, int bidb, int tidx)
: Base(params, smem, bidb, tidx)
, params_scale_bmm1_(params.scale_bmm1)
, smem_sum_(static_cast<float*>(smem), tidx)
, smem_max_(static_cast<float*>(smem) + Smem_tile_red::ELTS_PER_TILE, tidx) {
}
// Pack the data to a fragment for the next GEMM.
template<int K, int M>
inline __device__ void pack(Fragment_a (&dst)[K][M]) const {
#pragma unroll
for( int mi = 0; mi < M; ++mi ) {
#pragma unroll
for( int ki = 0; ki < K; ++ki ) {
// 1st row - 4 elements per row.
float tmp_00 = this->elt_[2 * mi + 0][4 * ki + 0];
float tmp_01 = this->elt_[2 * mi + 0][4 * ki + 1];
float tmp_02 = this->elt_[2 * mi + 0][4 * ki + 2];
float tmp_03 = this->elt_[2 * mi + 0][4 * ki + 3];
// 2nd row - 4 elements per row.
float tmp_10 = this->elt_[2 * mi + 1][4 * ki + 0];
float tmp_11 = this->elt_[2 * mi + 1][4 * ki + 1];
float tmp_12 = this->elt_[2 * mi + 1][4 * ki + 2];
float tmp_13 = this->elt_[2 * mi + 1][4 * ki + 3];
// Pack to 4 registers.
dst[ki][mi].reg(0) = fmha::float2_to_half2(tmp_00, tmp_01);
dst[ki][mi].reg(1) = fmha::float2_to_half2(tmp_10, tmp_11);
dst[ki][mi].reg(2) = fmha::float2_to_half2(tmp_02, tmp_03);
dst[ki][mi].reg(3) = fmha::float2_to_half2(tmp_12, tmp_13);
}
}
}
// Scale FP32 fragments
inline __device__ void unpack(const Accumulator (&acc)[MMAS_M][MMAS_N]) {
const float scalef = reinterpret_cast<const float &>(this->params_scale_bmm1_);
#pragma unroll
for( int mi = 0; mi < MMAS_M; ++mi ) {
#pragma unroll
for( int ni = 0; ni < MMAS_N; ++ni ) {
// 1st row - 4 elements per row.
this->elt_[2 * mi + 0][4 * ni + 0] = acc[mi][ni].elt(0) * scalef;
this->elt_[2 * mi + 0][4 * ni + 1] = acc[mi][ni].elt(1) * scalef;
this->elt_[2 * mi + 0][4 * ni + 2] = acc[mi][ni].elt(4) * scalef;
this->elt_[2 * mi + 0][4 * ni + 3] = acc[mi][ni].elt(5) * scalef;
// 2nd row - 4 elements per row.
this->elt_[2 * mi + 1][4 * ni + 0] = acc[mi][ni].elt(2) * scalef;
this->elt_[2 * mi + 1][4 * ni + 1] = acc[mi][ni].elt(3) * scalef;
this->elt_[2 * mi + 1][4 * ni + 2] = acc[mi][ni].elt(6) * scalef;
this->elt_[2 * mi + 1][4 * ni + 3] = acc[mi][ni].elt(7) * scalef;
}
}
}
// Scale FP32 fragments
inline __device__ void unpack_noscale(const Accumulator (&acc)[MMAS_M][MMAS_N]) {
#pragma unroll
for( int mi = 0; mi < MMAS_M; ++mi ) {
#pragma unroll
for( int ni = 0; ni < MMAS_N; ++ni ) {
// 1st row - 4 elements per row.
this->elt_[2 * mi + 0][4 * ni + 0] = acc[mi][ni].elt(0);
this->elt_[2 * mi + 0][4 * ni + 1] = acc[mi][ni].elt(1);
this->elt_[2 * mi + 0][4 * ni + 2] = acc[mi][ni].elt(4);
this->elt_[2 * mi + 0][4 * ni + 3] = acc[mi][ni].elt(5);
// 2nd row - 4 elements per row.
this->elt_[2 * mi + 1][4 * ni + 0] = acc[mi][ni].elt(2);
this->elt_[2 * mi + 1][4 * ni + 1] = acc[mi][ni].elt(3);
this->elt_[2 * mi + 1][4 * ni + 2] = acc[mi][ni].elt(6);
this->elt_[2 * mi + 1][4 * ni + 3] = acc[mi][ni].elt(7);
}
}
}
template<typename Operator>
__device__ inline void reduce_(float (&frag)[2 * MMAS_M], Operator &op, Smem_tile_red & smem_red) {
for( int mi = 0; mi < 2 * MMAS_M; mi++ ) {
frag[mi] = this->elt_[mi][0];
for( int ni = 1; ni < 4 * MMAS_N; ni++ ) {
frag[mi] = op(frag[mi], this->elt_[mi][ni]);
}
}
quad_reduce(frag, frag, op);
smem_red.store(frag);
__syncthreads();
typename Smem_tile_red::read_t tmp[2 * MMAS_M];
smem_red.load(tmp);
quad_allreduce(frag, tmp, op);
}
__device__ inline void reduce_max(float (&frag)[2 * MMAS_M]){
MaxOp<float> max;
reduce_(frag, max, smem_max_);
}
__device__ inline void reduce_sum(float (&frag)[2 * MMAS_M]){
SumOp<float> sum;
reduce_(frag, sum, smem_sum_);
}
const uint32_t params_scale_bmm1_;
Smem_tile_red smem_max_;
Smem_tile_red smem_sum_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include <assert.h>
#include <stdint.h>
#include <stdlib.h>
extern "C" __device__ uint32_t __nvvm_get_smem_pointer(void *ptr);
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Row {};
struct Col {};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int M, bool = (M & (M-1)) == 0 >
struct Next_power_of_two {
};
template< int M >
struct Next_power_of_two< M, true > { enum { VALUE = M }; };
template<>
struct Next_power_of_two< 3, false> { enum { VALUE = 4 }; };
template<>
struct Next_power_of_two< 5, false> { enum { VALUE = 8 }; };
template<>
struct Next_power_of_two< 6, false> { enum { VALUE = 8 }; };
template<>
struct Next_power_of_two< 7, false> { enum { VALUE = 8 }; };
template<>
struct Next_power_of_two< 9, false> { enum { VALUE = 16 }; };
template<>
struct Next_power_of_two< 10, false> { enum { VALUE = 16 }; };
template<>
struct Next_power_of_two< 11, false> { enum { VALUE = 16 }; };
template<>
struct Next_power_of_two< 12, false> { enum { VALUE = 16 }; };
template<>
struct Next_power_of_two< 13, false> { enum { VALUE = 16 }; };
template<>
struct Next_power_of_two< 14, false> { enum { VALUE = 16 }; };
template<>
struct Next_power_of_two< 15, false> { enum { VALUE = 16 }; };
template<>
struct Next_power_of_two< 24, false> { enum { VALUE = 32 }; };
template<>
struct Next_power_of_two< 48, false> { enum { VALUE = 64 }; };
template<>
struct Next_power_of_two< 80, false> { enum { VALUE = 128 }; };
template<>
struct Next_power_of_two< 96, false> { enum { VALUE = 128 }; };
template<>
struct Next_power_of_two<112, false> { enum { VALUE = 128 }; };
template<>
struct Next_power_of_two<144, false> { enum { VALUE = 256 }; };
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N, bool = (N & (N-1)) == 0 >
struct Prev_power_of_two {
};
template< int N >
struct Prev_power_of_two< N, true > { enum { VALUE = N }; };
template<>
struct Prev_power_of_two< 3, false> { enum { VALUE = 2 }; };
template<>
struct Prev_power_of_two< 5, false> { enum { VALUE = 4 }; };
template<>
struct Prev_power_of_two< 6, false> { enum { VALUE = 4 }; };
template<>
struct Prev_power_of_two< 7, false> { enum { VALUE = 4 }; };
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int M, int N >
struct Div_up {
enum { VALUE = (M + N-1) / N };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int A, int B >
struct Max {
enum { VALUE = A >= B ? A : B };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int A, int B, int C >
struct Max_3 {
enum { VALUE = Max<Max<A, B>::VALUE, C>::VALUE };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int A, int B >
struct Min {
enum { VALUE = A <= B ? A : B };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int SIZE_IN_BYTES >
struct Uint_from_size_in_bytes {
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
struct Uint_from_size_in_bytes<1> {
using Type = uint8_t;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
struct Uint_from_size_in_bytes<2> {
using Type = uint16_t;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
struct Uint_from_size_in_bytes<4> {
using Type = uint32_t;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
struct Uint_from_size_in_bytes<8> {
using Type = uint2;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
struct Uint_from_size_in_bytes<16> {
using Type = uint4;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int WARPS_M, int WARPS_N, int WARPS_K >
struct Warp_masks {
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
struct Warp_masks<8, 1, 1> { enum { M = 0xe0, N = 0x00, K = 0x00 }; };
template<>
struct Warp_masks<4, 2, 1> { enum { M = 0x60, N = 0x80, K = 0x00 }; };
template<>
struct Warp_masks<4, 1, 2> { enum { M = 0x60, N = 0x00, K = 0x80 }; };
template<>
struct Warp_masks<4, 1, 1> { enum { M = 0x60, N = 0x00, K = 0x00 }; };
template<>
struct Warp_masks<2, 4, 1> { enum { M = 0x20, N = 0xc0, K = 0x00 }; };
template<>
struct Warp_masks<2, 2, 2> { enum { M = 0x20, N = 0x40, K = 0x80 }; };
template<>
struct Warp_masks<2, 2, 1> { enum { M = 0x20, N = 0x40, K = 0x00 }; };
template<>
struct Warp_masks<2, 1, 2> { enum { M = 0x20, N = 0x00, K = 0x40 }; };
template<>
struct Warp_masks<2, 1, 1> { enum { M = 0x20, N = 0x00, K = 0x00 }; };
template<>
struct Warp_masks<1, 8, 1> { enum { M = 0x00, N = 0xe0, K = 0x00 }; };
template<>
struct Warp_masks<1, 4, 2> { enum { M = 0x00, N = 0x60, K = 0x80 }; };
template<>
struct Warp_masks<1, 4, 1> { enum { M = 0x00, N = 0x60, K = 0x00 }; };
template<>
struct Warp_masks<1, 2, 2> { enum { M = 0x00, N = 0x20, K = 0x40 }; };
template<>
struct Warp_masks<1, 2, 1> { enum { M = 0x00, N = 0x20, K = 0x00 }; };
template<>
struct Warp_masks<1, 1, 4> { enum { M = 0x00, N = 0x00, K = 0x60 }; };
template<>
struct Warp_masks<1, 1, 2> { enum { M = 0x00, N = 0x00, K = 0x20 }; };
template<>
struct Warp_masks<1, 1, 1> { enum { M = 0x00, N = 0x00, K = 0x00 }; };
////////////////////////////////////////////////////////////////////////////////////////////////////
template< typename T >
inline __device__ __host__ T div_up(T m, T n) {
return (m + n-1) / n;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline int clz(int x) {
for( int i = 31; i >= 0; --i ) {
if( (1 << i) & x ) {
return 31 - i;
}
}
return 32;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline int find_log_2(int x, bool round_up = false) {
int a = 31 - clz(x);
if( round_up ) {
a += (x & (x-1)) ? 1 : 0;
}
return a;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t hadd2(uint32_t a, uint32_t b) {
uint32_t c;
asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t hmin2(uint32_t a, uint32_t b) {
uint32_t c;
asm volatile("min.f16x2 %0, %1, %2;" : "=r"(c) : "r"(a), "r"(b));
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t hmul2(uint32_t a, uint32_t b) {
uint32_t c;
asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint2 hmul4(uint2 a, uint2 b) {
uint2 c;
c.x = hmul2(a.x, b.x);
c.y = hmul2(a.y, b.y);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint4 hmul8(uint4 a, uint4 b) {
uint4 c;
c.x = hmul2(a.x, b.x);
c.y = hmul2(a.y, b.y);
c.z = hmul2(a.z, b.z);
c.w = hmul2(a.w, b.w);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint4 hmul8(uint32_t a, uint4 b) {
uint4 c;
c.x = hmul2(a, b.x);
c.y = hmul2(a, b.y);
c.z = hmul2(a, b.z);
c.w = hmul2(a, b.w);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t hrelu2(uint32_t x, uint32_t lb = 0) {
uint32_t res;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile( "max.f16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(lb));
#else
const uint32_t zero = 0u;
asm volatile( \
"{\n" \
"\t .reg .f16x2 sela;\n" \
"\t set.gtu.u32.f16x2 sela, %1, %2;\n" \
"\t and.b32 %0, sela, %1;\n"
"}\n" : "=r"(res) : "r"(x), "r"(zero));
#endif
return res;
}
static inline __device__ uint32_t habs2(uint32_t x) {
uint32_t res;
asm volatile( "abs.f16x2 %0, %1;\n" : "=r"(res) : "r"(x));
return res;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
template< typename T >
static inline __device__ T clamp(T x, T lb, T ub) {
return x < lb ? lb : (x > ub ? ub : x);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint16_t clamp_to_zero(uint16_t x) {
uint16_t mask;
asm volatile("set.gtu %0, %1, 0;" : "=h"(mask) : "h"(x));
return mask & x;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint16_t float_to_half(float f) {
uint16_t h;
asm volatile("cvt.rn.f16.f32 %0, %1;" : "=h"(h) : "f"(f));
return h;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t float2_to_half2(float a, float b) {
uint32_t c;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(c) : "f"(b), "f"(a));
#else
uint16_t lo = float_to_half(a);
uint16_t hi = float_to_half(b);
asm volatile("mov.b32 %0, {%1, %2};\n" : "=r"(c) : "h"(lo), "h"(hi));
#endif
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t float_to_half2(float a) {
return float2_to_half2(a,a);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t float2_to_half2(const float2 &f) {
return float2_to_half2(f.x, f.y);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint2 float4_to_half4(float x, float y, float z, float w) {
uint2 d;
d.x = float2_to_half2(x, y);
d.y = float2_to_half2(z, w);
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t hfma2(uint32_t a, uint32_t b, uint32_t c) {
uint32_t d;
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c));
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t hfma2_relu(uint32_t a, uint32_t b, uint32_t c) {
uint32_t d;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile("fma.rn.f16x2.relu %0, %1, %2, %3;" : "=r"(d) : "r"(a), "r"(b), "r"(c));
#else
d = hrelu2(hfma2(a, b, c));
#endif
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t h0_h0(uint32_t x) {
uint32_t y;
asm volatile("{.reg .f16 lo, hi; mov.b32 {lo, hi}, %1; mov.b32 %0, {lo, lo};}\n"
: "=r"(y) : "r"(x));
return y;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ float h0_to_float(uint32_t h2) {
float f;
asm volatile("{\n" \
".reg .f16 lo, hi;\n" \
"mov.b32 {lo, hi}, %1;\n" \
"cvt.f32.f16 %0, lo;\n" \
"}\n" : "=f"(f) : "r"(h2));
return f;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t h1_h1(uint32_t x) {
uint32_t y;
asm volatile("{.reg .f16 lo, hi; mov.b32 {lo, hi}, %1; mov.b32 %0, {hi, hi};}\n"
: "=r"(y) : "r"(x));
return y;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint16_t hadd(uint16_t a, uint16_t b) {
uint16_t d;
asm volatile("add.f16 %0, %1, %2;" : "=h"(d) : "h"(a), "h"(b));
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint32_t hadd(uint32_t a, uint32_t b) {
return hadd2(a, b);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint2 hadd4(uint2 a, uint2 b) {
uint2 c;
c.x = hadd2(a.x, b.x);
c.y = hadd2(a.y, b.y);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint2 hadd(uint2 a, uint2 b) {
return hadd4(a, b);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint4 hadd8(uint4 a, uint4 b) {
uint4 c;
c.x = hadd2(a.x, b.x);
c.y = hadd2(a.y, b.y);
c.z = hadd2(a.z, b.z);
c.w = hadd2(a.w, b.w);
return c;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint4 fadd4(uint4 a, uint4 b) {
float4 c;
c.x = reinterpret_cast<const float&>(a.x) + reinterpret_cast<const float&>(b.x);
c.y = reinterpret_cast<const float&>(a.y) + reinterpret_cast<const float&>(b.y);
c.z = reinterpret_cast<const float&>(a.z) + reinterpret_cast<const float&>(b.z);
c.w = reinterpret_cast<const float&>(a.w) + reinterpret_cast<const float&>(b.w);
return reinterpret_cast<const uint4&>(c);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint4 hadd(uint4 a, uint4 b) {
return hadd8(a, b);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ float half_to_float(uint16_t h) {
float f;
asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h));
return f;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ float2 half2_to_float2(uint32_t x) {
uint16_t lo, hi;
asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(x));
return make_float2(half_to_float(lo), half_to_float(hi));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ void half2_to_float2(float &x, float &y, uint32_t h) {
float2 tmp = half2_to_float2(h);
x = tmp.x;
y = tmp.y;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint16_t hfma(uint16_t a, uint16_t b, uint16_t c) {
uint16_t d;
asm volatile("fma.rn.f16 %0, %1, %2, %3;" : "=h"(d) : "h"(a), "h"(b), "h"(c));
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint16_t hmul(uint16_t a, uint16_t b) {
uint16_t d;
asm volatile("mul.f16 %0, %1, %2;" : "=h"(d) : "h"(a), "h"(b));
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ float sigmoid(float x) {
return 1.f / (1.f + expf(-x));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void clear(uint16_t &dst) {
dst = uint16_t(0);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void clear(uint32_t &dst) {
dst = 0u;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void clear(uint2 &dst) {
dst = make_uint2(0u, 0u);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void clear(uint4 &dst) {
dst = make_uint4(0u, 0u, 0u, 0u);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// P R E D I C A T E P A C K I N G
//
////////////////////////////////////////////////////////////////////////////////////////////////////
enum { BYTES_PER_REG = 4, PREDS_PER_BYTE = 4, PREDS_PER_REG = BYTES_PER_REG * PREDS_PER_BYTE };
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// G E N E R I C P R E D I C A T E D L D G S T S
//
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N, int M, typename Functor >
inline __device__ void load_(Functor &fct, const uint32_t (&preds)[M]) {
// The number of complete bytes (where we use all the predicates in a byte).
enum { COMPLETE = N / PREDS_PER_BYTE };
// Make sure we did allocate enough predicates.
static_assert(Div_up<COMPLETE, BYTES_PER_REG>::VALUE <= M, "");
// The remainder.
enum { REMAINDER = N - COMPLETE * PREDS_PER_BYTE };
// Make sure we got the math right and the remainder is between 0 and 3.
static_assert(REMAINDER >= 0 && REMAINDER <= 3, "");
// The mask to extract the predicates.
enum { COMPLETE_MASK = (1 << PREDS_PER_BYTE) - 1 };
// Clear the fetch registers.
#pragma unroll
for( int ii = 0; ii < N; ++ii ) {
fct.clear(ii);
}
// Run complete steps.
bool p[PREDS_PER_BYTE];
#pragma unroll
for( int ii = 0; ii < COMPLETE; ++ii ) {
// The predicate.
uint32_t reg = preds[ii / BYTES_PER_REG];
// Extract the predicates.
#pragma unroll
for( int jj = 0; jj < PREDS_PER_BYTE; ++jj ) {
uint32_t mask = 1u << (ii % BYTES_PER_REG * 8 + jj);
p[jj] = (reg & mask) != 0u;
}
// Issue the loads.
#pragma unroll
for( int jj = 0; jj < PREDS_PER_BYTE; ++jj ) {
fct.load(ii * PREDS_PER_BYTE + jj, p[jj]);
}
}
// Skip the rest of the code if we do not have a remainder.
if( REMAINDER > 0 ) {
// The mask to extract the predicates.
enum { REMAINDER_MASK = (1 << REMAINDER) - 1 };
// The predicate register.
uint32_t reg = preds[COMPLETE / BYTES_PER_REG];
// Extract the predicates.
#pragma unroll
for( int jj = 0; jj < PREDS_PER_BYTE; ++jj ) {
uint32_t mask = 1u << (COMPLETE % BYTES_PER_REG * 8 + jj);
p[jj] = (reg & mask) != 0u;
}
// Issue the loads.
#pragma unroll
for( int ii = 0; ii < REMAINDER; ++ii ) {
fct.load(COMPLETE * PREDS_PER_BYTE + ii, p[ii]);
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int M, typename Functor >
inline __device__ void load_(Functor &fct, uint32_t preds) {
uint32_t tmp[1] = { preds };
load_<M>(fct, tmp);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// L D G
//
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldg(uint8_t &dst, const void *ptr) {
dst = *reinterpret_cast<const uint8_t*>(ptr);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldg(uint16_t &dst, const void *ptr) {
dst = *reinterpret_cast<const uint16_t*>(ptr);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldg(uint32_t &dst, const void *ptr) {
dst = *reinterpret_cast<const uint32_t*>(ptr);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldg(uint2 &dst, const void *ptr) {
dst = *reinterpret_cast<const uint2*>(ptr);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldg(uint4 &dst, const void *ptr) {
dst = *reinterpret_cast<const uint4*>(ptr);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< typename Data_type, int N >
struct Ldg_functor {
// Ctor.
inline __device__ Ldg_functor(Data_type (&fetch)[N], const void* (&ptrs)[N])
: fetch_(fetch), ptrs_(ptrs) {
}
// Clear the element.
inline __device__ void clear(int ii) {
fmha::clear(fetch_[ii]);
}
// Trigger the loads.
inline __device__ void load(int ii, bool p) {
if( p ) {
ldg(fetch_[ii], ptrs_[ii]);
}
}
// The fetch registers.
Data_type (&fetch_)[N];
// The pointers.
const void* (&ptrs_)[N];
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< typename Data_type, int N, int M >
inline __device__ void ldg_(Data_type (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) {
Ldg_functor<Data_type, N> fct(fetch, ptrs);
load_<N>(fct, preds);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N, int M >
inline __device__ void ldg(uint8_t (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) {
ldg_<uint8_t, N>(fetch, ptrs, preds);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N, int M >
inline __device__ void ldg(uint16_t (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) {
ldg_<uint16_t, N>(fetch, ptrs, preds);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N, int M >
inline __device__ void ldg(uint32_t (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) {
ldg_<uint32_t, N>(fetch, ptrs, preds);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N, int M >
inline __device__ void ldg(uint2 (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) {
ldg_<uint2, N>(fetch, ptrs, preds);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N, int M >
inline __device__ void ldg(uint4 (&fetch)[N], const void* (&ptrs)[N], uint32_t (&preds)[M]) {
ldg_<uint4, N>(fetch, ptrs, preds);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// L D S
//
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void lds(uint16_t &dst, uint32_t ptr) {
asm volatile("ld.shared.b16 %0, [%1];\n" : "=h"(dst) : "r"(ptr));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void lds(uint32_t &dst, uint32_t ptr) {
asm volatile("ld.shared.b32 %0, [%1];\n" : "=r"(dst) : "r"(ptr));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void lds(uint2 &dst, uint32_t ptr) {
asm volatile("ld.shared.v2.b32 {%0, %1}, [%2];\n" : "=r"(dst.x), "=r"(dst.y) : "r"(ptr));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void lds(uint4 &dst, uint32_t ptr) {
asm volatile("ld.shared.v4.b32 {%0, %1, %2, %3}, [%4];\n"
: "=r"(dst.x)
, "=r"(dst.y)
, "=r"(dst.z)
, "=r"(dst.w)
: "r"(ptr));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// L D S M
//
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldsm(uint32_t &dst, uint32_t ptr) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730
asm volatile("ldmatrix.sync.aligned.m8n8.x1.shared.b16 {%0}, [%1];\n"
: "=r"(dst) : "r"(ptr));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldsmt(uint32_t &dst, uint32_t ptr) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730
asm volatile("ldmatrix.sync.aligned.m8n8.x1.trans.shared.b16 {%0}, [%1];\n"
: "=r"(dst) : "r"(ptr));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldsm(uint2 &dst, uint32_t ptr) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730
asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0, %1}, [%2];\n"
: "=r"(dst.x), "=r"(dst.y) : "r"(ptr));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldsmt(uint2 &dst, uint32_t ptr) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730
asm volatile("ldmatrix.sync.aligned.m8n8.x2.trans.shared.b16 {%0, %1}, [%2];\n"
: "=r"(dst.x), "=r"(dst.y) : "r"(ptr));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldsm(uint4 &dst, uint32_t ptr) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n"
: "=r"(dst.x), "=r"(dst.y), "=r"(dst.z), "=r"(dst.w) : "r"(ptr));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldsmt(uint4 &dst, uint32_t ptr) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730
asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];\n"
: "=r"(dst.x), "=r"(dst.y), "=r"(dst.z), "=r"(dst.w) : "r"(ptr));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// S T G
//
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void stg(void *ptr, uint8_t val) {
*reinterpret_cast<uint8_t*>(ptr) = val;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void stg(void *ptr, uint16_t val) {
*reinterpret_cast<uint16_t*>(ptr) = val;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void stg(void *ptr, uint32_t val) {
*reinterpret_cast<uint32_t*>(ptr) = val;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void stg(void *ptr, uint2 val) {
*reinterpret_cast<uint2*>(ptr) = val;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void stg(void *ptr, uint4 val) {
*reinterpret_cast<uint4*>(ptr) = val;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// S T S
//
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void sts(uint32_t ptr, uint16_t val) {
asm volatile("st.shared.b16 [%0], %1;\n" : : "r"(ptr), "h"(val));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void sts(uint32_t ptr, uint32_t val) {
asm volatile("st.shared.b32 [%0], %1;\n" : : "r"(ptr), "r"(val));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void sts(uint32_t ptr, uint2 val) {
asm volatile("st.shared.v2.b32 [%0], {%1, %2};\n"
:
: "r"(ptr)
, "r"(val.x)
, "r"(val.y));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void sts(uint32_t ptr, uint4 val) {
asm volatile("st.shared.v4.b32 [%0], {%1, %2, %3, %4};\n"
:
: "r"(ptr)
, "r"(val.x)
, "r"(val.y)
, "r"(val.z)
, "r"(val.w));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< typename Data_type, int N >
inline __device__ void sts_(uint32_t (&ptrs)[N], const Data_type (&data)[N]) {
#pragma unroll
for( int ii = 0; ii < N; ++ii ) {
sts(ptrs[ii], data[ii]);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
inline __device__ void sts(uint32_t (&ptrs)[N], const uint16_t (&data)[N]) {
sts_<uint16_t, N>(ptrs, data);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
inline __device__ void sts(uint32_t (&ptrs)[N], const uint32_t (&data)[N]) {
sts_<uint32_t, N>(ptrs, data);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
inline __device__ void sts(uint32_t (&ptrs)[N], const uint2 (&data)[N]) {
sts_<uint2, N>(ptrs, data);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int N >
inline __device__ void sts(uint32_t (&ptrs)[N], const uint4 (&data)[N]) {
sts_<uint4, N>(ptrs, data);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct MaxOp {
__device__ inline T operator()(T const & x, T const & y) { return x > y ? x : y; }
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct SumOp {
__device__ inline T operator()(T const & x, T const & y) { return x + y; }
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int THREADS>
struct Allreduce {
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
template<typename T, typename Operator>
static __device__ inline T run(T x, Operator &op) {
constexpr int OFFSET = THREADS / 2;
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
return Allreduce<OFFSET>::run(x, op);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
struct Allreduce<2> {
template<typename T, typename Operator>
static __device__ inline T run(T x, Operator &op) {
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
return x;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Operator, int M>
__device__ inline void quad_reduce(float (&dst)[M], float (&src)[M], Operator &op) {
#pragma unroll
for(int mi=0; mi < M; mi++){
dst[mi] = src[mi];
dst[mi] = op(dst[mi], __shfl_down_sync(uint32_t(-1), dst[mi], 2));
dst[mi] = op(dst[mi], __shfl_down_sync(uint32_t(-1), dst[mi], 1));
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Operator, int M>
__device__ inline void quad_reduce(float (&dst)[M], float2 (&src)[M], Operator &op) {
float tmp[M];
#pragma unroll
for(int mi=0; mi < M; mi++){
tmp[mi] = op(src[mi].x, src[mi].y);
}
quad_reduce(dst, tmp, op);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Operator, int M>
__device__ inline void quad_allreduce(float (&dst)[M], float (&src)[M], Operator &op) {
#pragma unroll
for(int mi=0; mi < M; mi++){
dst[mi] = src[mi];
dst[mi] = Allreduce<4>::run(dst[mi], op);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Operator, int M>
__device__ inline void quad_allreduce(float (&dst)[M], float2 (&src)[M], Operator &op) {
float tmp[M];
#pragma unroll
for(int mi=0; mi < M; mi++){
tmp[mi] = op(src[mi].x, src[mi].y);
}
quad_allreduce(dst, tmp, op);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include "fmha.h"
#include "fmha_dgrad_kernel_1xN_reload.h"
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
extern "C" __global__ void fmha_dgrad_fp16_128_64_sm80_kernel(Fused_multihead_attention_fprop_params params) {
fmha::compute_dv_1xN<Kernel_traits>(params);
fmha::compute_dq_dk_1xN<Kernel_traits>(params);
}
void run_fmha_dgrad_fp16_128_64_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream) {
constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);
constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;
constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;
constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;
using Smem_tile_s = fmha::Smem_tile_mma_transposed< Kernel_traits::Cta_tile_p>;
constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE;
static_assert(smem_size_s == 16 * 128 * 2);
static_assert(smem_size_o == 16 * 64 * 4 * Kernel_traits::Cta_tile_p::WARPS_N);
constexpr int smem_size_dv = smem_size_s + 2 * smem_size_q + smem_size_v + smem_size_softmax;
constexpr int smem_size_dq_dk = smem_size_s + smem_size_o + smem_size_q + smem_size_v;
constexpr int smem_size = std::max(smem_size_dv, smem_size_dq_dk);
if( smem_size >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(
fmha_dgrad_fp16_128_64_sm80_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
dim3 grid(params.h, params.b);
fmha_dgrad_fp16_128_64_sm80_kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);
}
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include "fmha.h"
#include "fmha_dgrad_kernel_1xN_reload.h"
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>;
extern "C" __global__ void fmha_dgrad_fp16_256_64_sm80_kernel(Fused_multihead_attention_fprop_params params) {
fmha::compute_dv_1xN<Kernel_traits>(params);
fmha::compute_dq_dk_1xN<Kernel_traits>(params);
}
void run_fmha_dgrad_fp16_256_64_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream) {
constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);
constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;
constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;
constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;
using Smem_tile_s = fmha::Smem_tile_mma_transposed< Kernel_traits::Cta_tile_p>;
constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE;
static_assert(smem_size_s == 16 * 256 * 2);
static_assert(smem_size_o == 16 * 64 * 4 * Kernel_traits::Cta_tile_p::WARPS_N);
constexpr int smem_size_dv = smem_size_s + 2 * smem_size_q + smem_size_v + smem_size_softmax;
constexpr int smem_size_dq_dk = smem_size_s + smem_size_o + smem_size_q + smem_size_v;
constexpr int smem_size = std::max(smem_size_dv, smem_size_dq_dk);
if( smem_size >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(
fmha_dgrad_fp16_256_64_sm80_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
dim3 grid(params.h, params.b);
fmha_dgrad_fp16_256_64_sm80_kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);
}
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include "fmha.h"
#include "fmha_dgrad_kernel_1xN_reload.h"
using Kernel_traits = FMHA_kernel_traits<384, 64, 16, 1, 8, 0x08u>;
extern "C" __global__ void fmha_dgrad_fp16_384_64_sm80_kernel(Fused_multihead_attention_fprop_params params) {
fmha::compute_dv_1xN<Kernel_traits>(params);
fmha::compute_dq_dk_1xN<Kernel_traits>(params);
}
void run_fmha_dgrad_fp16_384_64_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream) {
constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);
constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;
constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;
constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;
using Smem_tile_s = fmha::Smem_tile_mma_transposed< Kernel_traits::Cta_tile_p>;
constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE;
static_assert(smem_size_s == 16 * 384 * 2);
static_assert(smem_size_o == 16 * 64 * 4 * Kernel_traits::Cta_tile_p::WARPS_N);
constexpr int smem_size_dv = smem_size_s + 2 * smem_size_q + smem_size_v + smem_size_softmax;
constexpr int smem_size_dq_dk = smem_size_s + smem_size_o + smem_size_q + smem_size_v;
constexpr int smem_size = std::max(smem_size_dv, smem_size_dq_dk);
if( smem_size >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(
fmha_dgrad_fp16_384_64_sm80_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
dim3 grid(params.h, params.b);
fmha_dgrad_fp16_384_64_sm80_kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);
}
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include "fmha.h"
#include "fmha_dgrad_kernel_1xN_reload.h"
#include "fmha_dgrad_kernel_1xN_reload_nl.h"
using Kernel_traits = FMHA_kernel_traits<512, 64, 16, 1, 8, 0x08u>;
extern "C" __global__ void fmha_dgrad_fp16_512_64_sm80_kernel(Fused_multihead_attention_fprop_params params) {
fmha::compute_dv_1xN<Kernel_traits>(params);
fmha::compute_dq_dk_1xN<Kernel_traits>(params);
}
template<int CHUNKS>
__global__
void fmha_dgrad_fp16_512_64_sm80_nl_kernel(Fused_multihead_attention_fprop_params params){
fmha::compute_dv_1xN_nl<CHUNKS, Kernel_traits>(params);
fmha::compute_dq_dk_1xN_nl<CHUNKS, Kernel_traits>(params);
}
void run_fmha_dgrad_fp16_512_64_sm80(const Fused_multihead_attention_fprop_params &params, cudaStream_t stream) {
constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);
constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;
constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;
constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;
using Smem_tile_s = fmha::Smem_tile_mma_transposed< Kernel_traits::Cta_tile_p>;
constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE;
static_assert(smem_size_s == 16 * 512 * 2);
static_assert(smem_size_o == 16 * 64 * 4 * Kernel_traits::Cta_tile_p::WARPS_N);
constexpr int smem_size_dv = smem_size_s + 2 * smem_size_q + smem_size_v + smem_size_softmax;
constexpr int smem_size_dq_dk = smem_size_s + smem_size_o + smem_size_q + smem_size_v;
constexpr int smem_size = std::max(smem_size_dv, smem_size_dq_dk);
if( smem_size >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(
fmha_dgrad_fp16_512_64_sm80_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
dim3 grid(params.h, params.b);
fmha_dgrad_fp16_512_64_sm80_kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);
}
void run_fmha_dgrad_fp16_512_64_sm80_nl(const Fused_multihead_attention_fprop_params &params, const int num_chunks, cudaStream_t stream) {
constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);
constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;
constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;
constexpr int smem_size_o = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;
using Smem_tile_s = fmha::Smem_tile_mma_transposed<Kernel_traits::Cta_tile_p>;
constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE;
static_assert(smem_size_s == 16 * 512 * 2);
static_assert(smem_size_o == 16 * 64 * 4 * Kernel_traits::Cta_tile_p::WARPS_N);
constexpr int smem_size_dv = smem_size_s + 2 * smem_size_q + smem_size_v + smem_size_softmax;
constexpr int smem_size_dq_dk = smem_size_s + smem_size_o + smem_size_q + smem_size_v;
constexpr int smem_size = std::max(smem_size_dv, smem_size_dq_dk);
auto kernel = fmha_dgrad_fp16_512_64_sm80_nl_kernel<2>;
if( num_chunks == 2 ) {
kernel = fmha_dgrad_fp16_512_64_sm80_nl_kernel<2>;
}else if( num_chunks == 3 ) {
kernel = fmha_dgrad_fp16_512_64_sm80_nl_kernel<3>;
} else {
assert(false && "Unsupperted number of chunks");
}
if( smem_size >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
dim3 grid(params.h, params.b, num_chunks);
kernel<<<grid, Kernel_traits::THREADS, smem_size, stream>>>(params);
FMHA_CHECK_CUDA(cudaPeekAtLastError());
}
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include "fmha_kernel.h"
#include <fmha/kernel_traits.h>
#include <fmha/gemm.h>
namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, typename Params>
inline __device__ void compute_dv_1xN(const Params &params) {
// The description of the CTA tile for the 1st batched GEMM.
using Cta_tile_p = typename Kernel_traits::Cta_tile_p;
// The description of the CTA tile for the 2nd batched GEMM.
using Cta_tile_dv =
fmha::Cta_tile_extd<Cta_tile_p::N, Cta_tile_p::K, Cta_tile_p::M, Cta_tile_p::WARPS_N, 1, Cta_tile_p::WARPS_M>;
static_assert(Cta_tile_dv::M == 512 || Cta_tile_dv::M == 384 || Cta_tile_dv::M == 256 || Cta_tile_dv::M == 128);
static_assert(Cta_tile_dv::N == 64);
static_assert(Cta_tile_dv::K == 16);
// The MMA tile for the 1st GEMM.
using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;
// The MMA tile for the 2nd GEMM.
using Mma_tile_dv = fmha::Hmma_tile<Cta_tile_dv>;
// The global memory tile to load Q.
using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q;
// The shared memory tile to swizzle Q.
// using Smem_tile_q = typename Kernel_traits::Smem_tile_q;
using Smem_tile_q = fmha::Smem_tile_a<Cta_tile_p, fmha::Row, Gmem_tile_q::BYTES_PER_LDG, 2>;
// The shared memory tile to reload Q as fragment b.
using Smem_tile_qt = fmha::Smem_tile_b<Cta_tile_dv, fmha::Row, Gmem_tile_q::BYTES_PER_LDG, 2>;
// The global memory tile to load K.
using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k;
// The shared memory tile to swizzle K.
using Smem_tile_k = typename Kernel_traits::Smem_tile_k;
// The global memory tile to load V.
using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v;
// The shared memory tile to swizzle V.
using Smem_tile_v = typename Kernel_traits::Smem_tile_v;
// The global memory tile to store O.
using Gmem_tile_o = typename Kernel_traits::Gmem_tile_o;
// The shared memory tile to swizzle O.
using Smem_tile_o = typename Kernel_traits::Smem_tile_o;
// The global memory tile to store dV.
using Gmem_tile_dv = typename Kernel_traits::Gmem_tile_v;
// The shared memory tile to swizzle dV.
using Smem_tile_dv = fmha::Smem_tile_mma_epilogue<Cta_tile_dv>;
static_assert(Smem_tile_dv::NUM_LDS == Gmem_tile_dv::LDGS);
static_assert(Smem_tile_dv::THREADS_PER_ROW == Gmem_tile_dv::THREADS_PER_ROW);
using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s;
using Smem_tile_st = typename Kernel_traits::Smem_tile_st;
using Gmem_tile_do = typename Kernel_traits::Gmem_tile_do;
// Shared memory.
extern __shared__ char smem_[];
// The block index for the batch.
const int bidb = blockIdx.y;
// The block index for the head.
const int bidh = blockIdx.x;
// The thread index.
const int tidx = threadIdx.x;
const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);
if( binfo.stop_early() )
return;
Mask<Cta_tile_p> mask(params, binfo, tidx);
// Allocate the global memory tile loader for Q.
Gmem_tile_do gmem_q(params, binfo, tidx); // treating dout as Q
// Allocate the shared memory tile loader for Q.
Smem_tile_q smem_q(&smem_[0], tidx);
Smem_tile_qt smem_qt(&smem_[0], tidx);
Smem_tile_st smem_s(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE], tidx);
// Allocate the global memory tile loader for K.
Gmem_tile_k gmem_k(params, 2, binfo, tidx); // treating V as K
// Allocate the shared memory tile loader for K.
Smem_tile_k smem_k(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx);
// Trigger the loads for Q.
gmem_q.load(smem_q);
// Trigger the loads for K.
gmem_k.load(smem_k);
// Commit the data for Q and K to shared memory.
gmem_q.commit(smem_q);
gmem_k.commit(smem_k);
// Make sure the data is in shared memory.
__syncthreads();
// Load the fragments for Q.
typename Smem_tile_q::Fragment frag_q[2][Mma_tile_p::MMAS_M];
smem_q.load(frag_q[0], 0);
typename Smem_tile_qt::Fragment frag_qt[2][Mma_tile_dv::MMAS_N];
static_assert(Smem_tile_qt::Fragment::NUM_REGS == 4);
static_assert(Mma_tile_dv::MMAS_K == 1);
smem_qt.load(frag_qt[0], 0);
// Load the fragments for K. We keep the data in registers during the entire kernel.
typename Smem_tile_k::Fragment frag_k[2][Mma_tile_p::MMAS_N];
smem_k.load(frag_k[0], 0);
enum { BITS_PER_ELT_S = sizeof(fmha::A_type) * 8 };
Gmem_tile_s gmem_s(params, binfo, tidx);
// Create the object to do the softmax.
using Softmax = fmha::Softmax<Cta_tile_p, Kernel_traits>;
Softmax softmax(
params, &smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_st::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE], bidb, tidx);
enum { THREADS_PER_ROW = 32 };
enum { M = Mma_tile_p::MMAS_M };
enum { N = Mma_tile_p::MMAS_N };
// Declare the accumulators for the 2nd gemm.
fmha::Fragment_accumulator acc_dv[Mma_tile_dv::MMAS_M][Mma_tile_dv::MMAS_N];
fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_dv::WARPS_K>::apply(acc_dv);
enum { STEPS = Cta_tile_p::N / Cta_tile_p::M };
// Load over the entire sequence length.
for( int l = 0; l < STEPS; l++ ) {
const int loop = l * Cta_tile_p::M;
if( loop >= binfo.actual_seqlen )
break;
// Load S
uint4 s_regs[M][N];
gmem_s.load(s_regs, mask);
fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N];
fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_p::WARPS_K>::apply(acc_p);
// Do this part of P^T = (Q * K^T)^T.
#pragma unroll
for( int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki ) {
// Trigger the load from shared memory for the next series of Q values.
smem_q.load(frag_q[ki & 1], ki);
smem_k.load(frag_k[ki & 1], ki);
// Do the math for the values already in registers.
fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]);
}
// Store s * dmask to smem for transpose
smem_s.store(s_regs);
// Declare the accumulators for the 1st gemm.
// Do the final stage of math.
{
int ki = Mma_tile_p::MMAS_K;
fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]);
}
// Trigger the load for the next Q values. We're using double buffering, so reading qt is safe
if( l < STEPS - 1) {
smem_q.move_to_next_write_buffer();
gmem_q.move();
gmem_q.load(smem_q);
}
// Convert from the accumulator type to FP32 for Softmax.
softmax.unpack(acc_p);
float s_mat[2 * M][4 * N];
#pragma unroll
for( int mi = 0; mi < M; mi++ ) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
uint4 &dst = s_regs[mi][ni];
fmha::half2_to_float2(s_mat[2 * mi + 0][4 * ni + 0], s_mat[2 * mi + 0][4 * ni + 1], dst.x);
fmha::half2_to_float2(s_mat[2 * mi + 0][4 * ni + 2], s_mat[2 * mi + 0][4 * ni + 3], dst.y);
fmha::half2_to_float2(s_mat[2 * mi + 1][4 * ni + 0], s_mat[2 * mi + 1][4 * ni + 1], dst.z);
fmha::half2_to_float2(s_mat[2 * mi + 1][4 * ni + 2], s_mat[2 * mi + 1][4 * ni + 3], dst.w);
}
}
#pragma unroll
for( int mi = 0; mi < M; mi++ ) {
#pragma unroll
for( int ii = 0; ii < 2; ii++ ) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
#pragma unroll
for( int jj = 0; jj < 4; jj++ ) {
float & s_dmask = s_mat[2 * mi + ii][4 * ni + jj];
const bool drop = reinterpret_cast<const uint32_t &>(s_dmask) & 0x80000000;
const float d_s = drop ? 0.f : softmax.elt_[2 * mi + ii][4 * ni + jj] * params.rp_dropout;
s_dmask = fabsf(s_dmask);
softmax.elt_[2 * mi + ii][4 * ni + jj] = d_s * fabsf(s_dmask);
}
}
}
}
float p_sum[2 * M];
softmax.reduce_sum(p_sum);
const float scalef = reinterpret_cast<const float &>(params.scale_softmax);
#pragma unroll
for( int mi = 0; mi < M; mi++ ) {
#pragma unroll
for( int ii = 0; ii < 2; ii++ ) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
#pragma unroll
for( int jj = 0; jj < 4; jj++ ) {
softmax.elt_[2 * mi + ii][4 * ni + jj] -= p_sum[2 * mi + ii] * (s_mat[2 * mi + ii][4 * ni + jj]) ;
softmax.elt_[2 * mi + ii][4 * ni + jj] *= scalef;
}
}
}
}
typename Smem_tile_st::Fragment frag_s[Mma_tile_dv::MMAS_K][Mma_tile_dv::MMAS_M];
smem_s.load(frag_s);
for( int ki = 0; ki < Mma_tile_dv::MMAS_K; ki++ ) {
for( int mi = 0; mi < Mma_tile_dv::MMAS_M; mi++ ) {
for( int ii = 0; ii < Smem_tile_st::Fragment::NUM_REGS; ii++ ) {
frag_s[ki][mi].reg(ii) = fmha::hmul2(frag_s[ki][mi].reg(ii), params.scale_dropout);
frag_s[ki][mi].reg(ii) = fmha::hrelu2(frag_s[ki][mi].reg(ii));
}
}
}
gmem_s.store(softmax.elt_, mask);
gmem_s.move();
#pragma unroll
for( int ki = 1; ki < Mma_tile_dv::MMAS_K; ++ki ) {
// Trigger the load from shared memory for the next series of Q values.
smem_qt.load(frag_qt[ki & 1], ki);
// Do the math for the values already in registers.
fmha::gemm(acc_dv, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]);
}
// Do the final stage of math.
{
int ki = Mma_tile_dv::MMAS_K;
fmha::gemm(acc_dv, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]);
}
// Commit the values for Q into shared memory.
if(l < STEPS - 1) {
gmem_q.commit(smem_q);
}
// Make sure we are reading from the correct buffer.
smem_q.move_to_next_read_buffer();
smem_qt.move_to_next_read_buffer();
// Make sure the data is in shared memory.
__syncthreads();
// Trigger the loads for the values of Q for the next iteration.
smem_q.load(frag_q[0], 0);
smem_k.load(frag_k[0], 0);
smem_qt.load(frag_qt[0], 0);
} // Outer loop over the sequence length.
// Epilogue swizzle for dV
Smem_tile_dv smem_dv(&smem_[Kernel_traits::Smem_tile_q::BYTES_PER_TILE], tidx);
smem_dv.store(acc_dv);
__syncthreads();
uint4 dv_out[Smem_tile_dv::NUM_LDS];
smem_dv.load(dv_out);
Qkv_params dv_params;
dv_params.qkv_ptr = params.dqkv_ptr;
dv_params.qkv_stride_in_bytes = params.qkv_stride_in_bytes;
dv_params.h = params.h;
Gmem_tile_dv gmem_dv(dv_params, 2, binfo, tidx);
gmem_dv.store(dv_out);
}
template<typename Kernel_traits, typename Params>
inline __device__ void compute_dq_dk_1xN(const Params &params) {
// The description of the CTA tile for the 1st batched GEMM.
using Cta_tile_p = typename Kernel_traits::Cta_tile_p;
using Cta_tile_o = typename Kernel_traits::Cta_tile_o;
// The description of the CTA tile for the 2nd batched GEMM.
using Cta_tile_dk =
fmha::Cta_tile_extd<Cta_tile_p::N, Cta_tile_p::K, Cta_tile_p::M, Cta_tile_p::WARPS_N, 1, Cta_tile_p::WARPS_M>;
static_assert(Cta_tile_dk::M == 512 || Cta_tile_dk::M == 384 || Cta_tile_dk::M == 256 || Cta_tile_dk::M == 128);
static_assert(Cta_tile_dk::N == 64);
static_assert(Cta_tile_dk::K == 16);
// The MMA tile for the 1st GEMM.
using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;
using Mma_tile_o = fmha::Hmma_tile<Cta_tile_o>;
// The MMA tile for the 2nd GEMM.
using Mma_tile_dk = fmha::Hmma_tile<Cta_tile_dk>;
// The global memory tile to load Q.
using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q;
// The shared memory tile to swizzle Q.
using Smem_tile_q = typename Kernel_traits::Smem_tile_q;
// The global memory tile to load K.
using Gmem_tile_k = typename Kernel_traits::Gmem_tile_v;
// The shared memory tile to swizzle K.
using Smem_tile_k = typename Kernel_traits::Smem_tile_v; // K is used like V in fprop
// The global memory tile to load V.
using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v;
// The shared memory tile to swizzle V.
using Smem_tile_v = typename Kernel_traits::Smem_tile_v;
// The global memory tile to store O.
// using Gmem_tile_o = typename Kernel_traits::Gmem_tile_o;
using Gmem_tile_o = fmha::Gmem_tile_dq<Cta_tile_o>;
// The shared memory tile to swizzle O.
using Smem_tile_o = typename Kernel_traits::Smem_tile_o;
// The global memory tile to store dK.
using Gmem_tile_dk = typename Kernel_traits::Gmem_tile_v;
// The shared memory tile to swizzle dK.
using Smem_tile_dk = fmha::Smem_tile_mma_epilogue<Cta_tile_dk>;
static_assert(Smem_tile_dk::NUM_LDS == Gmem_tile_dk::LDGS);
static_assert(Smem_tile_dk::THREADS_PER_ROW == Gmem_tile_dk::THREADS_PER_ROW);
// The shared memory tile to reload Q transposed.
using Smem_tile_qt = fmha::Smem_tile_b<Cta_tile_dk, fmha::Row, Gmem_tile_q::BYTES_PER_LDG, 1>;
using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s;
using Smem_tile_st = typename Kernel_traits::Smem_tile_st;
enum { M = Mma_tile_p::MMAS_M };
enum { N = Mma_tile_p::MMAS_N };
static_assert(M == Mma_tile_o::MMAS_M);
static_assert(N == Mma_tile_o::MMAS_K);
// Shared memory.
extern __shared__ char smem_[];
// The block index for the batch.
const int bidb = blockIdx.y;
// The block index for the head.
const int bidh = blockIdx.x;
// The thread index.
const int tidx = threadIdx.x;
const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);
if( binfo.stop_early() )
return;
Mask<Cta_tile_p> mask(params, binfo, tidx);
// Allocate the global memory tile loader for Q.
Gmem_tile_q gmem_q(params, 0, binfo, tidx);
// Allocate the shared memory tile loader for Q.
Smem_tile_q smem_q(&smem_[0], tidx);
Smem_tile_qt smem_qt(&smem_[0], tidx);
Smem_tile_st smem_s(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE], tidx);
// Allocate the global memory tile loader for K.
Gmem_tile_k gmem_k(params, 1, binfo, tidx);
// Allocate the shared memory tile loader for K.
Smem_tile_k smem_k(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx);
// Allocate the global memory tile loader for O.
Gmem_tile_o gmem_o(params, binfo, tidx);
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
Smem_tile_o smem_o(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE], tidx);
// Trigger the loads for Q.
gmem_q.load(smem_q);
// Trigger the loads for K.
gmem_k.load(smem_k);
Gmem_tile_s gmem_s(params, binfo, tidx);
// Load dP
uint4 s_regs[M][N];
gmem_s.load(s_regs, mask);
gmem_s.move();
// Commit the data for Q and K to shared memory.
gmem_q.commit(smem_q);
gmem_k.commit(smem_k);
// Make sure the data is in shared memory.
__syncthreads();
typename Smem_tile_qt::Fragment frag_qt[2][Mma_tile_dk::MMAS_N];
smem_qt.load(frag_qt[0], 0);
typename Smem_tile_k::Fragment frag_k[2][Mma_tile_o::MMAS_N];
smem_k.load(frag_k[0], 0);
enum { BITS_PER_ELT_S = sizeof(fmha::A_type) * 8 };
enum { THREADS_PER_ROW = 32 };
enum { STEPS = Cta_tile_p::N / Cta_tile_p::M };
// Declare the accumulators for the 2nd gemm.
fmha::Fragment_accumulator acc_dk[Mma_tile_dk::MMAS_M][Mma_tile_dk::MMAS_N];
fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_dk::WARPS_K>::apply(acc_dk);
// Load over the entire sequence length.
for( int l=0;l<STEPS;l++) {
const int loop = l * Cta_tile_p::M;
if( loop >= binfo.actual_seqlen )
break;
// Pack dP as Fragment_a
fmha::Fragment_a<fmha::Row> frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M];
#pragma unroll
for( int mi = 0; mi < M; mi++ ) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
uint4 &dst = s_regs[mi][ni];
frag_p[ni][mi].reg(0) = dst.x; // row 0, cols 0,1
frag_p[ni][mi].reg(1) = dst.z; // row 8, cols 0,1
frag_p[ni][mi].reg(2) = dst.y; // row 0, cols 8,9
frag_p[ni][mi].reg(3) = dst.w; // row 8, cols 8,9
}
}
// Declare the accumulators for the 1st gemm.
fmha::Fragment_accumulator acc_o[Mma_tile_o::MMAS_M][Mma_tile_o::MMAS_N];
fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_o::WARPS_K>::apply(acc_o);
// Do this part of O = P^T * V^T. dQ = dP x dK
#pragma unroll
for( int ki = 1; ki < Mma_tile_o::MMAS_K; ++ki ) {
// Trigger the load from shared memory for the next series of Q values.
smem_k.load(frag_k[ki & 1], ki);
// Do the math for the values already in registers.
fmha::gemm(acc_o, frag_p[ki - 1], frag_k[(ki - 1) & 1]);
}
// Do the final stage of math.
{
int ki = Mma_tile_o::MMAS_K;
fmha::gemm(acc_o, frag_p[ki - 1], frag_k[(ki - 1) & 1]);
}
// Store dP to smem for transpose
smem_s.store(s_regs);
if(l < STEPS - 1) {
// Load next part of S
gmem_s.load(s_regs, mask);
gmem_s.move();
smem_q.move_to_next_write_buffer();
gmem_q.move();
gmem_q.load(smem_q);
}
// Loop over MMAS_M.
#pragma unroll
for( int ii = 0; ii < Gmem_tile_o::LOOPS; ++ii ) {
// Swizzle the elements and do the final reduction.
smem_o.store(acc_o, ii);
// Make sure the data is in shared memory.
__syncthreads();
// Load from shared memory.
uint4 out[Gmem_tile_o::STGS_PER_LOOP];
smem_o.load(out);
// Make sure the data was read from shared memory.
if( ii < Gmem_tile_o::LOOPS - 1 ) {
__syncthreads();
}
// Output the values.
gmem_o.store(out, ii);
}
// Move to the next part of the output.
gmem_o.move();
typename Smem_tile_st::Fragment frag_s[Mma_tile_dk::MMAS_K][Mma_tile_dk::MMAS_M];
smem_s.load(frag_s);
#pragma unroll
for( int ki = 1; ki < Mma_tile_dk::MMAS_K; ++ki ) {
// Trigger the load from shared memory for the next series of Q values.
smem_qt.load(frag_qt[ki & 1], ki);
// Do the math for the values already in registers.
fmha::gemm(acc_dk, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]);
}
// Do the final stage of math.
{
int ki = Mma_tile_dk::MMAS_K;
fmha::gemm(acc_dk, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]);
}
// Commit the values for Q into shared memory.
if( l < STEPS - 1) {
gmem_q.commit(smem_q);
}
// Make sure the data is in shared memory.
__syncthreads();
// Trigger the loads for the values of Q for the next iteration.
smem_qt.load(frag_qt[0], 0);
smem_k.load(frag_k[0], 0);
} // Outer loop over the sequence length.
// Epilogue swizzle for dK
Smem_tile_dk smem_dk(&smem_[0], tidx);
smem_dk.store(acc_dk);
__syncthreads();
uint4 dk_out[Smem_tile_dk::NUM_LDS];
smem_dk.load(dk_out);
Qkv_params dk_params;
dk_params.qkv_ptr = params.dqkv_ptr;
dk_params.qkv_stride_in_bytes = params.qkv_stride_in_bytes;
dk_params.h = params.h;
Gmem_tile_dk gmem_dk(dk_params, 1, binfo, tidx);
gmem_dk.store(dk_out);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include "fmha_kernel.h"
#include <fmha/kernel_traits.h>
#include <fmha/gemm.h>
namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int CHUNKS, typename Kernel_traits, typename Params>
inline __device__ void compute_dv_1xN_nl(const Params &params) {
// The description of the CTA tile for the 1st batched GEMM.
using Cta_tile_p = typename Kernel_traits::Cta_tile_p;
// The description of the CTA tile for the 2nd batched GEMM.
using Cta_tile_dv = fmha::Cta_tile_extd<Cta_tile_p::N, Cta_tile_p::K, Cta_tile_p::M, Cta_tile_p::WARPS_N, 1, Cta_tile_p::WARPS_M>;
static_assert(Cta_tile_dv::M == 512 || Cta_tile_dv::M == 384 || Cta_tile_dv::M == 256 || Cta_tile_dv::M == 128);
static_assert(Cta_tile_dv::N == 64);
static_assert(Cta_tile_dv::K == 16);
// The MMA tile for the 1st GEMM.
using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;
// The MMA tile for the 2nd GEMM.
using Mma_tile_dv = fmha::Hmma_tile<Cta_tile_dv>;
// The global memory tile to load Q.
using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q;
// The shared memory tile to swizzle Q.
using Smem_tile_q = fmha::Smem_tile_a<Cta_tile_p, fmha::Row, Gmem_tile_q::BYTES_PER_LDG, 2>;
// The shared memory tile to reload Q as fragment b.
using Smem_tile_qt = fmha::Smem_tile_b<Cta_tile_dv, fmha::Row, Gmem_tile_q::BYTES_PER_LDG, 2>;
// The global memory tile to load K.
using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k;
// The shared memory tile to swizzle K.
using Smem_tile_k = typename Kernel_traits::Smem_tile_k;
// The global memory tile to load V.
using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v;
// The shared memory tile to swizzle V.
using Smem_tile_v = typename Kernel_traits::Smem_tile_v;
// The global memory tile to store dV.
using Gmem_tile_dv = fmha::Gmem_tile_qkv<typename Kernel_traits::Cta_tile_o,
fmha::BITS_PER_ELEMENT_B,
Cta_tile_p::N, //S,
Cta_tile_p::K, //D,
2*CHUNKS>;
// The shared memory tile to swizzle dV.
using Smem_tile_dv = fmha::Smem_tile_mma_epilogue<Cta_tile_dv>;
static_assert(Smem_tile_dv::NUM_LDS == Gmem_tile_dv::LDGS);
static_assert(Smem_tile_dv::THREADS_PER_ROW == Gmem_tile_dv::THREADS_PER_ROW);
using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s;
using Smem_tile_st = typename Kernel_traits::Smem_tile_st;
using Gmem_tile_do = typename Kernel_traits::Gmem_tile_do;
// Shared memory.
extern __shared__ char smem_[];
// The block index for the chunk.
const int bidc = blockIdx.z;
// The block index for the batch.
const int bidb = blockIdx.y;
// The block index for the head.
const int bidh = blockIdx.x;
// The thread index.
const int tidx = threadIdx.x;
const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);
if( binfo.stop_early() )
return;
fmha::Mask<Cta_tile_p> mask(params, binfo, tidx);
// Allocate the global memory tile loader for Q.
Gmem_tile_do gmem_q(params, binfo, tidx); // treating dout as Q
// Allocate the shared memory tile loader for Q.
Smem_tile_q smem_q(&smem_[0], tidx);
Smem_tile_qt smem_qt(&smem_[0], tidx);
Smem_tile_st smem_s(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE], tidx);
// Allocate the global memory tile loader for K.
Gmem_tile_k gmem_k(params, 2, binfo, tidx); // treating V as K
// Allocate the shared memory tile loader for K.
Smem_tile_k smem_k(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx);
Gmem_tile_s gmem_s(params, binfo, tidx);
using Noloop = Noloop_traits<CHUNKS, Cta_tile_p>;
Noloop nl_traits(bidc, binfo);
nl_traits.move_all(gmem_q, gmem_s);
// Trigger the loads for Q.
gmem_q.load(smem_q);
// Trigger the loads for K.
gmem_k.load(smem_k);
// Commit the data for Q and K to shared memory.
gmem_q.commit(smem_q);
gmem_k.commit(smem_k);
// Make sure the data is in shared memory.
__syncthreads();
// Load the fragments for Q.
typename Smem_tile_q::Fragment frag_q[2][Mma_tile_p::MMAS_M];
smem_q.load(frag_q[0], 0);
typename Smem_tile_qt::Fragment frag_qt[2][Mma_tile_dv::MMAS_N];
static_assert(Smem_tile_qt::Fragment::NUM_REGS == 4);
static_assert(Mma_tile_dv::MMAS_K == 1);
smem_qt.load(frag_qt[0], 0);
// Load the fragments for K. We keep the data in registers during the entire kernel.
typename Smem_tile_k::Fragment frag_k[2][Mma_tile_p::MMAS_N];
smem_k.load(frag_k[0], 0);
enum { BITS_PER_ELT_S = sizeof(fmha::A_type) * 8 };
// Create the object to do the softmax.
using Softmax = fmha::Softmax<Cta_tile_p, Kernel_traits>;
Softmax softmax(
params, &smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_st::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE], bidb, tidx);
enum { THREADS_PER_ROW = 32 };
enum { M = Mma_tile_p::MMAS_M };
enum { N = Mma_tile_p::MMAS_N };
// Declare the accumulators for the 2nd gemm.
fmha::Fragment_accumulator acc_dv[Mma_tile_dv::MMAS_M][Mma_tile_dv::MMAS_N];
fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_dv::WARPS_K>::apply(acc_dv);
// Load over the entire sequence length.
for(int l = 0; l < nl_traits.num_steps_;l++) {
uint4 s_regs[M][N];
gmem_s.load(s_regs, mask);
fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N];
fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_p::WARPS_K>::apply(acc_p);
// Do this part of P^T = (Q * K^T)^T.
#pragma unroll
for( int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki ) {
// Trigger the load from shared memory for the next series of Q values.
smem_q.load(frag_q[ki & 1], ki);
smem_k.load(frag_k[ki & 1], ki);
// Do the math for the values already in registers.
fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]);
}
smem_s.store(s_regs);
// Declare the accumulators for the 1st gemm.
// Do the final stage of math.
{
int ki = Mma_tile_p::MMAS_K;
fmha::gemm(acc_p, frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]);
}
// Trigger the load for the next Q values. We're using double buffering, so reading qt is safe
if(l < nl_traits.num_steps_ - 1) {
smem_q.move_to_next_write_buffer();
gmem_q.move();
gmem_q.load(smem_q);
}
// Convert from the accumulator type to FP32 for Softmax.
softmax.unpack(acc_p);
float s_mat[2 * M][4 * N];
#pragma unroll
for( int mi = 0; mi < M; mi++ ) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
uint4 &dst = s_regs[mi][ni];
fmha::half2_to_float2(s_mat[2 * mi + 0][4 * ni + 0], s_mat[2 * mi + 0][4 * ni + 1], dst.x);
fmha::half2_to_float2(s_mat[2 * mi + 0][4 * ni + 2], s_mat[2 * mi + 0][4 * ni + 3], dst.y);
fmha::half2_to_float2(s_mat[2 * mi + 1][4 * ni + 0], s_mat[2 * mi + 1][4 * ni + 1], dst.z);
fmha::half2_to_float2(s_mat[2 * mi + 1][4 * ni + 2], s_mat[2 * mi + 1][4 * ni + 3], dst.w);
}
}
#pragma unroll
for( int mi = 0; mi < M; mi++ ) {
#pragma unroll
for( int ii = 0; ii < 2; ii++ ) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
#pragma unroll
for( int jj = 0; jj < 4; jj++ ) {
float & s_dmask = s_mat[2 * mi + ii][4 * ni + jj];
const bool drop = reinterpret_cast<const uint32_t &>(s_dmask) & 0x80000000;
const float d_s= drop ? 0.f : softmax.elt_[2 * mi + ii][4 * ni + jj] * params.rp_dropout;
s_dmask = fabsf(s_dmask);
softmax.elt_[2 * mi + ii][4 * ni + jj] = d_s * (s_dmask);
}
}
}
}
float p_sum[2 * M];
softmax.reduce_sum(p_sum);
const float scalef = reinterpret_cast<const float &>(params.scale_softmax);
#pragma unroll
for( int mi = 0; mi < M; mi++ ) {
#pragma unroll
for( int ii = 0; ii < 2; ii++ ) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
#pragma unroll
for( int jj = 0; jj < 4; jj++ ) {
softmax.elt_[2 * mi + ii][4 * ni + jj] -= p_sum[2 * mi + ii] * (s_mat[2 * mi + ii][4 * ni + jj]) ;
softmax.elt_[2 * mi + ii][4 * ni + jj] *= scalef;
}
}
}
}
typename Smem_tile_st::Fragment frag_s[Mma_tile_dv::MMAS_K][Mma_tile_dv::MMAS_M];
smem_s.load(frag_s);
for( int ki = 0; ki < Mma_tile_dv::MMAS_K; ki++ ) {
for( int mi = 0; mi < Mma_tile_dv::MMAS_M; mi++ ) {
for( int ii = 0; ii < Smem_tile_st::Fragment::NUM_REGS; ii++ ) {
frag_s[ki][mi].reg(ii) = fmha::hmul2(frag_s[ki][mi].reg(ii), params.scale_dropout);
frag_s[ki][mi].reg(ii) = fmha::hrelu2(frag_s[ki][mi].reg(ii));
}
}
}
gmem_s.store(softmax.elt_, mask);
gmem_s.move();
static_assert(Mma_tile_dv::MMAS_K == 1); // DEBUG
#pragma unroll
for( int ki = 1; ki < Mma_tile_dv::MMAS_K; ++ki ) {
// Trigger the load from shared memory for the next series of Q values.
smem_qt.load(frag_qt[ki & 1], ki);
// Do the math for the values already in registers.
fmha::gemm(acc_dv, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]);
}
// Do the final stage of math.
{
int ki = Mma_tile_dv::MMAS_K;
fmha::gemm(acc_dv, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]);
}
// Commit the values for Q into shared memory.
if(l < nl_traits.num_steps_ - 1) {
gmem_q.commit(smem_q);
}
// Make sure we are reading from the correct buffer.
smem_q.move_to_next_read_buffer();
smem_qt.move_to_next_read_buffer();
// Make sure the data is in shared memory.
__syncthreads();
// Trigger the loads for the values of Q for the next iteration.
smem_q.load(frag_q[0], 0);
smem_k.load(frag_k[0], 0);
smem_qt.load(frag_qt[0], 0);
} // Outer loop over the sequence length.
// Epilogue for dV = (S * D)' * dout'. We're fully exposed to this!
// Epilogue swizzle for dV
Smem_tile_dv smem_dv(&smem_[Kernel_traits::Smem_tile_q::BYTES_PER_TILE], tidx);
smem_dv.store(acc_dv);
__syncthreads();
uint4 dv_out[Smem_tile_dv::NUM_LDS];
smem_dv.load(dv_out);
Qkv_params dv_params;
dv_params.qkv_ptr = params.dkv_ptr;
dv_params.qkv_stride_in_bytes = params.h * 2 * CHUNKS * params.d * sizeof(half);
dv_params.h = params.h;
Gmem_tile_dv gmem_dv(dv_params, nl_traits.get_idx_dv(), binfo, tidx);
gmem_dv.store(dv_out);
}
template<int CHUNKS, typename Kernel_traits, typename Params>
inline __device__ void compute_dq_dk_1xN_nl(const Params &params) {
// The description of the CTA tile for the 1st batched GEMM.
using Cta_tile_p = typename Kernel_traits::Cta_tile_p;
using Cta_tile_o = typename Kernel_traits::Cta_tile_o;
// The description of the CTA tile for the 2nd batched GEMM.
using Cta_tile_dk = fmha::Cta_tile_extd<Cta_tile_p::N, Cta_tile_p::K, Cta_tile_p::M, Cta_tile_p::WARPS_N, 1, Cta_tile_p::WARPS_M>;
static_assert(Cta_tile_dk::M == 512 || Cta_tile_dk::M == 384 || Cta_tile_dk::M == 256 || Cta_tile_dk::M == 128);
static_assert(Cta_tile_dk::N == 64);
static_assert(Cta_tile_dk::K == 16);
// The MMA tile for the 1st GEMM.
using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;
using Mma_tile_o = fmha::Hmma_tile<Cta_tile_o>;
// The MMA tile for the 2nd GEMM.
using Mma_tile_dk = fmha::Hmma_tile<Cta_tile_dk>;
// The global memory tile to load Q.
using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q;
// The shared memory tile to swizzle Q.
using Smem_tile_q = typename Kernel_traits::Smem_tile_q;
// The global memory tile to load K.
using Gmem_tile_k = typename Kernel_traits::Gmem_tile_v;
// The shared memory tile to swizzle K.
using Smem_tile_k = typename Kernel_traits::Smem_tile_v; // K is used like V in fprop
// The global memory tile to load V.
using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v;
// The shared memory tile to swizzle V.
using Smem_tile_v = typename Kernel_traits::Smem_tile_v;
// The global memory tile to store O.
using Gmem_tile_o = Gmem_tile_dq<Cta_tile_o>;
// The shared memory tile to swizzle O.
using Smem_tile_o = typename Kernel_traits::Smem_tile_o;
// The global memory tile to store dK.
using Gmem_tile_dk = fmha::Gmem_tile_qkv<typename Kernel_traits::Cta_tile_o,
fmha::BITS_PER_ELEMENT_B,
Cta_tile_p::N, //S,
Cta_tile_p::K, //D,
2*CHUNKS>;
// The shared memory tile to swizzle dK.
using Smem_tile_dk = fmha::Smem_tile_mma_epilogue<Cta_tile_dk>;
static_assert(Smem_tile_dk::NUM_LDS == Gmem_tile_dk::LDGS);
static_assert(Smem_tile_dk::THREADS_PER_ROW == Gmem_tile_dk::THREADS_PER_ROW);
// The shared memory tile to reload Q transposed.
using Smem_tile_qt = fmha::Smem_tile_b<Cta_tile_dk, fmha::Row, Gmem_tile_q::BYTES_PER_LDG, 1>;
// The global memory tile to load dP, stored in S
using Gmem_tile_s = Gmem_tile_mma_s<Cta_tile_p>;
// The shared memory tile to transpose dP.
using Smem_tile_st = Smem_tile_mma_transposed<Cta_tile_p>;
using Noloop = Noloop_traits<CHUNKS, Cta_tile_p>;
enum { M = Mma_tile_p::MMAS_M };
enum { N = Mma_tile_p::MMAS_N };
static_assert(M == Mma_tile_o::MMAS_M);
static_assert(N == Mma_tile_o::MMAS_K);
// Shared memory.
extern __shared__ char smem_[];
const int bidc = blockIdx.z;
// The block index for the batch.
const int bidb = blockIdx.y;
// The block index for the head.
const int bidh = blockIdx.x;
// The thread index.
const int tidx = threadIdx.x;
const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);
if( binfo.stop_early() )
return;
fmha::Mask<Cta_tile_p> mask(params, binfo, tidx);
// Allocate the global memory tile loader for Q.
Gmem_tile_q gmem_q(params, 0, binfo, tidx);
// Allocate the shared memory tile loader for Q (as B).
Smem_tile_qt smem_qt(&smem_[0], tidx);
// Allocate the global memory tile loader for dP.
Gmem_tile_s gmem_s(params, binfo, tidx);
// Allocate the shared memory tile loader for dP.
Smem_tile_st smem_s(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE], tidx);
// Allocate the global memory tile loader for K.
Gmem_tile_k gmem_k(params, 1, binfo, tidx);
// Allocate the shared memory tile loader for K.
Smem_tile_k smem_k(&smem_[Smem_tile_q::BYTES_PER_TILE], tidx);
// Allocate the global memory tile loader for O.
Gmem_tile_o gmem_o(params, binfo, tidx);
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
Smem_tile_o smem_o(&smem_[Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE], tidx);
Noloop nl_traits(bidc, binfo);
nl_traits.move_all(gmem_q, gmem_o, gmem_s);
// Trigger the loads for Q.
gmem_q.load(smem_qt);
// Trigger the loads for K.
gmem_k.load(smem_k);
uint4 s_regs[M][N];
gmem_s.load(s_regs, mask);
// Commit the data for Q and K to shared memory.
gmem_q.commit(smem_qt);
gmem_k.commit(smem_k);
// Make sure the data is in shared memory.
__syncthreads();
typename Smem_tile_qt::Fragment frag_qt[2][Mma_tile_dk::MMAS_N];
smem_qt.load(frag_qt[0], 0);
typename Smem_tile_k::Fragment frag_k[2][Mma_tile_o::MMAS_N];
smem_k.load(frag_k[0], 0);
enum { BITS_PER_ELT_S = sizeof(fmha::A_type) * 8 };
enum { THREADS_PER_ROW = 32 };
// Declare the accumulators for the 2nd gemm.
fmha::Fragment_accumulator acc_dk[Mma_tile_dk::MMAS_M][Mma_tile_dk::MMAS_N];
fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_dk::WARPS_K>::apply(acc_dk);
// Load over the entire sequence length.
for(int l=0;l < nl_traits.num_steps_; l++) {
// Pack dP as Fragment_a
fmha::Fragment_a<fmha::Row> frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M];
#pragma unroll
for( int mi = 0; mi < M; mi++ ) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
uint4 &dst = s_regs[mi][ni];
frag_p[ni][mi].reg(0) = dst.x;
frag_p[ni][mi].reg(1) = dst.z;
frag_p[ni][mi].reg(2) = dst.y;
frag_p[ni][mi].reg(3) = dst.w;
}
}
smem_s.store(s_regs);
if(l < nl_traits.num_steps_- 1) {
// Load next part of S
gmem_s.move();
gmem_s.load(s_regs, mask);
// Trigger the load for the next Q values.
smem_qt.move_to_next_write_buffer();
gmem_q.move();
gmem_q.load(smem_qt);
}
// Declare the accumulators for the 1st gemm.
fmha::Fragment_accumulator acc_o[Mma_tile_o::MMAS_M][Mma_tile_o::MMAS_N];
fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_o::WARPS_K>::apply(acc_o);
// Do this part of O = P^T * V^T. dQ = dP x dK
#pragma unroll
for( int ki = 1; ki < Mma_tile_o::MMAS_K; ++ki ) {
// Trigger the load from shared memory for the next series of Q values.
smem_k.load(frag_k[ki & 1], ki);
// Do the math for the values already in registers.
fmha::gemm(acc_o, frag_p[ki - 1], frag_k[(ki - 1) & 1]);
}
// Do the final stage of math.
{
int ki = Mma_tile_o::MMAS_K;
fmha::gemm(acc_o, frag_p[ki - 1], frag_k[(ki - 1) & 1]);
}
static_assert(Gmem_tile_o::LOOPS == 1); //DEBUG
// Loop over MMAS_M.
#pragma unroll
for( int ii = 0; ii < Gmem_tile_o::LOOPS; ++ii ) {
// Swizzle the elements and do the final reduction.
smem_o.store(acc_o, ii);
// Make sure the data is in shared memory.
__syncthreads();
// Load from shared memory.
uint4 out[Gmem_tile_o::STGS_PER_LOOP];
smem_o.load(out);
// Make sure the data was read from shared memory.
if( ii < Gmem_tile_o::LOOPS - 1 ) {
__syncthreads();
}
// Output the values.
gmem_o.store(out, ii);
}
// Move to the next part of the output.
gmem_o.move();
typename Smem_tile_st::Fragment frag_s[Mma_tile_dk::MMAS_K][Mma_tile_dk::MMAS_M];
smem_s.load(frag_s);
static_assert(Mma_tile_dk::MMAS_K == 1); // DEBUG
#pragma unroll
for( int ki = 1; ki < Mma_tile_dk::MMAS_K; ++ki ) {
// Trigger the load from shared memory for the next series of Q values.
smem_qt.load(frag_qt[ki & 1], ki);
// Do the math for the values already in registers.
fmha::gemm(acc_dk, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]);
}
// Do the final stage of math.
{
int ki = Mma_tile_dk::MMAS_K;
fmha::gemm(acc_dk, frag_s[(ki - 1)], frag_qt[(ki - 1) & 1]);
}
// Commit the values for Q into shared memory.
if(l < nl_traits.num_steps_- 1) {
gmem_q.commit(smem_qt);
__syncthreads();
// Trigger the loads for the values of Q for the next iteration.
smem_qt.load(frag_qt[0], 0);
smem_k.load(frag_k[0], 0);
}
} // Outer loop over the sequence length.
// Epilogue for dK = dP' * dq. We're fully exposed to this!
// Epilogue swizzle for dK
Smem_tile_dk smem_dk(&smem_[0], tidx);
smem_dk.store(acc_dk);
__syncthreads();
uint4 dk_out[Smem_tile_dk::NUM_LDS];
smem_dk.load(dk_out);
Qkv_params dk_params;
dk_params.qkv_ptr = params.dkv_ptr;
dk_params.qkv_stride_in_bytes = params.h * 2 * CHUNKS * params.d * sizeof(half);
dk_params.h = params.h;
Gmem_tile_dk gmem_dk(dk_params, nl_traits.get_idx_dk(), binfo, tidx);
gmem_dk.store(dk_out);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include "fmha.h"
#include "fmha_fprop_kernel_1xN.h"
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u>;
template<bool Is_training>
__global__
void fmha_fprop_fp16_128_64_sm80_kernel(Fused_multihead_attention_fprop_params params,
const int num_full_heads,
const int num_main_groups,
const int main_group_size,
const int main_steps,
const int rest_steps) {
fmha::device_1xN<Kernel_traits, Is_training>(
params, num_full_heads, num_main_groups, main_group_size, main_steps, rest_steps);
}
void run_fmha_fp16_128_64_sm80(Launch_params<Fused_multihead_attention_fprop_params> &launch_params, const bool configure) {
auto kernel = launch_params.is_training ? &fmha_fprop_fp16_128_64_sm80_kernel<true> : &fmha_fprop_fp16_128_64_sm80_kernel<false>;
constexpr int smem_size = fmha::get_dynamic_smem_size<Kernel_traits>();
if( smem_size >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
const int sm_count = launch_params.props->multiProcessorCount;
int ctas_per_sm;
FMHA_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size));
int total_ctas = sm_count * ctas_per_sm;
if(configure) {
const int heads_total = launch_params.params.b * launch_params.params.h;
std::tie(launch_params.num_full_heads,
launch_params.num_main_groups,
launch_params.heads_last_wave,
launch_params.main_steps,
launch_params.rest_steps,
launch_params.elts_per_thread) = fmha::work_dist<Kernel_traits>(total_ctas, heads_total);
return;
}
dim3 grid(total_ctas);
kernel<<<grid, Kernel_traits::THREADS, smem_size, launch_params.stream>>>(
launch_params.params,
launch_params.num_full_heads,
launch_params.num_main_groups,
launch_params.heads_last_wave,
launch_params.main_steps,
launch_params.rest_steps);
FMHA_CHECK_CUDA(cudaPeekAtLastError());
}
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include "fmha.h"
#include "fmha_fprop_kernel_1xN.h"
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>;
template<bool Is_training>
__global__
void fmha_fprop_fp16_256_64_sm80_kernel(Fused_multihead_attention_fprop_params params,
const int num_full_heads,
const int num_main_groups,
const int main_group_size,
const int main_steps,
const int rest_steps) {
fmha::device_1xN<Kernel_traits, Is_training>(
params, num_full_heads, num_main_groups, main_group_size, main_steps, rest_steps);
}
void run_fmha_fp16_256_64_sm80(Launch_params<Fused_multihead_attention_fprop_params> &launch_params, const bool configure) {
auto kernel = launch_params.is_training ? &fmha_fprop_fp16_256_64_sm80_kernel<true> : &fmha_fprop_fp16_256_64_sm80_kernel<false>;
constexpr int smem_size = fmha::get_dynamic_smem_size<Kernel_traits>();
if( smem_size >= 48 * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
const int sm_count = launch_params.props->multiProcessorCount;
int ctas_per_sm;
FMHA_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size));
int total_ctas = sm_count * ctas_per_sm;
if(configure) {
const int heads_total = launch_params.params.b * launch_params.params.h;
std::tie(launch_params.num_full_heads,
launch_params.num_main_groups,
launch_params.heads_last_wave,
launch_params.main_steps,
launch_params.rest_steps,
launch_params.elts_per_thread) = fmha::work_dist<Kernel_traits>(total_ctas, heads_total);
return;
}
dim3 grid(total_ctas);
kernel<<<grid, Kernel_traits::THREADS, smem_size, launch_params.stream>>>(
launch_params.params,
launch_params.num_full_heads,
launch_params.num_main_groups,
launch_params.heads_last_wave,
launch_params.main_steps,
launch_params.rest_steps);
FMHA_CHECK_CUDA(cudaPeekAtLastError());
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment