Commit e68ebbe8 authored by Tri Dao's avatar Tri Dao
Browse files

Simplify FusedDense

parent 1bc6e5b0
......@@ -6,6 +6,8 @@
#include <stdio.h>
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
// https://github.com/NVIDIA/apex/blob/master/csrc/type_shim.h
// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#define DISPATCH_HALF_AND_BF16(TYPE, NAME, ...) \
......@@ -24,14 +26,6 @@
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
template <typename T>
int linear_bias_forward_cuda(at::Tensor input, T *weight, at::Tensor bias, int in_features, int batch_size, int out_features, at::Tensor output, void *lt_workspace);
template <typename T>
int linear_bias_backward_cuda(T *input, T *weight, T *d_output, int in_features, int batch_size, int out_features, T *d_weight, T *d_bias, T *d_input, bool residual, void *lt_workspace);
template <typename T>
int linear_bias_wgrad_cuda(T *input, T *d_output, int in_features, int batch_size, int out_features, T *d_weight, T *d_bias, void *lt_workspace);
......@@ -39,103 +33,34 @@ template <typename T>
int linear_gelu_forward_cuda(T *input, T *weight, T *bias, int in_features, int batch_size, int out_features, int heuristic, T *output, T *gelu_in, void *lt_workspace) ;
template <typename T>
int linear_gelu_linear_backward_cuda(T *input, T *gelu_in, T *output1, T *weight1, T *weight2, T *d_output1, T *d_output2, int in_features, int batch_size, int hidden_features, int out_features, int heuristic, T *d_weight1, T *d_weight2, T *d_bias1, T *d_bias2, T *d_input, bool residual, void *lt_workspace);
at::Tensor linear_bias_forward(at::Tensor input, at::Tensor weight, at::Tensor bias) {
auto batch_size = input.size(0);
auto in_features = input.size(1);
int out_features = weight.size(0);
//auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
// create output/workspace tensor
auto out = at::empty({batch_size, out_features}, at::dtype(input.dtype()).device(input.device()));
//auto reserved_space = at::empty({reserved_size}, inputs[0].type());
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto lt_workspace = at::empty({1 << 22}, at::dtype(input.dtype()).device(input.device()));
DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_forward", [&] {
scalar_t* w_ptr = weight.data_ptr<scalar_t>();
auto result = linear_bias_forward_cuda<scalar_t>(
input,
w_ptr,
bias,
in_features,
batch_size,
out_features,
out,
//out.data_ptr<scalar_t>(),
// reserved_space.data_ptr<scalar_t>(),
(void*) (lt_workspace.data_ptr<scalar_t>()));
TORCH_CHECK(result == 0, "linear_bias_forward failed.")
});
return {out};
}
std::vector<at::Tensor> linear_bias_backward(at::Tensor input, at::Tensor weight, at::Tensor d_output) {
auto batch_size = input.size(0);
auto in_features = input.size(1);
int out_features = weight.size(0);
int bias_gelu_linear_dgrad_bgrad_cuda(T *weight, T *d_output, T *gelu_in, int in_features, int batch_size, int out_features, int heuristic, T *d_input, T *d_bias, void *lt_workspace);
//auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
// create output/workspace tensor
auto opts = input.options();
auto d_weight = at::empty({out_features, in_features}, opts);
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION < 11600
auto d_bias = d_output.view({-1, out_features}).sum(0, false);
#else
auto d_bias = at::empty({out_features}, opts);
#endif
auto d_input = at::empty({batch_size, in_features}, opts);
//auto reserved_space = at::empty({reserved_size}, inputs[0].type());
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto lt_workspace = at::empty({1 << 22}, opts);
DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_backward", [&] {
scalar_t* w_ptr = weight.data_ptr<scalar_t>();
auto result = linear_bias_backward_cuda<scalar_t>(
input.data_ptr<scalar_t>(),
w_ptr,
d_output.data_ptr<scalar_t>(),
in_features,
batch_size,
out_features,
d_weight.data_ptr<scalar_t>(),
d_bias.data_ptr<scalar_t>(),
d_input.data_ptr<scalar_t>(),
// reserved_space.data_ptr<scalar_t>(),
/*residual=*/false,
(void*) (lt_workspace.data_ptr<scalar_t>()));
TORCH_CHECK(result == 0, "linear_bias_backward failed.")
});
return {d_input, d_weight, d_bias};
}
std::vector<at::Tensor> linear_bias_wgrad(at::Tensor input, at::Tensor d_output) {
auto batch_size = input.size(0);
auto in_features = input.size(1);
std::vector<at::Tensor> linear_bias_wgrad(at::Tensor input, at::Tensor d_output, bool has_d_bias) {
int batch_size = input.size(0);
int in_features = input.size(1);
int out_features = d_output.size(1);
//auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
TORCH_CHECK(input.dtype() == torch::kFloat16 || input.dtype() == torch::kBFloat16);
TORCH_CHECK(input.dtype() == d_output.dtype());
TORCH_CHECK(input.is_cuda());
TORCH_CHECK(d_output.is_cuda());
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(d_output.is_contiguous());
CHECK_SHAPE(input, batch_size, in_features);
CHECK_SHAPE(d_output, batch_size, out_features);
// create output/workspace tensor
auto opts = input.options();
auto d_weight = at::empty({out_features, in_features}, opts);
at::Tensor d_bias;
if (has_d_bias) {
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION < 11600
auto d_bias = d_output.view({-1, out_features}).sum(0, false);
d_bias = d_output.view({-1, out_features}).sum(0, false);
#else
auto d_bias = at::empty({out_features}, opts);
d_bias = at::empty({out_features}, opts);
#endif
//auto reserved_space = at::empty({reserved_size}, inputs[0].type());
}
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto lt_workspace = at::empty({1 << 22}, opts);
......@@ -147,93 +72,59 @@ std::vector<at::Tensor> linear_bias_wgrad(at::Tensor input, at::Tensor d_output)
batch_size,
out_features,
d_weight.data_ptr<scalar_t>(),
d_bias.data_ptr<scalar_t>(),
// reserved_space.data_ptr<scalar_t>(),
has_d_bias ? d_bias.data_ptr<scalar_t>() : nullptr,
(void*) (lt_workspace.data_ptr<scalar_t>()));
TORCH_CHECK(result == 0, "linear_bias_wgrad failed.")
TORCH_CHECK(result == 0, "linear_bias_wgrad failed.");
});
return {d_weight, d_bias};
}
std::vector<at::Tensor> linear_bias_residual_backward(at::Tensor input, at::Tensor weight, at::Tensor d_output, at::Tensor d_input) {
auto batch_size = input.size(0);
auto in_features = input.size(1);
int out_features = weight.size(0);
//auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
// create output/workspace tensor
auto opts = input.options();
auto d_weight = at::empty({out_features, in_features}, opts);
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION < 11600
auto d_bias = d_output.view({-1, out_features}).sum(0, false);
#else
auto d_bias = at::empty({out_features}, opts);
#endif
CHECK_SHAPE(d_input, batch_size, in_features);
//auto reserved_space = at::empty({reserved_size}, inputs[0].type());
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto lt_workspace = at::empty({1 << 22}, opts);
DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_backward", [&] {
scalar_t* w_ptr = weight.data_ptr<scalar_t>();
auto result = linear_bias_backward_cuda<scalar_t>(
input.data_ptr<scalar_t>(),
w_ptr,
d_output.data_ptr<scalar_t>(),
in_features,
batch_size,
out_features,
d_weight.data_ptr<scalar_t>(),
d_bias.data_ptr<scalar_t>(),
d_input.data_ptr<scalar_t>(),
// reserved_space.data_ptr<scalar_t>(),
/*residual=*/true,
(void*) (lt_workspace.data_ptr<scalar_t>()));
TORCH_CHECK(result == 0, "linear_bias_residual_backward failed.")
});
return {d_input, d_weight, d_bias};
}
std::vector<at::Tensor> linear_gelu_forward(at::Tensor input, at::Tensor weight, at::Tensor bias,
std::vector<at::Tensor> linear_gelu_forward(at::Tensor input, at::Tensor weight,
c10::optional<at::Tensor> bias_,
bool save_gelu_in, int heuristic) {
auto batch_size = input.size(0);
auto in_features = input.size(1);
int batch_size = input.size(0);
int in_features = input.size(1);
int out_features = weight.size(0);
//auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
TORCH_CHECK(input.dtype() == torch::kFloat16 || input.dtype() == torch::kBFloat16);
TORCH_CHECK(input.dtype() == weight.dtype());
TORCH_CHECK(input.is_cuda());
TORCH_CHECK(weight.is_cuda());
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(weight.is_contiguous());
CHECK_SHAPE(input, batch_size, in_features);
CHECK_SHAPE(weight, out_features, in_features);
if (bias_.has_value()) {
auto bias = bias_.value();
TORCH_CHECK(bias.dtype() == input.dtype());
TORCH_CHECK(bias.is_cuda());
TORCH_CHECK(bias.is_contiguous());
CHECK_SHAPE(bias, out_features);
}
// create output/workspace tensor
auto opts = input.options();
auto output = at::empty({batch_size, out_features}, opts);
at::Tensor gelu_in;
if (save_gelu_in) { gelu_in = at::empty({batch_size, out_features}, opts); }
//auto reserved_space = at::empty({reserved_size}, inputs[0].type());
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto lt_workspace = at::empty({1 << 22}, opts);
DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_gelu_forward", [&] {
scalar_t* w_ptr = weight.data_ptr<scalar_t>();
scalar_t* b_ptr = bias.data_ptr<scalar_t>();
auto result = linear_gelu_forward_cuda<scalar_t>(
input.data_ptr<scalar_t>(),
w_ptr,
b_ptr,
weight.data_ptr<scalar_t>(),
bias_.has_value()? bias_.value().data_ptr<scalar_t>() : nullptr,
in_features,
batch_size,
out_features,
heuristic,
output.data_ptr<scalar_t>(),
save_gelu_in ? gelu_in.data_ptr<scalar_t>() : nullptr,
// reserved_space.data_ptr<scalar_t>(),
(void*) (lt_workspace.data_ptr<scalar_t>()));
TORCH_CHECK(result == 0, "linear_gelu_forward failed.")
TORCH_CHECK(result == 0, "linear_gelu_forward failed.");
});
std::vector<at::Tensor> result = {output};
......@@ -241,116 +132,54 @@ std::vector<at::Tensor> linear_gelu_forward(at::Tensor input, at::Tensor weight,
return result;
}
std::vector<at::Tensor> linear_gelu_linear_backward(at::Tensor input, at::Tensor gelu_in, at::Tensor output1, at::Tensor weight1, at::Tensor weight2, at::Tensor d_output2, int heuristic) {
std::vector<at::Tensor> bias_gelu_linear_dgrad_bgrad(
at::Tensor weight, at::Tensor d_output, at::Tensor gelu_in, int heuristic
) {
auto batch_size = input.size(0);
auto in_features = input.size(1);
int hidden_features = weight1.size(0);
int out_features = weight2.size(0);
//auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
int batch_size = d_output.size(0);
int out_features = d_output.size(1);
int in_features = weight.size(1);
TORCH_CHECK(weight.dtype() == torch::kFloat16 || weight.dtype() == torch::kBFloat16);
TORCH_CHECK(weight.dtype() == d_output.dtype());
TORCH_CHECK(weight.dtype() == gelu_in.dtype());
TORCH_CHECK(weight.is_cuda());
TORCH_CHECK(d_output.is_cuda());
TORCH_CHECK(gelu_in.is_cuda());
TORCH_CHECK(weight.is_contiguous());
TORCH_CHECK(d_output.is_contiguous());
TORCH_CHECK(gelu_in.is_contiguous());
CHECK_SHAPE(weight, out_features, in_features);
CHECK_SHAPE(d_output, batch_size, out_features);
CHECK_SHAPE(gelu_in, batch_size, in_features);
// create output/workspace tensor
auto opts = input.options();
auto d_weight1 = at::empty({hidden_features, in_features}, opts);
auto d_weight2 = at::empty({out_features, hidden_features}, opts);
auto d_bias1 = at::empty({hidden_features}, opts);
auto d_bias2 = at::empty({out_features}, opts);
auto opts = weight.options();
auto d_bias = at::empty({in_features}, opts);
auto d_input = at::empty({batch_size, in_features}, opts);
auto d_output1 = at::empty({batch_size, hidden_features}, opts);
//auto reserved_space = at::empty({reserved_size}, inputs[0].type());
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto lt_workspace = at::empty({1 << 22}, opts);
DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_backward", [&] {
//scalar_t* w_ptr = weight.data_ptr<scalar_t>();
//scalar_t* d_b_ptr = d_bias.data_ptr<scalar_t>();
auto result = linear_gelu_linear_backward_cuda<scalar_t>(
input.data_ptr<scalar_t>(),
gelu_in.data_ptr<scalar_t>(),
output1.data_ptr<scalar_t>(),
weight1.data_ptr<scalar_t>(),
weight2.data_ptr<scalar_t>(),
d_output1.data_ptr<scalar_t>(),
d_output2.data_ptr<scalar_t>(),
in_features,
batch_size,
hidden_features,
out_features,
heuristic,
d_weight1.data_ptr<scalar_t>(),
d_weight2.data_ptr<scalar_t>(),
d_bias1.data_ptr<scalar_t>(),
d_bias2.data_ptr<scalar_t>(),
d_input.data_ptr<scalar_t>(),
// reserved_space.data_ptr<scalar_t>(),
/*residual=*/false,
(void*) (lt_workspace.data_ptr<scalar_t>()));
TORCH_CHECK(result == 0, "linear_gelu_linear_backward failed.")
});
return {d_input, d_weight1, d_bias1, d_weight2, d_bias2};
}
std::vector<at::Tensor> linear_residual_gelu_linear_backward(at::Tensor input, at::Tensor gelu_in, at::Tensor output1, at::Tensor weight1, at::Tensor weight2, at::Tensor d_output2, at::Tensor d_input, int heuristic) {
auto batch_size = input.size(0);
auto in_features = input.size(1);
int hidden_features = weight1.size(0);
int out_features = weight2.size(0);
//auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
// create output/workspace tensor
auto opts = input.options();
auto d_weight1 = at::empty({hidden_features, in_features}, opts);
auto d_weight2 = at::empty({out_features, hidden_features}, opts);
auto d_bias1 = at::empty({hidden_features}, opts);
auto d_bias2 = at::empty({out_features}, opts);
CHECK_SHAPE(d_input, batch_size, in_features);
auto d_output1 = at::empty({batch_size, hidden_features}, opts);
//auto reserved_space = at::empty({reserved_size}, inputs[0].type());
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto lt_workspace = at::empty({1 << 22}, opts);
DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_backward", [&] {
//scalar_t* w_ptr = weight.data_ptr<scalar_t>();
//scalar_t* d_b_ptr = d_bias.data_ptr<scalar_t>();
auto result = linear_gelu_linear_backward_cuda<scalar_t>(
input.data_ptr<scalar_t>(),
DISPATCH_HALF_AND_BF16(weight.scalar_type(), "bias_gelu_linear_dgrad_bgrad", [&] {
auto result = bias_gelu_linear_dgrad_bgrad_cuda<scalar_t>(
weight.data_ptr<scalar_t>(),
d_output.data_ptr<scalar_t>(),
gelu_in.data_ptr<scalar_t>(),
output1.data_ptr<scalar_t>(),
weight1.data_ptr<scalar_t>(),
weight2.data_ptr<scalar_t>(),
d_output1.data_ptr<scalar_t>(),
d_output2.data_ptr<scalar_t>(),
in_features,
batch_size,
hidden_features,
out_features,
heuristic,
d_weight1.data_ptr<scalar_t>(),
d_weight2.data_ptr<scalar_t>(),
d_bias1.data_ptr<scalar_t>(),
d_bias2.data_ptr<scalar_t>(),
d_input.data_ptr<scalar_t>(),
// reserved_space.data_ptr<scalar_t>(),
/*residual=*/true,
d_bias.data_ptr<scalar_t>(),
(void*) (lt_workspace.data_ptr<scalar_t>()));
TORCH_CHECK(result == 0, "linear_residual_gelu_linear_backward failed.")
TORCH_CHECK(result == 0, "bias_gelu_linear_dgrad_bgrad failed.");
});
return {d_input, d_weight1, d_bias1, d_weight2, d_bias2};
return {d_input, d_bias};
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("linear_bias_forward", &linear_bias_forward, "linear bias forward");
m.def("linear_bias_backward", &linear_bias_backward, "linear bias backward");
m.def("linear_bias_wgrad", &linear_bias_wgrad, "linear bias wgrad");
m.def("linear_bias_residual_backward", &linear_bias_residual_backward, "linear bias residual backward");
m.def("linear_gelu_forward", &linear_gelu_forward, "linear gelu forward");
m.def("linear_gelu_linear_backward", &linear_gelu_linear_backward, "linear gelu linear backward");
m.def("linear_residual_gelu_linear_backward", &linear_residual_gelu_linear_backward, "linear residual gelu linear backward");
m.def("bias_gelu_linear_dgrad_bgrad", &bias_gelu_linear_dgrad_bgrad, "bias gelu linear dgrad bgrad");
}
......@@ -94,226 +94,6 @@ cublasStatus_t gemm_bias(
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
int gemm_bias_lt(
cublasLtHandle_t ltHandle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const float *alpha, /* host pointer */
at::Half* A,
int lda,
at::Half* B,
int ldb,
const float *beta, /* host pointer */
at::Half* C,
int ldc,
void *workspace,
size_t workspaceSize,
cudaStream_t stream,
bool use_bias,
const void* bias) {
cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
cublasLtMatmulDescOpaque_t operationDesc = {};
cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};
cublasLtMatmulPreferenceOpaque_t preference = {};
int returnedResults = 0;
cublasLtMatmulHeuristicResult_t heuristicResult = {};
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;
// Create operation descriptor; see cublasLtMatmulDescAttributes_t
// for details about defaults; here we just set the transforms for
// A and B.
status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
if (use_bias) {
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias));
if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP;
}
epilogue = CUBLASLT_EPILOGUE_BIAS;
}
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));
if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP;
}
// Create matrix descriptors. Not setting any extra attributes.
status = cublasLtMatrixLayoutInit(
&Adesc, CUDA_R_16F, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatrixLayoutInit(
&Bdesc, CUDA_R_16F, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_16F, m, n, ldc);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
// Create preference handle; In general, extra attributes can be
// used here to disable tensor ops or to make sure algo selected
// will work with badly aligned A, B, C. However, for simplicity
// here we assume A,B,C are always well aligned (e.g., directly
// come from cudaMalloc)
status = cublasLtMatmulPreferenceInit(&preference);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatmulPreferenceSetAttribute(
&preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
// We just need the best available heuristic to try and run matmul.
// There is no guarantee that this will work. For example, if A is
// badly aligned, you can request more (e.g. 32) algos and try to
// run them one by one until something works.
status = cublasLtMatmulAlgoGetHeuristic(
ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
if (returnedResults == 0) {
status = CUBLAS_STATUS_NOT_SUPPORTED;
goto CLEANUP;
}
status = cublasLtMatmul(ltHandle,
&operationDesc,
alpha,
A,
&Adesc,
B,
&Bdesc,
beta,
C,
&Cdesc,
C,
&Cdesc,
//&heuristicResult.algo,
NULL,
workspace,
workspaceSize,
stream);
CLEANUP:
// Descriptors are no longer needed as all GPU work was already
// enqueued.
return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
}
int gemm_bias_lt(
cublasLtHandle_t ltHandle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const float *alpha, /* host pointer */
at::BFloat16* A,
int lda,
at::BFloat16* B,
int ldb,
const float *beta, /* host pointer */
at::BFloat16* C,
int ldc,
void *workspace,
size_t workspaceSize,
cudaStream_t stream,
bool use_bias,
const void* bias) {
cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
cublasLtMatmulDescOpaque_t operationDesc = {};
cublasLtMatrixLayoutOpaque_t Adesc = {}, Bdesc = {}, Cdesc = {};
cublasLtMatmulPreferenceOpaque_t preference = {};
int returnedResults = 0;
cublasLtMatmulHeuristicResult_t heuristicResult = {};
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT;
// Create operation descriptor; see cublasLtMatmulDescAttributes_t
// for details about defaults; here we just set the transforms for
// A and B.
status = cublasLtMatmulDescInit(&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(transa));
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
if (use_bias) {
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias));
if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP;
}
epilogue = CUBLASLT_EPILOGUE_BIAS;
}
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));
if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP;
}
// Create matrix descriptors. Not setting any extra attributes.
status = cublasLtMatrixLayoutInit(
&Adesc, CUDA_R_16BF, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatrixLayoutInit(
&Bdesc, CUDA_R_16BF, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatrixLayoutInit(&Cdesc, CUDA_R_16BF, m, n, ldc);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
// Create preference handle; In general, extra attributes can be
// used here to disable tensor ops or to make sure algo selected
// will work with badly aligned A, B, C. However, for simplicity
// here we assume A,B,C are always well aligned (e.g., directly
// come from cudaMalloc)
status = cublasLtMatmulPreferenceInit(&preference);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
status = cublasLtMatmulPreferenceSetAttribute(
&preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize));
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
// We just need the best available heuristic to try and run matmul.
// There is no guarantee that this will work. For example, if A is
// badly aligned, you can request more (e.g. 32) algos and try to
// run them one by one until something works.
status = cublasLtMatmulAlgoGetHeuristic(
ltHandle, &operationDesc, &Adesc, &Bdesc, &Cdesc, &Cdesc, &preference, 1, &heuristicResult, &returnedResults);
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
if (returnedResults == 0) {
status = CUBLAS_STATUS_NOT_SUPPORTED;
goto CLEANUP;
}
status = cublasLtMatmul(ltHandle,
&operationDesc,
alpha,
A,
&Adesc,
B,
&Bdesc,
beta,
C,
&Cdesc,
C,
&Cdesc,
//&heuristicResult.algo,
NULL,
workspace,
workspaceSize,
stream);
CLEANUP:
// Descriptors are no longer needed as all GPU work was already
// enqueued.
return status == CUBLAS_STATUS_SUCCESS ? 0 : 1;
}
int gemm_bias_gelu_lt(
cublasLtHandle_t ltHandle,
cublasOperation_t transa,
......@@ -332,7 +112,6 @@ int gemm_bias_gelu_lt(
void *workspace,
size_t workspaceSize,
cudaStream_t stream,
bool use_bias,
int heuristic,
const void* gelu_in,
const void* bias) {
......@@ -363,12 +142,14 @@ int gemm_bias_gelu_lt(
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ldc, sizeof(ldc));
}
if (use_bias) {
if (bias != nullptr) {
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias));
if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP;
}
epilogue = save_gelu_in ? CUBLASLT_EPILOGUE_GELU_AUX_BIAS : CUBLASLT_EPILOGUE_GELU_BIAS;
} else {
epilogue = save_gelu_in ? CUBLASLT_EPILOGUE_GELU_AUX : CUBLASLT_EPILOGUE_GELU;
}
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));
......@@ -453,7 +234,6 @@ int gemm_bias_gelu_lt(
void *workspace,
size_t workspaceSize,
cudaStream_t stream,
bool use_bias,
int heuristic,
const void* gelu_in,
const void* bias) {
......@@ -484,12 +264,14 @@ int gemm_bias_gelu_lt(
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &ldc, sizeof(ldc));
}
if (use_bias) {
if (bias != nullptr) {
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias));
if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP;
}
epilogue = save_gelu_in ? CUBLASLT_EPILOGUE_GELU_AUX_BIAS : CUBLASLT_EPILOGUE_GELU_BIAS;
} else {
epilogue = save_gelu_in ? CUBLASLT_EPILOGUE_GELU_AUX : CUBLASLT_EPILOGUE_GELU;
}
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue));
......@@ -574,7 +356,6 @@ int gemm_bgradb_lt(
void *workspace,
size_t workspaceSize,
cudaStream_t stream,
bool use_bias,
const void* bgrad) {
cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
......@@ -596,7 +377,7 @@ int gemm_bgradb_lt(
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
if (use_bias) {
if (bgrad != nullptr) {
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bgrad, sizeof(bgrad));
if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP;
......@@ -684,7 +465,6 @@ int gemm_bgradb_lt(
void *workspace,
size_t workspaceSize,
cudaStream_t stream,
bool use_bias,
const void* bgrad) {
cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
......@@ -706,7 +486,7 @@ int gemm_bgradb_lt(
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(transa));
if (status != CUBLAS_STATUS_SUCCESS) goto CLEANUP;
if (use_bias) {
if (bgrad != nullptr) {
status = cublasLtMatmulDescSetAttribute(&operationDesc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bgrad, sizeof(bgrad));
if (status != CUBLAS_STATUS_SUCCESS) {
goto CLEANUP;
......@@ -1008,132 +788,6 @@ CLEANUP:
#endif
template <typename T>
int linear_bias_forward_cuda(at::Tensor input, T *weight, at::Tensor bias, int in_features, int batch_size, int out_features, at::Tensor output, void *lt_workspace) {
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
// Get the stream from cublas handle to reuse for biasReLU kernel.
cudaStream_t stream;
cublasGetStream(handle, &stream);
const float alpha = 1.0;
const float beta_zero = 0.0;
const float beta_one = 1.0;
int status = 1;
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
status = gemm_bias_lt(
(cublasLtHandle_t)handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
out_features,
batch_size,
in_features,
&alpha, /* host pointer */
weight,
in_features,
input.data_ptr<T>(),
in_features,
&beta_zero, /* host pointer */
output.data_ptr<T>(),
out_features,
lt_workspace,
1 << 22,
stream,
true,
static_cast<const void*>(bias.data_ptr<T>()));
#endif
if (status != 0){
output.copy_(bias);
status = gemm_bias(
handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
out_features,
batch_size,
in_features,
&alpha,
weight,
in_features,
input.data_ptr<T>(),
in_features,
&beta_one,
output.data_ptr<T>(),
out_features);
}
return status;
}
template <typename T>
int linear_bias_backward_cuda(T *input, T *weight, T *d_output, int in_features, int batch_size, int out_features, T *d_weight, T *d_bias, T *d_input, bool residual, void *lt_workspace) {
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
// Get the stream from cublas handle to reuse for biasReLU kernel.
cudaStream_t stream;
cublasGetStream(handle, &stream);
const float alpha = 1.0;
const float beta_zero = 0.0;
const float beta = residual ? 1.0 : 0.0;
int status = 1;
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
status = gemm_bgradb_lt(
(cublasLtHandle_t)handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
in_features,
out_features,
batch_size,
&alpha, /* host pointer */
input,
in_features,
d_output,
out_features,
&beta_zero, /* host pointer */
d_weight,
in_features,
lt_workspace,
1 << 22,
stream,
true,
static_cast<const void*>(d_bias));
#endif
if (status != 0){
status = gemm_bias(
handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
in_features,
out_features,
batch_size,
&alpha,
input,
in_features,
d_output,
out_features,
&beta_zero,
d_weight,
in_features);
}
status = gemm_bias(
handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
in_features,
batch_size,
out_features,
&alpha,
weight,
in_features,
d_output,
out_features,
&beta,
d_input,
in_features);
return status;
}
template <typename T>
int linear_bias_wgrad_cuda(T *input, T *d_output, int in_features, int batch_size, int out_features, T *d_weight, T *d_bias, void *lt_workspace) {
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
......@@ -1162,13 +816,10 @@ int linear_bias_wgrad_cuda(T *input, T *d_output, int in_features, int batch_siz
lt_workspace,
1 << 22,
stream,
true,
static_cast<const void*>(d_bias));
#endif
if (status != 0){
status = gemm_bias(
handle,
CUBLAS_OP_N,
......@@ -1217,7 +868,6 @@ int linear_gelu_forward_cuda(T *input, T *weight, T *bias, int in_features, int
lt_workspace,
1 << 22,
stream,
true,
heuristic,
static_cast<const void*>(gelu_in),
static_cast<const void*>(bias));
......@@ -1228,109 +878,46 @@ int linear_gelu_forward_cuda(T *input, T *weight, T *bias, int in_features, int
}
template <typename T>
int linear_gelu_linear_backward_cuda(T *input, T *gelu_in, T *output1, T *weight1, T *weight2, T *d_output1, T *d_output2, int in_features, int batch_size, int hidden_features, int out_features, int heuristic, T *d_weight1, T *d_weight2, T *d_bias1, T *d_bias2, T *d_input, bool residual, void *lt_workspace) {
int bias_gelu_linear_dgrad_bgrad_cuda(T *weight, T *d_output, T *gelu_in, int in_features, int batch_size, int out_features, int heuristic, T *d_input, T *d_bias, void *lt_workspace) {
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
// Get the stream from cublas handle to reuse for biasReLU kernel.
cudaStream_t stream;
cublasGetStream(handle, &stream);
const float alpha = 1.0;
const float beta_zero = 0.0;
const float beta = residual ? 1.0 : 0.0;
int status = 1;
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
//wgrad for first gemm
status = gemm_bgradb_lt(
(cublasLtHandle_t)handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
hidden_features,
out_features,
batch_size,
&alpha, /* host pointer */
output1,
hidden_features,
d_output2,
out_features,
&beta_zero, /* host pointer */
d_weight2,
hidden_features,
lt_workspace,
1 << 22,
stream,
true,
static_cast<const void*>(d_bias2));
//dgrad for second GEMM
status = gemm_dgelu_bgradb_lt(
(cublasLtHandle_t)handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
hidden_features,
in_features,
batch_size,
out_features,
&alpha, /* host pointer */
weight2,
hidden_features,
d_output2,
weight,
in_features,
d_output,
out_features,
&beta_zero, /* host pointer */
d_output1,
hidden_features,
d_input,
in_features,
lt_workspace,
1 << 22,
stream,
heuristic,
static_cast<const void*>(gelu_in),
static_cast<const void*>(d_bias1));
//wgrad for the first GEMM
status = gemm_bias(
handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
in_features,
hidden_features,
batch_size,
&alpha,
input,
in_features,
d_output1,
hidden_features,
&beta_zero,
d_weight1,
in_features);
//dgrad for the first GEMM
status = gemm_bias(
handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
in_features,
batch_size,
hidden_features,
&alpha,
weight1,
in_features,
d_output1,
hidden_features,
&beta,
d_input,
in_features);
static_cast<const void*>(d_bias));
#endif
return status;
}
template int linear_bias_forward_cuda<at::Half>(at::Tensor input, at::Half *weight, at::Tensor bias, int in_features, int batch_size, int out_features, at::Tensor output, void *lt_workspace);
template int linear_bias_forward_cuda<at::BFloat16>(at::Tensor input, at::BFloat16 *weight, at::Tensor bias, int in_features, int batch_size, int out_features, at::Tensor output, void *lt_workspace);
template int linear_bias_backward_cuda<at::Half>(at::Half *input, at::Half *weight, at::Half *d_output, int in_features, int batch_size, int out_features, at::Half *d_weight, at::Half *d_bias, at::Half *d_input, bool residual, void *lt_workspace) ;
template int linear_bias_backward_cuda<at::BFloat16>(at::BFloat16 *input, at::BFloat16 *weight, at::BFloat16 *d_output, int in_features, int batch_size, int out_features, at::BFloat16 *d_weight, at::BFloat16 *d_bias, at::BFloat16 *d_input, bool residual, void *lt_workspace) ;
template int linear_bias_wgrad_cuda<at::Half>(at::Half *input, at::Half *d_output, int in_features, int batch_size, int out_features, at::Half *d_weight, at::Half *d_bias, void *lt_workspace) ;
template int linear_bias_wgrad_cuda<at::BFloat16>(at::BFloat16 *input, at::BFloat16 *d_output, int in_features, int batch_size, int out_features, at::BFloat16 *d_weight, at::BFloat16 *d_bias, void *lt_workspace) ;
template int linear_gelu_forward_cuda<at::Half>(at::Half *input, at::Half *weight, at::Half *bias, int in_features, int batch_size, int out_features, int heuristic, at::Half *output, at::Half *gelu_in, void *lt_workspace) ;
template int linear_gelu_forward_cuda<at::BFloat16>(at::BFloat16 *input, at::BFloat16 *weight, at::BFloat16 *bias, int in_features, int batch_size, int out_features, int heuristic, at::BFloat16 *output, at::BFloat16 *gelu_in, void *lt_workspace) ;
template int linear_gelu_linear_backward_cuda<at::Half>(at::Half *input, at::Half *gelu_in, at::Half *output1, at::Half *weight1, at::Half *weight2, at::Half *d_output1, at::Half *d_output2, int in_features, int batch_size, int hidden_features, int out_features, int heuristic, at::Half *d_weight1, at::Half *d_weight2, at::Half *d_bias1, at::Half *d_bias2, at::Half *d_input, bool residual, void *lt_workspace);
template int linear_gelu_linear_backward_cuda<at::BFloat16>(at::BFloat16 *input, at::BFloat16 *gelu_in, at::BFloat16 *output1, at::BFloat16 *weight1, at::BFloat16 *weight2, at::BFloat16 *d_output1, at::BFloat16 *d_output2, int in_features, int batch_size, int hidden_features, int out_features, int heuristic, at::BFloat16 *d_weight1, at::BFloat16 *d_weight2, at::BFloat16 *d_bias1, at::BFloat16 *d_bias2, at::BFloat16 *d_input, bool residual, void *lt_workspace);
template int bias_gelu_linear_dgrad_bgrad_cuda<at::Half>(at::Half *weight, at::Half *d_output, at::Half *gelu_in, int in_features, int batch_size, int out_features, int heuristic, at::Half *d_input, at::Half *d_bias, void *lt_workspace);
template int bias_gelu_linear_dgrad_bgrad_cuda<at::BFloat16>(at::BFloat16 *weight, at::BFloat16 *d_output, at::BFloat16 *gelu_in, int in_features, int batch_size, int out_features, int heuristic, at::BFloat16 *d_input, at::BFloat16 *d_bias, void *lt_workspace);
\ No newline at end of file
......@@ -10,9 +10,9 @@ from torch.nn.modules.utils import _pair
from einops import rearrange
try:
from flash_attn.ops.fused_dense import FusedDenseTD
from flash_attn.ops.fused_dense import FusedDense
except ImportError:
FusedDenseTD = None
FusedDense = None
class PatchEmbed(nn.Module):
......@@ -37,10 +37,10 @@ class PatchEmbed(nn.Module):
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
if fused_bias_fc and FusedDenseTD is None:
if fused_bias_fc and FusedDense is None:
raise ImportError('fused_dense is not installed')
linear_cls = nn.Linear if not fused_bias_fc or not bias else FusedDenseTD
linear_cls = nn.Linear if not fused_bias_fc or not bias else FusedDense
self.proj = linear_cls(in_chans * patch_size[0] * patch_size[1], embed_dim, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
......
......@@ -30,9 +30,9 @@ from flash_attn.bert_padding import unpad_input, pad_input
from flash_attn.bert_padding import index_first_axis, index_first_axis_residual
try:
from flash_attn.ops.fused_dense import FusedDenseTD
from flash_attn.ops.fused_dense import FusedDense
except ImportError:
FusedDenseTD = None
FusedDense = None
try:
from flash_attn.ops.layer_norm import dropout_add_layer_norm, layer_norm
......@@ -70,6 +70,8 @@ def create_mlp_cls(config, layer_idx=None, return_residual=False):
activation=partial(F.gelu, approximate=approximate),
return_residual=return_residual)
else:
if FusedDenseGeluDense is None:
raise ImportError('fused_dense is not installed')
mlp_checkpoint_lvl = getattr(config, 'mlp_checkpoint_lvl', 0)
# mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
if isinstance(mlp_checkpoint_lvl, Sequence):
......@@ -168,9 +170,9 @@ class BertPooler(nn.Module):
def __init__(self, config):
super().__init__()
fused_bias_fc = getattr(config, 'fused_bias_fc', False)
if fused_bias_fc and FusedDenseTD is None:
if fused_bias_fc and FusedDense is None:
raise ImportError('fused_dense is not installed')
linear_cls = nn.Linear if not fused_bias_fc else FusedDenseTD
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
self.dense = linear_cls(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
......@@ -188,12 +190,12 @@ class BertPredictionHeadTransform(nn.Module):
def __init__(self, config):
super().__init__()
fused_bias_fc = getattr(config, 'fused_bias_fc', False)
if fused_bias_fc and FusedDenseTD is None:
if fused_bias_fc and FusedDense is None:
raise ImportError('fused_dense is not installed')
self.fused_dropout_add_ln = getattr(config, 'fused_dropout_add_ln', False)
if self.fused_dropout_add_ln and layer_norm is None:
raise ImportError('dropout_add_layer_norm is not installed')
linear_cls = nn.Linear if not fused_bias_fc else FusedDenseTD
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
self.dense = linear_cls(config.hidden_size, config.hidden_size)
approximate = 'tanh' if config.hidden_act in ['gelu_new', 'gelu_fast'] else 'none'
self.transform_act_fn = nn.GELU(approximate=approximate)
......@@ -215,9 +217,9 @@ class BertLMPredictionHead(nn.Module):
def __init__(self, config):
super().__init__()
fused_bias_fc = getattr(config, 'fused_bias_fc', False)
if fused_bias_fc and FusedDenseTD is None:
if fused_bias_fc and FusedDense is None:
raise ImportError('fused_dense is not installed')
linear_cls = nn.Linear if not fused_bias_fc else FusedDenseTD
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
self.transform = BertPredictionHeadTransform(config)
......
......@@ -61,6 +61,8 @@ def create_mlp_cls(config, layer_idx=None):
assert layer_idx is not None
mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
if fused_dense_gelu_dense:
if FusedDenseGeluDense is None:
raise ImportError('fused_dense is not installed')
mlp_cls = partial(FusedDenseGeluDense, hidden_features=inner_dim,
checkpoint_lvl=mlp_checkpoint_lvl)
elif fused_dense_sqrelu_dense:
......
......@@ -21,9 +21,9 @@ except ImportError:
flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None
try:
from flash_attn.ops.fused_dense import FusedDenseTD, FusedDenseResidual
from flash_attn.ops.fused_dense import FusedDense
except ImportError:
FusedDenseTD, FusedDenseResidual = None, None
FusedDense = None
try:
from flash_attn.layers.rotary import RotaryEmbedding
......@@ -270,7 +270,7 @@ class CrossAttention(nn.Module):
class LinearResidual(nn.Linear):
"""Wrap nn.Linear to return the residual as well. For compatibility with FusedDenseResidual.
"""Wrap nn.Linear to return the residual as well. For compatibility with FusedDense.
"""
def forward(self, input: torch.Tensor) -> torch.Tensor:
......@@ -311,10 +311,11 @@ class MHA(nn.Module):
assert RotaryEmbedding is not None, 'rotary_emb is not installed'
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base)
if fused_bias_fc and FusedDenseTD is None:
if fused_bias_fc and FusedDense is None:
raise ImportError('fused_dense is not installed')
linear_cls = nn.Linear if not fused_bias_fc else FusedDenseTD
linear_resid_cls = LinearResidual if not fused_bias_fc else FusedDenseResidual
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
linear_resid_cls = (LinearResidual if not fused_bias_fc
else partial(FusedDense, return_residual=True))
if not self.cross_attn:
if not self.return_residual:
self.Wqkv = linear_cls(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs)
......
......@@ -5,11 +5,9 @@ import torch.nn as nn
import torch.nn.functional as F
try:
from flash_attn.ops.fused_dense import fused_dense_gelu_dense_function_td
from flash_attn.ops.fused_dense import fused_dense_res_gelu_dense_function_td
from flash_attn.ops.fused_dense import FusedDenseGeluDense
except ImportError:
fused_dense_gelu_dense_function_td = None
fused_dense_res_gelu_dense_function_td = None
FusedDenseGeluDense = None
class Mlp(nn.Module):
......@@ -30,43 +28,3 @@ class Mlp(nn.Module):
y = self.activation(y)
y = self.fc2(y)
return y if not self.return_residual else (y, x)
class FusedDenseGeluDense(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, bias=True,
checkpoint_lvl=0, heuristic=0, return_residual=False, device=None, dtype=None):
"""
checkpoint_lvl (increasing lvl means slower but more memory saving):
0: no recomputation in the bwd
1: recompute gelu_out in the bwd
2: recompute gelu_in and gelu_out in the bwd
heuristic:
-1: don't fuse gemm + gelu (separate kernel)
0..4: use this heuristic for the algo section in the fused gemm + gelu
For CUDA >= 11.8, you'd want heuristic=0 for both fp16 and bf16 for best perf.
For CUDA <= 11.7, you'd want heuristic=1 for fp16 and heuristic=-1 for bf16.
return_residual: whether to return the input x along with the output. This is for
performance reason: for post-norm architecture, returning the input allows us
to fuse the backward of nn.Linear with the residual connection.
"""
assert checkpoint_lvl in [0, 1, 2]
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
assert bias == True, "DenseGeluDense module without bias is currently not supported"
assert (fused_dense_gelu_dense_function_td is not None
and fused_dense_res_gelu_dense_function_td is not None), 'fused_dense_lib is not installed'
self.checkpoint_lvl = checkpoint_lvl
self.heuristic = heuristic
self.return_residual = return_residual
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias, **factory_kwargs)
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, **factory_kwargs)
def forward(self, x):
assert x.is_cuda
fn = (fused_dense_gelu_dense_function_td if not self.return_residual
else fused_dense_res_gelu_dense_function_td)
return fn(x, self.fc1.weight, self.fc1.bias, self.fc2.weight, self.fc2.bias,
self.checkpoint_lvl, self.heuristic)
# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/fused_dense/fused_dense.py
# We make it work with pytorch amp and with bfloat16.
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd
# import fused_dense_cuda # from apex
......@@ -11,126 +13,84 @@ import fused_dense_lib as fused_dense_cuda
from flash_attn.ops.gelu_activation import gelu_bwd
# implements fused GEMM+bias in forward pass using mlp_cuda from apex
class FusedDenseFuncTD(torch.autograd.Function):
class FusedDenseFunc(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx, x, weight, bias):
def forward(ctx, x, weight, bias, return_residual=False):
if torch.is_autocast_enabled():
dtype = torch.get_autocast_gpu_dtype()
x, weight, bias = [a.to(dtype=dtype) for a in [x, weight, bias]]
x, weight = [a.to(dtype=dtype) for a in [x, weight]]
bias = bias.to(dtype=dtype) if bias is not None else None
ctx.return_residual = return_residual
x = x.contiguous()
weight = weight.contiguous()
bias = bias.contiguous()
ctx.save_for_backward(x, weight)
batch_shape, n = x.shape[:-1], x.shape[-1]
batch_dim = batch_shape.numel()
assert batch_dim <= 64 * 1024, 'fused_dense only supports dimension at most 64k'
output = fused_dense_cuda.linear_bias_forward(x.reshape(batch_dim, n), weight, bias)
return output.reshape(*batch_shape, output.shape[-1])
output = F.linear(x, weight, bias)
return output if not return_residual else (output, x)
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
def backward(ctx, grad_output, *args):
grad_output = grad_output.contiguous()
if ctx.return_residual:
grad_input, = args
grad_input = grad_input.contiguous()
x, weight = ctx.saved_tensors
batch_shape, n = x.shape[:-1], x.shape[-1]
batch_dim = batch_shape.numel()
if ctx.needs_input_grad[0]:
grad_input, grad_weight, grad_bias = fused_dense_cuda.linear_bias_backward(
x.reshape(batch_dim, n), weight, grad_output.reshape(batch_dim, grad_output.shape[-1])
grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
if ctx.needs_input_grad[1]:
grad_weight, grad_bias = fused_dense_cuda.linear_bias_wgrad(
x.reshape(batch_dim, n), grad_output, ctx.needs_input_grad[2]
)
else:
grad_weight = None
grad_bias = grad_output if ctx.needs_input_grad[2] else None
if ctx.needs_input_grad[0]:
if not ctx.return_residual:
grad_input = F.linear(grad_output, weight.t())
else:
grad_input = torch.addmm(grad_input.reshape(batch_dim, n), grad_output, weight)
grad_input = grad_input.reshape_as(x)
else:
grad_weight, grad_bias = fused_dense_cuda.linear_bias_wgrad(
x.reshape(batch_dim, n), grad_output.reshape(batch_dim, grad_output.shape[-1])
)
grad_input = None
# print((grad_bias - grad_output.view(-1, grad_output.shape[-1]).sum(dim=0)).abs().max())
return grad_input, grad_weight, grad_bias
# grad_input, grad_weight = None, None
# grad_output_reshaped = grad_output.reshape(batch_dim, grad_output.shape[-1])
# if ctx.needs_input_grad[0]:
# grad_input = (grad_output_reshaped @ weight.conj()).reshape(*batch_shape, n)
# if ctx.needs_input_grad[1]:
# grad_weight = grad_output_reshaped.t() @ x.conj().reshape(batch_dim, n)
# # We don't need to compute grad_bias explicitly, when we return grad_out Pytorch
# # will sum over the batch dimension to get grad_bias.
# return grad_input, grad_weight, grad_output
return grad_input, grad_weight, grad_bias, None
fused_dense_function_td = FusedDenseFuncTD.apply
def fused_dense_func(x: Tensor, weight: Tensor, bias: Optional[Tensor] = None,
return_residual: bool = False):
batch_dim = x.shape[:-1].numel()
dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16]
or (x.dtype == torch.float32 and torch.is_autocast_enabled()))
if (x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and batch_dim <= 64 * 1024
and dtype_eligible):
return FusedDenseFunc.apply(x, weight, bias, return_residual)
else:
out = F.linear(x, weight, bias)
return out if not return_residual else (out, x)
class FusedDenseTD(nn.Linear):
class FusedDense(nn.Linear):
def __init__(self, in_features: int, out_features: int, bias: bool = True,
device=None, dtype=None) -> None:
return_residual: bool = False, device=None, dtype=None) -> None:
super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype)
self.return_residual = return_residual
def forward(self, x):
if x.is_cuda and self.bias is not None:
return fused_dense_function_td(x, self.weight, self.bias)
else:
return F.linear(x, self.weight, self.bias)
return fused_dense_func(x, self.weight, self.bias, return_residual=self.return_residual)
class FusedDenseResidualFunc(torch.autograd.Function):
class FusedDenseGeluDenseFunc(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx, x, weight, bias):
if torch.is_autocast_enabled():
dtype = torch.get_autocast_gpu_dtype()
x, weight, bias = [a.to(dtype=dtype) for a in [x, weight, bias]]
x = x.contiguous()
x = x.contiguous()
weight = weight.contiguous()
bias = bias.contiguous()
ctx.save_for_backward(x, weight)
batch_shape, n = x.shape[:-1], x.shape[-1]
batch_dim = batch_shape.numel()
assert batch_dim <= 64 * 1024, 'fused_dense only supports dimension at most 64k'
output = fused_dense_cuda.linear_bias_forward(x.reshape(batch_dim, n), weight, bias)
return output.reshape(*batch_shape, output.shape[-1]), x
@staticmethod
@custom_bwd
def backward(ctx, grad_output, grad_input):
grad_output = grad_output.contiguous()
grad_input = grad_input.contiguous()
x, weight = ctx.saved_tensors
batch_shape, n = x.shape[:-1], x.shape[-1]
batch_dim = batch_shape.numel()
grad_input, grad_weight, grad_bias = fused_dense_cuda.linear_bias_residual_backward(
x.reshape(batch_dim, n), weight, grad_output.reshape(batch_dim, grad_output.shape[-1]),
grad_input.reshape(batch_dim, n)
)
return grad_input.reshape_as(x), grad_weight, grad_bias
fused_dense_residual_function = FusedDenseResidualFunc.apply
class FusedDenseResidual(nn.Linear):
"""Similar to FusedDense, but we return both the output and the input.
This is so that in the backward pass, we can combine the input gradient from the residual branch
with the input gradient from the matrix multiply, without having to do a separate addition.
"""
def forward(self, x):
if x.is_cuda and self.bias is not None:
return fused_dense_residual_function(x, self.weight, self.bias)
else:
return F.linear(x, self.weight, self.bias), x
class FusedDenseGeluDenseFuncTD(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx, x, weight1, bias1, weight2, bias2, checkpoint_lvl=0, heuristic=0):
def forward(ctx, x, weight1, bias1, weight2, bias2, save_gelu_in=True, return_residual=False,
checkpoint_lvl=0, heuristic=0):
"""checkpoint_lvl:
0: no recomputation in the bwd
1: recompute gelu_out in the bwd
......@@ -139,49 +99,53 @@ class FusedDenseGeluDenseFuncTD(torch.autograd.Function):
assert -1 <= heuristic <= 4
if torch.is_autocast_enabled():
dtype = torch.get_autocast_gpu_dtype()
x, weight1, bias1, weight2, bias2 = [a.to(dtype=dtype)
for a in [x, weight1, bias1, weight2, bias2]]
x, weight1, weight2 = [a.to(dtype=dtype) for a in [x, weight1, weight2]]
bias1 = bias1.to(dtype=dtype) if bias1 is not None else None
bias2 = bias2.to(dtype=dtype) if bias2 is not None else None
if not save_gelu_in:
checkpoint_lvl = 2
assert checkpoint_lvl in [0, 1, 2]
ctx.return_residual = return_residual
x = x.contiguous()
weight1 = weight1.contiguous()
bias1 = bias1.contiguous()
bias1 = bias1.contiguous() if bias1 is not None else None
weight2 = weight2.contiguous()
bias2 = bias2.contiguous()
bias2 = bias2.contiguous() if bias2 is not None else None
batch_shape, n = x.shape[:-1], x.shape[-1]
batch_dim = batch_shape.numel()
assert batch_dim <= 64 * 1024, 'fused_dense only supports dimension at most 64k'
# output1, output2, gelu_in = fused_dense_cuda.linear_gelu_linear_forward(
# x.reshape(batch_dim, n), weight1, bias1, weight2, bias2
# )
if heuristic == -1:
gelu_in = fused_dense_cuda.linear_bias_forward(x.reshape(batch_dim, n), weight1, bias1)
gelu_in = F.linear(x, weight1, bias1)
output1 = F.gelu(gelu_in, approximate='tanh')
# gelu_in = F.linear(x.reshape(batch_dim, n), weight1) # This is before adding bias1
# with torch.jit.fuser('fuser2'):
# output1 = bias_gelu(gelu_in, bias1)
else:
save_gelu_in = checkpoint_lvl != 2
output1, *rest = fused_dense_cuda.linear_gelu_forward(x.reshape(batch_dim, n), weight1,
bias1, save_gelu_in, heuristic)
if save_gelu_in:
gelu_in = rest[0]
output2 = fused_dense_cuda.linear_bias_forward(output1, weight2, bias2)
output2 = F.linear(output1, weight2, bias2)
ctx.checkpoint_lvl = checkpoint_lvl
ctx.heuristic = heuristic
if checkpoint_lvl == 0:
ctx.save_for_backward(x, weight1, bias1, weight2, gelu_in, output1)
ctx.save_for_backward(x, weight1, weight2, gelu_in, output1)
elif checkpoint_lvl == 1:
ctx.save_for_backward(x, weight1, bias1, weight2, gelu_in)
ctx.save_for_backward(x, weight1, weight2, gelu_in)
elif checkpoint_lvl == 2:
ctx.save_for_backward(x, weight1, bias1, weight2)
return output2.reshape(*batch_shape, output2.shape[-1])
ctx.save_for_backward(x, weight1, weight2, bias1)
output2 = output2.reshape(*batch_shape, output2.shape[-1])
return output2 if not return_residual else (output2, x)
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
def backward(ctx, grad_output, *args):
grad_output = grad_output.contiguous()
checkpoint_lvl = ctx.checkpoint_lvl
x, weight1, bias1, weight2, *rest = ctx.saved_tensors
if ctx.return_residual:
grad_input, = args
grad_input = grad_input.contiguous()
x, weight1, weight2, *rest = ctx.saved_tensors
batch_shape, n = x.shape[:-1], x.shape[-1]
batch_dim = batch_shape.numel()
if checkpoint_lvl == 0:
......@@ -190,55 +154,88 @@ class FusedDenseGeluDenseFuncTD(torch.autograd.Function):
gelu_in, = rest
output1 = F.gelu(gelu_in, approximate='tanh')
elif checkpoint_lvl == 2:
# bias1, = rest
bias1, = rest
if ctx.heuristic == -1:
gelu_in = fused_dense_cuda.linear_bias_forward(x.reshape(batch_dim, n), weight1, bias1)
gelu_in = F.linear(x, weight1, bias1)
output1 = F.gelu(gelu_in, approximate='tanh')
else:
output1, gelu_in = fused_dense_cuda.linear_gelu_forward(x.reshape(batch_dim, n),
weight1, bias1, True, ctx.heuristic)
output1, gelu_in = fused_dense_cuda.linear_gelu_forward(
x.reshape(batch_dim, n), weight1, bias1, True, ctx.heuristic
)
grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
output1 = output1.reshape(batch_dim, output1.shape[-1])
gelu_in = gelu_in.reshape(batch_dim, gelu_in.shape[-1])
if ctx.needs_input_grad[3]:
grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(
output1, grad_output, ctx.needs_input_grad[4]
)
else:
grad_weight2 = None
grad_bias2 = grad_output if ctx.needs_input_grad[4] else None
if ctx.heuristic == -1:
grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
# grad_output1, grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_backward(output1, weight2, grad_output)
grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(output1, grad_output)
# grad_gelu = matmul_dgelu(grad_output, weight2, gelu_in)
grad_output1 = grad_output @ weight2
grad_output1 = F.linear(grad_output, weight2.t())
with torch.jit.fuser('fuser2'):
grad_gelu = gelu_bwd(grad_output1, gelu_in)
grad_input, grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_backward(
x.reshape(batch_dim, n), weight1, grad_gelu
)
# with torch.jit.fuser('fuser2'):
# grad_gelu, grad_bias1 = bias_gelu_back(grad_output1, gelu_in, bias1)
# grad_input = grad_gelu @ weight1
# grad_weight1 = grad_gelu.reshape(batch_dim, -1).T @ x.reshape(batch_dim, n)
# grad_input, grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_backward(
# x.reshape(batch_dim, n), weight1, grad_gelu
# )
if ctx.needs_input_grad[1]:
grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_wgrad(
x.reshape(batch_dim, n), grad_gelu, ctx.needs_input_grad[2]
)
else:
grad_weight1 = None
grad_bias1 = grad_gelu if ctx.needs_input_grad[2] else None
else:
grad_input, grad_weight1, grad_bias1, grad_weight2, grad_bias2 = fused_dense_cuda.linear_gelu_linear_backward(
x.reshape(batch_dim, n), gelu_in, output1, weight1, weight2,
grad_output.reshape(batch_dim, grad_output.shape[-1]),
ctx.heuristic
# The cublasLt epilogue has to compute both gelu grad and bias grad, we can't
# just compute gelu grad
grad_gelu, grad_bias1 = fused_dense_cuda.bias_gelu_linear_dgrad_bgrad(
weight2, grad_output, gelu_in, ctx.heuristic
)
# grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
# # grad_output1, grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_backward(output1, weight2, grad_output)
# grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(output1, grad_output)
# grad_gelu = matmul_dgelu(grad_output, weight2, gelu_in)
# grad_input, grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_backward(
# x.reshape(batch_dim, n), weight1, grad_gelu
# )
return grad_input.reshape_as(x), grad_weight1, grad_bias1, grad_weight2, grad_bias2, None, None
fused_dense_gelu_dense_function_td = FusedDenseGeluDenseFuncTD.apply
if not ctx.needs_input_grad[2]:
grad_bias1 = None
if ctx.needs_input_grad[1]:
grad_weight1 = F.linear(grad_gelu.t(), x.reshape(batch_dim, n).t())
else:
grad_weight1 = None
if ctx.needs_input_grad[0]:
if not ctx.return_residual:
grad_input = F.linear(grad_gelu, weight1.t())
else:
grad_input = torch.addmm(grad_input.reshape(batch_dim, n), grad_gelu, weight1)
grad_input = grad_input.reshape_as(x)
else:
grad_input = None
return grad_input, grad_weight1, grad_bias1, grad_weight2, grad_bias2, None, None, None, None
def fused_dense_gelu_dense_func(
x: Tensor, weight1: Tensor, weight2: Tensor, bias1: Optional[Tensor] = None,
bias2: Optional[Tensor] = None,
save_gelu_in: bool = True, return_residual: bool = False,
checkpoint_lvl: int = 0, heuristic: int = 0
):
batch_dim = x.shape[:-1].numel()
dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16]
or (x.dtype == torch.float32 and torch.is_autocast_enabled()))
if (x.is_cuda and weight1.is_cuda and weight2.is_cuda and (bias1 is None or bias1.is_cuda)
and (bias2 is None or bias2.is_cuda) and batch_dim <= 64 * 1024
and dtype_eligible):
return FusedDenseGeluDenseFunc.apply(
x, weight1, bias1, weight2, bias2,
save_gelu_in, return_residual, checkpoint_lvl, heuristic
)
else:
gelu_in = F.linear(x, weight1, bias1)
output1 = F.gelu(gelu_in, approximate='tanh')
output2 = F.linear(output1, weight2, bias2)
return output2 if not return_residual else (output2, x)
class FusedDenseGeluDenseTD(nn.Module):
class FusedDenseGeluDense(nn.Module):
def __init__(self, in_features, intermediate_features, out_features=None, bias=True,
checkpoint_lvl=0, heuristic=0, device=None, dtype=None):
def __init__(self, in_features, hidden_features, out_features=None, bias1=True,
bias2=True, return_residual=False, checkpoint_lvl=0, heuristic=0,
device=None, dtype=None):
"""
checkpoint_lvl (increasing lvl means slower but more memory saving):
0: no recomputation in the bwd
......@@ -247,110 +244,26 @@ class FusedDenseGeluDenseTD(nn.Module):
heuristic:
-1: don't fuse gemm + gelu (separate kernel)
0..4: use this heuristic for the algo section in the fused gemm + gelu
For CUDA >= 11.8, you'd want heuristic=0 for both fp16 and bf16 for best perf.
For CUDA <= 11.7, you'd want heuristic=1 for fp16 and heuristic=-1 for bf16.
return_residual: whether to return the input x along with the output. This is for
performance reason: for post-norm architecture, returning the input allows us
to fuse the backward of nn.Linear with the residual connection.
"""
assert checkpoint_lvl in [0, 1, 2]
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
if out_features is None:
out_features = in_features
assert bias == True, "DenseGeluDense module without bias is currently not supported"
self.return_residual = return_residual
self.checkpoint_lvl = checkpoint_lvl
self.heuristic = heuristic
self.fc1 = nn.Linear(in_features, intermediate_features, bias=bias, **factory_kwargs)
self.fc2 = nn.Linear(intermediate_features, out_features, bias=bias, **factory_kwargs)
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
def forward(self, x):
return fused_dense_gelu_dense_function_td(x, self.fc1.weight, self.fc1.bias,
self.fc2.weight, self.fc2.bias,
self.checkpoint_lvl, self.heuristic)
class FusedDenseResGeluDenseFunc(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx, x, weight1, bias1, weight2, bias2, checkpoint_lvl=0, heuristic=0):
"""checkpoint_lvl:
0: no recomputation in the bwd
1: recompute gelu_out in the bwd
2: recompute gelu_in and gelu_out in the bwd
"""
assert -1 <= heuristic <= 4
if torch.is_autocast_enabled():
dtype = torch.get_autocast_gpu_dtype()
x, weight1, bias1, weight2, bias2 = [a.to(dtype=dtype)
for a in [x, weight1, bias1, weight2, bias2]]
assert checkpoint_lvl in [0, 1, 2]
x = x.contiguous()
weight1 = weight1.contiguous()
bias1 = bias1.contiguous()
weight2 = weight2.contiguous()
bias2 = bias2.contiguous()
batch_shape, n = x.shape[:-1], x.shape[-1]
batch_dim = batch_shape.numel()
assert batch_dim <= 64 * 1024, 'fused_dense only supports dimension at most 64k'
# output1, output2, gelu_in = fused_dense_cuda.linear_gelu_linear_forward(
# x.reshape(batch_dim, n), weight1, bias1, weight2, bias2
# )
# gelu_in = fused_dense_cuda.linear_bias_forward(x.reshape(batch_dim, n), weight1, bias1)
# output1 = F.gelu(gelu_in, approximate='tanh')
save_gelu_in = checkpoint_lvl != 2
output1, *rest = fused_dense_cuda.linear_gelu_forward(x.reshape(batch_dim, n), weight1,
bias1, save_gelu_in, heuristic)
if save_gelu_in:
gelu_in = rest[0]
output2 = fused_dense_cuda.linear_bias_forward(output1, weight2, bias2)
ctx.checkpoint_lvl = checkpoint_lvl
ctx.heuristic = heuristic
if checkpoint_lvl == 0:
ctx.save_for_backward(x, weight1, weight2, gelu_in, output1)
elif checkpoint_lvl == 1:
ctx.save_for_backward(x, weight1, weight2, gelu_in)
elif checkpoint_lvl == 2:
ctx.save_for_backward(x, weight1, weight2, bias1)
return output2.reshape(*batch_shape, output2.shape[-1]), x
@staticmethod
@custom_bwd
def backward(ctx, grad_output, grad_input):
grad_output = grad_output.contiguous()
grad_input = grad_input.contiguous()
checkpoint_lvl = ctx.checkpoint_lvl
x, weight1, weight2, *rest = ctx.saved_tensors
batch_shape, n = x.shape[:-1], x.shape[-1]
batch_dim = batch_shape.numel()
if checkpoint_lvl == 0:
gelu_in, output1 = rest
elif checkpoint_lvl == 1:
gelu_in, = rest
output1 = F.gelu(gelu_in, approximate='tanh')
elif checkpoint_lvl == 2:
bias1, = rest
output1, gelu_in = fused_dense_cuda.linear_gelu_forward(x.reshape(batch_dim, n),
weight1, bias1, True, ctx.heuristic)
grad_input, grad_weight1, grad_bias1, grad_weight2, grad_bias2 = fused_dense_cuda.linear_residual_gelu_linear_backward(
x.reshape(batch_dim, n), gelu_in, output1, weight1, weight2,
grad_output.reshape(batch_dim, grad_output.shape[-1]),
grad_input.reshape(batch_dim, n),
ctx.heuristic
return fused_dense_gelu_dense_func(
x, self.fc1.weight, self.fc2.weight, self.fc1.bias, self.fc2.bias,
save_gelu_in=self.training, return_residual=self.return_residual,
checkpoint_lvl=self.checkpoint_lvl, heuristic=self.heuristic
)
# grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
# # grad_output1, grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_backward(output1, weight2, grad_output)
# grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad(output1, grad_output)
# grad_gelu = matmul_dgelu(grad_output, weight2, gelu_in)
# grad_input, grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_residual_backward(
# x.reshape(batch_dim, n), weight1, grad_gelu,
# grad_input.reshape(batch_dim, n)
# )
return grad_input.reshape_as(x), grad_weight1, grad_bias1, grad_weight2, grad_bias2, None, None
fused_dense_res_gelu_dense_function_td = FusedDenseResGeluDenseFunc.apply
class FusedDenseResGeluDense(FusedDenseGeluDenseTD):
def forward(self, x):
return fused_dense_res_gelu_dense_function_td(x, self.fc1.weight, self.fc1.bias,
self.fc2.weight, self.fc2.bias,
self.checkpoint_lvl, False, self.heuristic)
......@@ -6,29 +6,44 @@ import pytest
from einops import rearrange
from flash_attn.ops.fused_dense import FusedDenseTD, FusedDenseGeluDenseTD
from flash_attn.ops.fused_dense import FusedDenseResidual, FusedDenseResGeluDense
from flash_attn.ops.fused_dense import FusedDense, FusedDenseGeluDense
@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
@pytest.mark.parametrize('return_residual', [False, True])
@pytest.mark.parametrize('has_bias', [True, False])
@pytest.mark.parametrize('out_features', [1024, 4096])
@pytest.mark.parametrize('in_features', [1024, 4096])
def test_fused_linear_bias(in_features, out_features, dtype):
def test_fused_linear_bias(in_features, out_features, has_bias, return_residual, dtype):
device = 'cuda'
rtol, atol = (3e-3, 1e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3)
# set seed
torch.random.manual_seed(0)
batch_size = 8
seqlen = 512
x_pt = torch.randn(batch_size, seqlen, in_features, device=device, dtype=dtype, requires_grad=True)
x_pt = torch.randn(batch_size, seqlen, in_features, device=device, dtype=dtype,
requires_grad=True)
x = x_pt.detach().clone().requires_grad_()
model_pt = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype)
model = FusedDenseTD(in_features, out_features, device=device, dtype=dtype)
model_pt = torch.nn.Linear(in_features, out_features, bias=has_bias, device=device, dtype=dtype)
model = FusedDense(in_features, out_features, bias=has_bias, return_residual=return_residual,
device=device, dtype=dtype)
with torch.no_grad():
model.weight.copy_(model_pt.weight)
model.bias.copy_(model_pt.bias)
if has_bias:
model.bias.copy_(model_pt.bias)
out_pt = model_pt(x_pt)
out = model(x)
if not return_residual:
out = model(x)
else:
out, x_copy = model(x)
x_copy = (x_copy[..., :out_features] if out_features < in_features
else F.pad(x_copy, (0, out_features - in_features)))
x_pt_copy = (x_pt[..., :out_features] if out_features < in_features
else F.pad(x_pt, (0, out_features - in_features)))
# Just add some random function of the residual
out_pt = out_pt + F.gelu(x_pt_copy)
out = out + F.gelu(x_copy)
# with torch.no_grad():
# out_fl = F.linear(x_pt.float(), model.weight.float(), model.bias.float()).half()
assert torch.allclose(out, out_pt, rtol=rtol, atol=atol)
......@@ -40,66 +55,52 @@ def test_fused_linear_bias(in_features, out_features, dtype):
assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=atol)
# The error for d_weight and d_bias is quite a bit higher
assert torch.allclose(model.weight.grad, model_pt.weight.grad, rtol=rtol, atol=atol * 10)
assert torch.allclose(model.bias.grad, model_pt.bias.grad, rtol=rtol, atol=atol * 5)
@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
@pytest.mark.parametrize('out_features,in_features', [(1024, 1024), (4096, 4096)])
def test_fused_linear_bias_residual(in_features, out_features, dtype):
device = 'cuda'
rtol, atol = (3e-3, 1e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3)
# set seed
torch.random.manual_seed(0)
batch_size = 8
seqlen = 512
x_pt = torch.randn(batch_size, seqlen, in_features, device=device, dtype=dtype, requires_grad=True)
x = x_pt.detach().clone().requires_grad_()
model_pt = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype)
model = FusedDenseResidual(in_features, out_features, device=device, dtype=dtype)
with torch.no_grad():
model.weight.copy_(model_pt.weight)
model.bias.copy_(model_pt.bias)
out_pt = model_pt(x_pt) + F.gelu(x_pt) # Just add some random function of the residual x_pt
out, x_copy = model(x)
out = out + F.gelu(x_copy)
assert torch.allclose(out, out_pt, rtol=rtol, atol=atol * 2)
# If we don't divide by batch_size, the gradient gets a bit too large.
g = torch.randn_like(out) / 32
out_pt.backward(g)
out.backward(g)
assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=atol)
# The error for d_weight and d_bias is quite a bit higher
assert torch.allclose(model.weight.grad, model_pt.weight.grad, rtol=rtol, atol=atol * 10)
assert torch.allclose(model.bias.grad, model_pt.bias.grad, rtol=rtol, atol=atol * 5)
if has_bias:
assert torch.allclose(model.bias.grad, model_pt.bias.grad, rtol=rtol, atol=atol * 5)
@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
@pytest.mark.parametrize('heuristic', [1, -1])
@pytest.mark.parametrize('heuristic', [0, -1])
@pytest.mark.parametrize('checkpoint_lvl', [0, 1, 2])
@pytest.mark.parametrize('return_residual', [False, True])
@pytest.mark.parametrize('has_bias2', [True, False])
@pytest.mark.parametrize('has_bias1', [True, False])
@pytest.mark.parametrize('out_features', [1024, 4096])
@pytest.mark.parametrize('in_features', [1024, 4096])
def test_fused_dense_gelu_dense(in_features, out_features, checkpoint_lvl, heuristic, dtype):
def test_fused_dense_gelu_dense(in_features, out_features, has_bias1, has_bias2, return_residual,
checkpoint_lvl, heuristic, dtype):
device = 'cuda'
rtol, atol = (3e-3, 1e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3)
rtol, atol = (3e-3, 3e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3)
# set seed
torch.random.manual_seed(0)
batch_size = 8
seqlen = 512
x_pt = torch.randn(batch_size, seqlen, in_features, device=device, dtype=dtype, requires_grad=True)
x_pt = torch.randn(batch_size, seqlen, in_features, device=device, dtype=dtype,
requires_grad=True)
x = x_pt.detach().clone().requires_grad_()
model_pt_fc1 = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype)
model_pt_fc2 = torch.nn.Linear(out_features, in_features, device=device, dtype=dtype)
model = FusedDenseGeluDenseTD(in_features, out_features, in_features,
checkpoint_lvl=checkpoint_lvl, heuristic=heuristic,
device=device, dtype=dtype)
model_pt_fc1 = torch.nn.Linear(in_features, out_features, bias=has_bias1, device=device,
dtype=dtype)
model_pt_fc2 = torch.nn.Linear(out_features, in_features, bias=has_bias2, device=device,
dtype=dtype)
model = FusedDenseGeluDense(in_features, out_features, in_features, bias1=has_bias1,
bias2=has_bias2, return_residual=return_residual,
checkpoint_lvl=checkpoint_lvl, heuristic=heuristic,
device=device, dtype=dtype)
with torch.no_grad():
model.fc1.weight.copy_(model_pt_fc1.weight)
model.fc1.bias.copy_(model_pt_fc1.bias)
if has_bias1:
model.fc1.bias.copy_(model_pt_fc1.bias)
model.fc2.weight.copy_(model_pt_fc2.weight)
model.fc2.bias.copy_(model_pt_fc2.bias)
if has_bias2:
model.fc2.bias.copy_(model_pt_fc2.bias)
out_pt = model_pt_fc2(F.gelu(model_pt_fc1(x_pt), approximate='tanh'))
out = model(x)
if not return_residual:
out = model(x)
else:
out, x_copy = model(x)
# Just add some random function of the residual
out_pt = out_pt + F.gelu(x_pt)
out = out + F.gelu(x_copy)
assert torch.allclose(out, out_pt, rtol=rtol, atol=atol)
# If we don't divide by batch_size, the gradient gets a bit too large.
......@@ -109,46 +110,8 @@ def test_fused_dense_gelu_dense(in_features, out_features, checkpoint_lvl, heuri
assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=atol)
# The error for d_weight and d_bias is quite a bit higher
assert torch.allclose(model.fc1.weight.grad, model_pt_fc1.weight.grad, rtol=rtol, atol=atol * 10)
assert torch.allclose(model.fc1.bias.grad, model_pt_fc1.bias.grad, rtol=rtol, atol=atol * 5)
assert torch.allclose(model.fc2.weight.grad, model_pt_fc2.weight.grad, rtol=rtol, atol=atol * 10)
assert torch.allclose(model.fc2.bias.grad, model_pt_fc2.bias.grad, rtol=rtol, atol=atol * 5)
@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
@pytest.mark.parametrize('checkpoint_lvl', [0, 1, 2])
@pytest.mark.parametrize('out_features', [1024, 4096])
@pytest.mark.parametrize('in_features', [1024, 4096])
def test_fused_dense_residual_gelu_dense(in_features, out_features, checkpoint_lvl, dtype):
device = 'cuda'
rtol, atol = (3e-3, 1e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3)
# set seed
torch.random.manual_seed(0)
batch_size = 8
seqlen = 512
x_pt = torch.randn(batch_size, seqlen, in_features, device=device, dtype=dtype, requires_grad=True)
x = x_pt.detach().clone().requires_grad_()
model_pt_fc1 = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype)
model_pt_fc2 = torch.nn.Linear(out_features, in_features, device=device, dtype=dtype)
model = FusedDenseResGeluDense(in_features, out_features, in_features,
checkpoint_lvl=checkpoint_lvl,
device=device, dtype=dtype)
with torch.no_grad():
model.fc1.weight.copy_(model_pt_fc1.weight)
model.fc1.bias.copy_(model_pt_fc1.bias)
model.fc2.weight.copy_(model_pt_fc2.weight)
model.fc2.bias.copy_(model_pt_fc2.bias)
out_pt = model_pt_fc2(F.gelu(model_pt_fc1(x_pt), approximate='tanh')) + F.gelu(x_pt)
out, x_copy = model(x)
out = out + F.gelu(x_copy)
assert torch.allclose(out, out_pt, rtol=rtol, atol=atol * 2)
# If we don't divide by batch_size, the gradient gets a bit too large.
g = torch.randn_like(out) / 32
out_pt.backward(g)
out.backward(g)
assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=atol)
# The error for d_weight and d_bias is quite a bit higher
assert torch.allclose(model.fc1.weight.grad, model_pt_fc1.weight.grad, rtol=rtol, atol=atol * 10)
assert torch.allclose(model.fc1.bias.grad, model_pt_fc1.bias.grad, rtol=rtol, atol=atol * 5)
if has_bias1:
assert torch.allclose(model.fc1.bias.grad, model_pt_fc1.bias.grad, rtol=rtol, atol=atol * 5)
assert torch.allclose(model.fc2.weight.grad, model_pt_fc2.weight.grad, rtol=rtol, atol=atol * 10)
assert torch.allclose(model.fc2.bias.grad, model_pt_fc2.bias.grad, rtol=rtol, atol=atol * 5)
if has_bias2:
assert torch.allclose(model.fc2.bias.grad, model_pt_fc2.bias.grad, rtol=rtol, atol=atol * 5)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment