Unverified Commit b8be1bc7 authored by Burc Eryilmaz's avatar Burc Eryilmaz Committed by GitHub
Browse files

initial cublaslt support for MLP (#1080)



* initial cublaslt support

* 64 bit input

* add license headers

* cleanup

* remove license
Co-authored-by: default avatarpbialecki <pbialecki@nvidia.com>
parent b5eb38db
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#include <stdio.h> #include <stdio.h>
size_t get_mlp_reserved_space(int batch_size, int num_layers, const int* output_features); size_t get_mlp_reserved_space(int64_t batch_size, int num_layers, const int* output_features);
template <typename T> template <typename T>
size_t get_mlp_bp_workspace_in_bytes(int batch_size, int num_layers, const int* output_features); size_t get_mlp_bp_workspace_in_bytes(int batch_size, int num_layers, const int* output_features);
...@@ -21,7 +21,8 @@ int mlp_fp( ...@@ -21,7 +21,8 @@ int mlp_fp(
T* Y, T* Y,
T* reserved_space, T* reserved_space,
int use_bias, int use_bias,
int activation); int activation,
void* lt_workspace);
template <typename T> template <typename T>
int mlp_bp( int mlp_bp(
...@@ -60,9 +61,10 @@ std::vector<at::Tensor> mlp_forward(int use_bias, int activation, std::vector<at ...@@ -60,9 +61,10 @@ std::vector<at::Tensor> mlp_forward(int use_bias, int activation, std::vector<at
auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data()); auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
// create output/workspace tensor // create output/workspace tensor
// TODO(deyuf): just get buffer?
auto out = at::empty({batch_size, output_features.back()}, inputs[0].type()); auto out = at::empty({batch_size, output_features.back()}, inputs[0].type());
auto reserved_space = at::empty({reserved_size}, inputs[0].type()); auto reserved_space = at::empty({reserved_size}, inputs[0].type());
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto lt_workspace = at::empty({1 << 22}, inputs[0].type());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(inputs[0].type(), "mlp_forward", [&] { AT_DISPATCH_FLOATING_TYPES_AND_HALF(inputs[0].type(), "mlp_forward", [&] {
std::vector<scalar_t*> w_ptr; std::vector<scalar_t*> w_ptr;
...@@ -84,7 +86,8 @@ std::vector<at::Tensor> mlp_forward(int use_bias, int activation, std::vector<at ...@@ -84,7 +86,8 @@ std::vector<at::Tensor> mlp_forward(int use_bias, int activation, std::vector<at
out.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(),
reserved_space.data_ptr<scalar_t>(), reserved_space.data_ptr<scalar_t>(),
use_bias, use_bias,
activation); activation,
(void*) (lt_workspace.data_ptr<scalar_t>()));
}); });
return {out, reserved_space}; return {out, reserved_space};
...@@ -106,7 +109,6 @@ std::vector<at::Tensor> mlp_backward( ...@@ -106,7 +109,6 @@ std::vector<at::Tensor> mlp_backward(
auto batch_size = inputs[0].size(0); auto batch_size = inputs[0].size(0);
auto input_features = inputs[0].size(1); auto input_features = inputs[0].size(1);
// TODO: not creating empty tensor for it?
bool requires_grad = inputs[0].requires_grad(); bool requires_grad = inputs[0].requires_grad();
std::vector<int> output_features; std::vector<int> output_features;
...@@ -114,7 +116,6 @@ std::vector<at::Tensor> mlp_backward( ...@@ -114,7 +116,6 @@ std::vector<at::Tensor> mlp_backward(
output_features.push_back(inputs[i + 1].size(0)); output_features.push_back(inputs[i + 1].size(0));
} }
// create outputs, length of inputs // create outputs, length of inputs
// TODO: not create bias if not needed
std::vector<at::Tensor> outputs; std::vector<at::Tensor> outputs;
for (int i = 0; i < inputs.size(); i++) { for (int i = 0; i < inputs.size(); i++) {
outputs.push_back(at::empty(inputs[i].sizes(), inputs[i].type())); // clone for testing now outputs.push_back(at::empty(inputs[i].sizes(), inputs[i].type())); // clone for testing now
...@@ -162,3 +163,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -162,3 +163,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &mlp_forward, "MLP forward"); m.def("forward", &mlp_forward, "MLP forward");
m.def("backward", &mlp_backward, "MLP backward"); m.def("backward", &mlp_backward, "MLP backward");
} }
...@@ -10,6 +10,9 @@ ...@@ -10,6 +10,9 @@
#include <cublas_v2.h> #include <cublas_v2.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
// includes cublaslt
#include <cublasLt.h>
// constants for fused bias+relu kernel // constants for fused bias+relu kernel
#define BIAS_RELU_FW_NTHREADS 128 // forward number of thread per block #define BIAS_RELU_FW_NTHREADS 128 // forward number of thread per block
#define BIAS_RELU_BW_NTHREADS_X 32 // backward number of thread in feature dim #define BIAS_RELU_BW_NTHREADS_X 32 // backward number of thread in feature dim
...@@ -165,6 +168,268 @@ cublasStatus_t mlp_gemm( ...@@ -165,6 +168,268 @@ cublasStatus_t mlp_gemm(
CUBLAS_GEMM_DEFAULT_TENSOR_OP); CUBLAS_GEMM_DEFAULT_TENSOR_OP);
} }
int mlp_gemm_lt(
cublasLtHandle_t ltHandle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
float *alpha, /* host pointer */
const at::Half* A,
int lda,
const at::Half* B,
int ldb,
float *beta, /* host pointer */
at::Half* C,
int ldc,
void *workspace,
size_t workspaceSize,
cudaStream_t stream,
bool use_bias,
bool use_relu,
const void* bias) {
cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
cublasLtMatmulDescOpaque_t operationDesc = {};
cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};
cublasLtMatmulPreferenceOpaque_t preference = {};
int returnedResults = 0;
cublasLtMatmulHeuristicResult_t heuristicResult = {};
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;
// Create operation descriptor; see cublasLtMatmulDescAttributes_t
// for details about defaults; here we just set the transforms for
// A and B.
status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
if (use_bias) {
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias));
if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP;
}
if (use_relu) {
epilogue = CUBLASLT_EPILOGUE_RELU_BIAS;
} else {
epilogue = CUBLASLT_EPILOGUE_BIAS;
}
} else {
if (use_relu) {
epilogue = CUBLASLT_EPILOGUE_RELU;
}
}
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));
if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP;
}
// Create matrix descriptors. Not setting any extra attributes.
status = cublasLtMatrixLayoutInit(
&Adesc, CUDA_R_16F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatrixLayoutInit(
&Bdesc, CUDA_R_16F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_16F, m, n, ldc);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
// Create preference handle; In general, extra attributes can be
// used here to disable tensor ops or to make sure algo selected
// will work with badly aligned A, B, C. However, for simplicity
// here we assume A,B,C are always well aligned (e.g., directly
// come from cudaMalloc)
status = cublasLtMatmulPreferenceInit(&preference);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatmulPreferenceSetAttribute(
&preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
// We just need the best available heuristic to try and run matmul.
// There is no guarantee that this will work. For example, if A is
// badly aligned, you can request more (e.g. 32) algos and try to
// run them one by one until something works.
status = cublasLtMatmulAlgoGetHeuristic(
ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
if (returnedResults == 0) {
status = CUBLAS_STATUS_NOT_SUPPORTED;
goto CLEANUP;
}
status = cublasLtMatmul(ltHandle,
&operationDesc,
alpha,
A,
&Adesc,
B,
&Bdesc,
beta,
C,
&Cdesc,
C,
&Cdesc,
&heuristicResult.algo,
workspace,
workspaceSize,
stream);
CLEANUP:
// Descriptors are no longer needed as all GPU work was already
// enqueued.
return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
}
int mlp_gemm_lt(
cublasLtHandle_t ltHandle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
float *alpha, /* host pointer */
const double* A,
int lda,
const double* B,
int ldb,
float *beta, /* host pointer */
double* C,
int ldc,
void *workspace,
size_t workspaceSize,
cudaStream_t stream,
bool use_bias,
bool use_relu,
const void* bias) {
return 1;
}
int mlp_gemm_lt(
cublasLtHandle_t ltHandle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
float *alpha, /* host pointer */
const float *A,
int lda,
const float *B,
int ldb,
float *beta, /* host pointer */
float *C,
int ldc,
void *workspace,
size_t workspaceSize,
cudaStream_t stream,
bool use_bias,
bool use_relu,
const void* bias) {
cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
cublasLtMatmulDescOpaque_t operationDesc = {};
cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};
cublasLtMatmulPreferenceOpaque_t preference = {};
int returnedResults = 0;
cublasLtMatmulHeuristicResult_t heuristicResult = {};
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;
// Create operation descriptor; see cublasLtMatmulDescAttributes_t
// for details about defaults; here we just set the transforms for
// A and B.
status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
if (use_bias) {
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias));
if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP;
}
if (use_relu) {
epilogue = CUBLASLT_EPILOGUE_RELU_BIAS;
} else {
epilogue = CUBLASLT_EPILOGUE_BIAS;
}
} else {
if (use_relu) {
epilogue = CUBLASLT_EPILOGUE_RELU;
}
}
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));
if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP;
}
// Create matrix descriptors. Not setting any extra attributes.
status = cublasLtMatrixLayoutInit(
&Adesc, CUDA_R_32F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatrixLayoutInit(
&Bdesc, CUDA_R_32F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_32F, m, n, ldc);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
// Create preference handle; In general, extra attributes can be
// used here to disable tensor ops or to make sure algo selected
// will work with badly aligned A, B, C. However, for simplicity
// here we assume A,B,C are always well aligned (e.g., directly
// come from cudaMalloc)
status = cublasLtMatmulPreferenceInit(&preference);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatmulPreferenceSetAttribute(
&preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
// We just need the best available heuristic to try and run matmul.
// There is no guarantee that this will work. For example, if A is
// badly aligned, you can request more (e.g. 32) algos and try to
// run them one by one until something works.
status = cublasLtMatmulAlgoGetHeuristic(
ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
if (returnedResults == 0) {
status = CUBLAS_STATUS_NOT_SUPPORTED;
goto CLEANUP;
}
status = cublasLtMatmul(ltHandle,
&operationDesc,
alpha,
A,
&Adesc,
B,
&Bdesc,
beta,
C,
&Cdesc,
C,
&Cdesc,
&heuristicResult.algo,
workspace,
workspaceSize,
stream);
CLEANUP:
// Descriptors are no longer needed as all GPU work was already
// enqueued.
return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
}
// Bias ADD. Assume input X is [features x batch size], column major. // Bias ADD. Assume input X is [features x batch size], column major.
// Bias is one 'features' long vector, with implicit broadcast. // Bias is one 'features' long vector, with implicit broadcast.
template <typename T> template <typename T>
...@@ -498,7 +763,7 @@ __global__ void biasAdd_bprop( ...@@ -498,7 +763,7 @@ __global__ void biasAdd_bprop(
int nidx = 0; int nidx = 0;
// Handle non-multiple of UNROLL_FACTOR residue // Handle non-multiple of UNROLL_FACTOR residue
for (; nidx < nSpan % UNROLL_FACTOR; nidx++) { for (; nidx < nSpan % UNROLL_FACTOR; nidx++) {
int row, col, flat_idx; int64_t row, col, flat_idx;
row = f; row = f;
col = nStart + nidx; col = nStart + nidx;
flat_idx = col * features + row; flat_idx = col * features + row;
...@@ -507,7 +772,7 @@ __global__ void biasAdd_bprop( ...@@ -507,7 +772,7 @@ __global__ void biasAdd_bprop(
// Handle meat of work // Handle meat of work
for (; (nidx + UNROLL_FACTOR - 1) < nSpan; nidx += UNROLL_FACTOR) { for (; (nidx + UNROLL_FACTOR - 1) < nSpan; nidx += UNROLL_FACTOR) {
int row, col, flat_idx; int64_t row, col, flat_idx;
row = f; row = f;
col = nStart + nidx; col = nStart + nidx;
flat_idx = col * features + row; flat_idx = col * features + row;
...@@ -780,7 +1045,6 @@ __global__ void biasAddRelu_bprop_aligned( ...@@ -780,7 +1045,6 @@ __global__ void biasAddRelu_bprop_aligned(
} }
// block result is in db_local now for all threadIdx.y == 0 // block result is in db_local now for all threadIdx.y == 0
// TODO: maybe not useful early exit here
if(gridDim.y == 1) { if(gridDim.y == 1) {
#pragma unroll #pragma unroll
for(int ii=0;ii<ILP;ii++){ for(int ii=0;ii<ILP;ii++){
...@@ -847,7 +1111,7 @@ void get_y_offsets( ...@@ -847,7 +1111,7 @@ void get_y_offsets(
} }
// Returns the reserved space (in elements) needed for the MLP // Returns the reserved space (in elements) needed for the MLP
size_t get_mlp_reserved_space(int batch_size, int num_layers, const int* output_features) { size_t get_mlp_reserved_space(int64_t batch_size, int num_layers, const int* output_features) {
size_t res_space = 0; size_t res_space = 0;
// Need to store output of every intermediate MLP - size equal to output_features[i] * batch_size // Need to store output of every intermediate MLP - size equal to output_features[i] * batch_size
// for all 'i' in [0, num_layers-1) // for all 'i' in [0, num_layers-1)
...@@ -858,7 +1122,7 @@ size_t get_mlp_reserved_space(int batch_size, int num_layers, const int* output_ ...@@ -858,7 +1122,7 @@ size_t get_mlp_reserved_space(int batch_size, int num_layers, const int* output_
} }
// Returns the size of all fprop activations combined // Returns the size of all fprop activations combined
size_t get_all_activations_size(int batch_size, int num_layers, const int* output_features) { size_t get_all_activations_size(int64_t batch_size, int num_layers, const int* output_features) {
size_t acts_size = 0; size_t acts_size = 0;
for (int l = 0; l < num_layers; l++) { for (int l = 0; l < num_layers; l++) {
acts_size += output_features[l] * batch_size; acts_size += output_features[l] * batch_size;
...@@ -979,7 +1243,8 @@ int mlp_fp( ...@@ -979,7 +1243,8 @@ int mlp_fp(
T* Y, T* Y,
T* reserved_space, T* reserved_space,
int use_bias, int use_bias,
int activation) { int activation,
void* lt_workspace) {
T *weight, *input, *output, *bias; T *weight, *input, *output, *bias;
T *reserved_space_x, *reserved_space_y; T *reserved_space_x, *reserved_space_y;
reserved_space_x = NULL; reserved_space_x = NULL;
...@@ -987,6 +1252,9 @@ int mlp_fp( ...@@ -987,6 +1252,9 @@ int mlp_fp(
// Get cublas handle from Pytorch // Get cublas handle from Pytorch
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cublasLtHandle_t ltHandle;
cublasStatus_t lthandle_status;
lthandle_status = cublasLtCreate(&ltHandle);
// Get the stream from cublas handle to reuse for biasReLU kernel. // Get the stream from cublas handle to reuse for biasReLU kernel.
cudaStream_t stream; cudaStream_t stream;
cublasGetStream(handle, &stream); cublasGetStream(handle, &stream);
...@@ -1004,9 +1272,37 @@ int mlp_fp( ...@@ -1004,9 +1272,37 @@ int mlp_fp(
float one = 1.f; float one = 1.f;
float zero = 0.f; float zero = 0.f;
cublasStatus_t cublas_status; // try with cublaslt first for supported case with valid handle
// Call GEMM: fprop is Y = W'X int cublaslt_status = 1;
cublas_status = mlp_gemm( if(lthandle_status == CUBLAS_STATUS_SUCCESS && activation < 2){
cublaslt_status = mlp_gemm_lt(
ltHandle,
CUBLAS_OP_T,
CUBLAS_OP_N,
ofeat,
batch_size,
ifeat,
&one,
weight,
ifeat,
input,
ifeat,
&zero,
output,
ofeat,
lt_workspace,
1 << 22,
stream,
use_bias == 1,
activation == 1,
bias);
}
// if cublaslt failed or not executed, fallback to cublas
if (cublaslt_status != 0) {
cublasStatus_t cublas_status;
// Call GEMM: fprop is Y = W'X
cublas_status = mlp_gemm(
handle, handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
...@@ -1022,45 +1318,48 @@ int mlp_fp( ...@@ -1022,45 +1318,48 @@ int mlp_fp(
output, output,
ofeat); ofeat);
if (cublas_status != CUBLAS_STATUS_SUCCESS) { if (cublas_status != CUBLAS_STATUS_SUCCESS) {
printf("GEMM fprop failed with %d\n", cublas_status); printf("GEMM fprop failed with %d\n", cublas_status);
return 1; return 1;
}
const uint &input_size = ofeat;
int num_blocks = 0;
int num_SMs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
// Call biasReLU
if(use_bias == 1) {
if (activation == 0) { // no activation
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, biasAdd_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);
biasAdd_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, bias, batch_size, input_size);
} else if (activation == 1) { // relu
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, biasAddRelu_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);
biasAddRelu_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, bias, batch_size, input_size);
} else if (activation == 2) { // sigmoid
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, biasAdd_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);
biasAdd_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, bias, batch_size, input_size);
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Sigmoid_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);
Sigmoid_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, batch_size, input_size);
} }
} else {
// don't need to do anything in case of no activation and no bias const uint &input_size = ofeat;
if (activation == 1) { // relu int num_blocks = 0;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Relu_fprop<T>, BIAS_RELU_FW_NTHREADS, 0); int num_SMs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
Relu_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, batch_size, input_size); // Call biasReLU
} else if (activation == 2) { // sigmoid if(use_bias == 1) {
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Sigmoid_fprop<T>, BIAS_RELU_FW_NTHREADS, 0); if (activation == 0) { // no activation
Sigmoid_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, batch_size, input_size); cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, biasAdd_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);
biasAdd_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, bias, batch_size, input_size);
} else if (activation == 1) { // relu
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, biasAddRelu_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);
biasAddRelu_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, bias, batch_size, input_size);
} else if (activation == 2) { // sigmoid
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, biasAdd_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);
biasAdd_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, bias, batch_size, input_size);
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Sigmoid_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);
Sigmoid_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, batch_size, input_size);
}
} else {
// don't need to do anything in case of no activation and no bias
if (activation == 1) { // relu
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Relu_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);
Relu_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, batch_size, input_size);
} else if (activation == 2) { // sigmoid
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks, Sigmoid_fprop<T>, BIAS_RELU_FW_NTHREADS, 0);
Sigmoid_fprop<<<num_SMs*num_blocks, BIAS_RELU_FW_NTHREADS, 0, stream>>>(output, batch_size, input_size);
}
} }
} }
// Set current output as next layer input // Set current output as next layer input
reserved_space_x = reserved_space_y; reserved_space_x = reserved_space_y;
// Set next layer output // Set next layer output
reserved_space_y += ofeat * batch_size; reserved_space_y += ofeat * batch_size;
} }
if(lthandle_status == CUBLAS_STATUS_SUCCESS) cublasLtDestroy(ltHandle);
return 0; return 0;
} }
...@@ -1281,7 +1580,8 @@ template int mlp_fp<float>( ...@@ -1281,7 +1580,8 @@ template int mlp_fp<float>(
float* Y, float* Y,
float* reserved_space, float* reserved_space,
int use_bias, int use_bias,
int activation); int activation,
void* lt_workspace);
template int mlp_bp<float>( template int mlp_bp<float>(
float* X, float* X,
...@@ -1312,7 +1612,8 @@ template int mlp_fp<at::Half>( ...@@ -1312,7 +1612,8 @@ template int mlp_fp<at::Half>(
at::Half* Y, at::Half* Y,
at::Half* reserved_space, at::Half* reserved_space,
int use_bias, int use_bias,
int activation); int activation,
void* lt_workspace);
template int mlp_bp<at::Half>( template int mlp_bp<at::Half>(
at::Half* X, at::Half* X,
...@@ -1343,7 +1644,8 @@ template int mlp_fp<double>( ...@@ -1343,7 +1644,8 @@ template int mlp_fp<double>(
double* Y, double* Y,
double* reserved_space, double* reserved_space,
int use_bias, int use_bias,
int activation); int activation,
void* lt_workspace);
template int mlp_bp<double>( template int mlp_bp<double>(
double* X, double* X,
...@@ -1375,3 +1677,4 @@ template size_t get_mlp_bp_workspace_in_bytes<double>( ...@@ -1375,3 +1677,4 @@ template size_t get_mlp_bp_workspace_in_bytes<double>(
int batch_size, int batch_size,
int num_layers, int num_layers,
const int* output_features); const int* output_features);
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment