Commit 27f8f890 authored by Tri Dao's avatar Tri Dao
Browse files

[FusedDense] Allocate lt_workspace on input device

parent 48bc6eac
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
// We make it work for bfloat16 // We make it work for bfloat16
#include <torch/extension.h> #include <torch/extension.h>
#include <torch/torch.h> #include <torch/torch.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include <vector> #include <vector>
...@@ -28,13 +29,13 @@ ...@@ -28,13 +29,13 @@
} }
template <typename T> template <typename T>
int linear_bias_wgrad_cuda(const T *input, const T *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, T *d_weight, T *d_bias); int linear_bias_wgrad_cuda(const T *input, const T *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, T *d_weight, T *d_bias, void *lt_workspace, size_t workspaceSize);
template <typename T> template <typename T>
int linear_act_forward_cuda(const T *input, const T *weight, const T *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *output, void *pre_act); int linear_act_forward_cuda(const T *input, const T *weight, const T *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *output, void *pre_act, void *lt_workspace, size_t workspaceSize);
template <typename T> template <typename T>
int bias_act_linear_dgrad_bgrad_cuda(const T *weight, const T *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *d_input, T *d_bias); int bias_act_linear_dgrad_bgrad_cuda(const T *weight, const T *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *d_input, T *d_bias, void *lt_workspace, size_t workspaceSize);
std::vector<at::Tensor> linear_bias_wgrad(at::Tensor input, at::Tensor d_output, bool has_d_bias) { std::vector<at::Tensor> linear_bias_wgrad(at::Tensor input, at::Tensor d_output, bool has_d_bias) {
...@@ -66,6 +67,11 @@ std::vector<at::Tensor> linear_bias_wgrad(at::Tensor input, at::Tensor d_output, ...@@ -66,6 +67,11 @@ std::vector<at::Tensor> linear_bias_wgrad(at::Tensor input, at::Tensor d_output,
d_bias = at::empty({out_features}, opts); d_bias = at::empty({out_features}, opts);
#endif #endif
} }
// See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind setting this to 1M.
// However, Apex sets it to 4M and TransformerEngine sets to 32M for Hopper and 4M for other GPUs
// https://github.com/NVIDIA/TransformerEngine/blob/a0f0065498bbcfc1da78cf9e8b166f5381613fbc/transformer_engine/pytorch/module.py#L91
size_t workspaceSize = 1024 * 1024 * (at::cuda::getCurrentDeviceProperties()->major >= 9 ? 32 : 4);
auto lt_workspace = at::empty({static_cast<int64_t>(workspaceSize)}, opts.dtype(torch::kUInt8));
DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_wgrad", [&] { DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_wgrad", [&] {
auto result = linear_bias_wgrad_cuda<scalar_t>( auto result = linear_bias_wgrad_cuda<scalar_t>(
...@@ -75,7 +81,9 @@ std::vector<at::Tensor> linear_bias_wgrad(at::Tensor input, at::Tensor d_output, ...@@ -75,7 +81,9 @@ std::vector<at::Tensor> linear_bias_wgrad(at::Tensor input, at::Tensor d_output,
batch_size, batch_size,
out_features, out_features,
d_weight.data_ptr<scalar_t>(), d_weight.data_ptr<scalar_t>(),
has_d_bias ? d_bias.data_ptr<scalar_t>() : nullptr); has_d_bias ? d_bias.data_ptr<scalar_t>() : nullptr,
(void*) (lt_workspace.data_ptr()),
workspaceSize);
TORCH_CHECK(result == 0, "linear_bias_wgrad failed."); TORCH_CHECK(result == 0, "linear_bias_wgrad failed.");
}); });
...@@ -117,6 +125,11 @@ std::vector<at::Tensor> linear_act_forward(at::Tensor input, at::Tensor weight, ...@@ -117,6 +125,11 @@ std::vector<at::Tensor> linear_act_forward(at::Tensor input, at::Tensor weight,
// If ReLU, cuBlasLT stores a bit-mask (1 bit per element) // If ReLU, cuBlasLT stores a bit-mask (1 bit per element)
if (save_pre_act) { pre_act = at::empty({batch_size, is_gelu ? out_features : out_features / 8}, if (save_pre_act) { pre_act = at::empty({batch_size, is_gelu ? out_features : out_features / 8},
is_gelu ? opts : opts.dtype(torch::kUInt8)); } is_gelu ? opts : opts.dtype(torch::kUInt8)); }
// See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind setting this to 1M.
// However, Apex sets it to 4M and TransformerEngine sets to 32M for Hopper and 4M for other GPUs
// https://github.com/NVIDIA/TransformerEngine/blob/a0f0065498bbcfc1da78cf9e8b166f5381613fbc/transformer_engine/pytorch/module.py#L91
size_t workspaceSize = 1024 * 1024 * (at::cuda::getCurrentDeviceProperties()->major >= 9 ? 32 : 4);
auto lt_workspace = at::empty({static_cast<int64_t>(workspaceSize)}, opts.dtype(torch::kUInt8));
DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_act_forward", [&] { DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_act_forward", [&] {
auto result = linear_act_forward_cuda<scalar_t>( auto result = linear_act_forward_cuda<scalar_t>(
...@@ -129,7 +142,9 @@ std::vector<at::Tensor> linear_act_forward(at::Tensor input, at::Tensor weight, ...@@ -129,7 +142,9 @@ std::vector<at::Tensor> linear_act_forward(at::Tensor input, at::Tensor weight,
is_gelu, is_gelu,
heuristic, heuristic,
output.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(),
save_pre_act ? pre_act.data_ptr() : nullptr); save_pre_act ? pre_act.data_ptr() : nullptr,
(void*) (lt_workspace.data_ptr()),
workspaceSize);
TORCH_CHECK(result == 0, "linear_act_forward failed."); TORCH_CHECK(result == 0, "linear_act_forward failed.");
}); });
...@@ -168,6 +183,11 @@ std::vector<at::Tensor> bias_act_linear_dgrad_bgrad( ...@@ -168,6 +183,11 @@ std::vector<at::Tensor> bias_act_linear_dgrad_bgrad(
auto opts = weight.options(); auto opts = weight.options();
auto d_bias = at::empty({in_features}, opts); auto d_bias = at::empty({in_features}, opts);
auto d_input = at::empty({batch_size, in_features}, opts); auto d_input = at::empty({batch_size, in_features}, opts);
// See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind setting this to 1M.
// However, Apex sets it to 4M and TransformerEngine sets to 32M for Hopper and 4M for other GPUs
// https://github.com/NVIDIA/TransformerEngine/blob/a0f0065498bbcfc1da78cf9e8b166f5381613fbc/transformer_engine/pytorch/module.py#L91
size_t workspaceSize = 1024 * 1024 * (at::cuda::getCurrentDeviceProperties()->major >= 9 ? 32 : 4);
auto lt_workspace = at::empty({static_cast<int64_t>(workspaceSize)}, opts.dtype(torch::kUInt8));
DISPATCH_HALF_AND_BF16(weight.scalar_type(), "bias_act_linear_dgrad_bgrad", [&] { DISPATCH_HALF_AND_BF16(weight.scalar_type(), "bias_act_linear_dgrad_bgrad", [&] {
auto result = bias_act_linear_dgrad_bgrad_cuda<scalar_t>( auto result = bias_act_linear_dgrad_bgrad_cuda<scalar_t>(
...@@ -180,7 +200,9 @@ std::vector<at::Tensor> bias_act_linear_dgrad_bgrad( ...@@ -180,7 +200,9 @@ std::vector<at::Tensor> bias_act_linear_dgrad_bgrad(
is_gelu, is_gelu,
heuristic, heuristic,
d_input.data_ptr<scalar_t>(), d_input.data_ptr<scalar_t>(),
d_bias.data_ptr<scalar_t>()); d_bias.data_ptr<scalar_t>(),
(void*) (lt_workspace.data_ptr()),
workspaceSize);
TORCH_CHECK(result == 0, "bias_act_linear_dgrad_bgrad failed."); TORCH_CHECK(result == 0, "bias_act_linear_dgrad_bgrad failed.");
}); });
......
...@@ -110,7 +110,9 @@ int gemm_bias_act_lt( ...@@ -110,7 +110,9 @@ int gemm_bias_act_lt(
int64_t ldc, int64_t ldc,
void* pre_act, void* pre_act,
bool is_gelu, bool is_gelu,
int heuristic int heuristic,
void *lt_workspace,
size_t workspaceSize
) { ) {
static_assert(std::is_same<Dtype, at::Half>::value || std::is_same<Dtype, at::BFloat16>::value, static_assert(std::is_same<Dtype, at::Half>::value || std::is_same<Dtype, at::BFloat16>::value,
"gemm_bias_act_lt only supports fp16 and bf16"); "gemm_bias_act_lt only supports fp16 and bf16");
...@@ -120,14 +122,6 @@ int gemm_bias_act_lt( ...@@ -120,14 +122,6 @@ int gemm_bias_act_lt(
cublasLtHandle_t ltHandle = cublasLtHandle_t ltHandle =
reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle()); reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());
// See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
// setting this to 1M.
// However, Apex sets it to 4M and TransformerEngine sets to 32M for Hopper and 4M for other GPUs
// https://github.com/NVIDIA/TransformerEngine/blob/a0f0065498bbcfc1da78cf9e8b166f5381613fbc/transformer_engine/pytorch/module.py#L91
size_t workspaceSize = 1024 * 1024 * (at::cuda::getCurrentDeviceProperties()->major >= 9 ? 32 : 4);
void* workspace = at::empty(
{static_cast<int64_t>(workspaceSize)},
at::device({at::kCUDA, at::cuda::current_device()}).dtype(at::kByte)).data_ptr();
cublasStatus_t status = CUBLAS_STATUS_SUCCESS; cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
...@@ -228,7 +222,7 @@ int gemm_bias_act_lt( ...@@ -228,7 +222,7 @@ int gemm_bias_act_lt(
// TD [2022-04-29] Somehow algo 0 and 2 are a lot slower than other algos // TD [2022-04-29] Somehow algo 0 and 2 are a lot slower than other algos
&heuristicResult[heuristic].algo, &heuristicResult[heuristic].algo,
// NULL, // NULL,
workspace, lt_workspace,
workspaceSize, workspaceSize,
at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentCUDAStream());
...@@ -254,7 +248,9 @@ template int gemm_bias_act_lt( ...@@ -254,7 +248,9 @@ template int gemm_bias_act_lt(
int64_t ldc, int64_t ldc,
void* pre_act, void* pre_act,
bool is_gelu, bool is_gelu,
int heuristic); int heuristic,
void *lt_workspace,
size_t workspaceSize);
template int gemm_bias_act_lt( template int gemm_bias_act_lt(
cublasOperation_t transa, cublasOperation_t transa,
...@@ -272,7 +268,9 @@ template int gemm_bias_act_lt( ...@@ -272,7 +268,9 @@ template int gemm_bias_act_lt(
int64_t ldc, int64_t ldc,
void* pre_act, void* pre_act,
bool is_gelu, bool is_gelu,
int heuristic); int heuristic,
void *lt_workspace,
size_t workspaceSize);
template <typename Dtype> template <typename Dtype>
int gemm_bgradb_lt( int gemm_bgradb_lt(
...@@ -288,7 +286,9 @@ int gemm_bgradb_lt( ...@@ -288,7 +286,9 @@ int gemm_bgradb_lt(
int64_t ldb, int64_t ldb,
Dtype* C, Dtype* C,
int64_t ldc, int64_t ldc,
Dtype* bgrad) { Dtype* bgrad,
void *lt_workspace,
size_t workspaceSize) {
static_assert(std::is_same<Dtype, at::Half>::value || std::is_same<Dtype, at::BFloat16>::value, static_assert(std::is_same<Dtype, at::Half>::value || std::is_same<Dtype, at::BFloat16>::value,
"gemm_bgradb_lt only supports fp16 and bf16"); "gemm_bgradb_lt only supports fp16 and bf16");
float beta = 0.0; float beta = 0.0;
...@@ -296,13 +296,6 @@ int gemm_bgradb_lt( ...@@ -296,13 +296,6 @@ int gemm_bgradb_lt(
cublasLtHandle_t ltHandle = cublasLtHandle_t ltHandle =
reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle()); reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());
// See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
// setting this to 1M.
// However, Apex sets it to 4M and TransformerEngine sets to 32M for Hopper and 4M for other GPUs
size_t workspaceSize = 1024 * 1024 * (at::cuda::getCurrentDeviceProperties()->major >= 9 ? 32 : 4);
void* workspace = at::empty(
{static_cast<int64_t>(workspaceSize)},
at::device({at::kCUDA, at::cuda::current_device()}).dtype(at::kByte)).data_ptr();
cublasStatus_t status = CUBLAS_STATUS_SUCCESS; cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
...@@ -384,7 +377,7 @@ int gemm_bgradb_lt( ...@@ -384,7 +377,7 @@ int gemm_bgradb_lt(
&Cdesc, &Cdesc,
//&heuristicResult.algo, //&heuristicResult.algo,
NULL, NULL,
workspace, lt_workspace,
workspaceSize, workspaceSize,
at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentCUDAStream());
...@@ -408,7 +401,9 @@ template int gemm_bgradb_lt( ...@@ -408,7 +401,9 @@ template int gemm_bgradb_lt(
int64_t ldb, int64_t ldb,
at::Half* C, at::Half* C,
int64_t ldc, int64_t ldc,
at::Half* bgrad); at::Half* bgrad,
void *lt_workspace,
size_t workspaceSize);
template int gemm_bgradb_lt( template int gemm_bgradb_lt(
cublasOperation_t transa, cublasOperation_t transa,
...@@ -423,7 +418,9 @@ template int gemm_bgradb_lt( ...@@ -423,7 +418,9 @@ template int gemm_bgradb_lt(
int64_t ldb, int64_t ldb,
at::BFloat16* C, at::BFloat16* C,
int64_t ldc, int64_t ldc,
at::BFloat16* bgrad); at::BFloat16* bgrad,
void *lt_workspace,
size_t workspaceSize);
template <typename Dtype> template <typename Dtype>
int gemm_dact_bgradb_lt( int gemm_dact_bgradb_lt(
...@@ -442,7 +439,9 @@ int gemm_dact_bgradb_lt( ...@@ -442,7 +439,9 @@ int gemm_dact_bgradb_lt(
int64_t ldc, int64_t ldc,
Dtype* bgrad, Dtype* bgrad,
bool is_gelu, bool is_gelu,
int heuristic) { int heuristic,
void *lt_workspace,
size_t workspaceSize) {
static_assert(std::is_same<Dtype, at::Half>::value || std::is_same<Dtype, at::BFloat16>::value, static_assert(std::is_same<Dtype, at::Half>::value || std::is_same<Dtype, at::BFloat16>::value,
"gemm_dact_bgradb_lt only supports fp16 and bf16"); "gemm_dact_bgradb_lt only supports fp16 and bf16");
float beta = 0.0; float beta = 0.0;
...@@ -450,13 +449,6 @@ int gemm_dact_bgradb_lt( ...@@ -450,13 +449,6 @@ int gemm_dact_bgradb_lt(
cublasLtHandle_t ltHandle = cublasLtHandle_t ltHandle =
reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle()); reinterpret_cast<cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle());
// See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
// setting this to 1M.
// However, Apex sets it to 4M and TransformerEngine sets to 32M for Hopper and 4M for other GPUs
size_t workspaceSize = 1024 * 1024 * (at::cuda::getCurrentDeviceProperties()->major >= 9 ? 32 : 4);
void* workspace = at::empty(
{static_cast<int64_t>(workspaceSize)},
at::device({at::kCUDA, at::cuda::current_device()}).dtype(at::kByte)).data_ptr();
cublasStatus_t status = CUBLAS_STATUS_SUCCESS; cublasStatus_t status = CUBLAS_STATUS_SUCCESS;
...@@ -542,7 +534,7 @@ int gemm_dact_bgradb_lt( ...@@ -542,7 +534,7 @@ int gemm_dact_bgradb_lt(
//&heuristicResult.algo, //&heuristicResult.algo,
&heuristicResult[heuristic].algo, &heuristicResult[heuristic].algo,
// NULL, // NULL,
workspace, lt_workspace,
workspaceSize, workspaceSize,
at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentCUDAStream());
...@@ -568,7 +560,9 @@ template int gemm_dact_bgradb_lt( ...@@ -568,7 +560,9 @@ template int gemm_dact_bgradb_lt(
int64_t ldc, int64_t ldc,
at::Half* bgrad, at::Half* bgrad,
bool is_gelu, bool is_gelu,
int heuristic); int heuristic,
void *lt_workspace,
size_t workspaceSize);
template int gemm_dact_bgradb_lt( template int gemm_dact_bgradb_lt(
cublasOperation_t transa, cublasOperation_t transa,
...@@ -586,12 +580,14 @@ template int gemm_dact_bgradb_lt( ...@@ -586,12 +580,14 @@ template int gemm_dact_bgradb_lt(
int64_t ldc, int64_t ldc,
at::BFloat16* bgrad, at::BFloat16* bgrad,
bool is_gelu, bool is_gelu,
int heuristic); int heuristic,
void *lt_workspace,
size_t workspaceSize);
#endif #endif
template <typename T> template <typename T>
int linear_bias_wgrad_cuda(const T *input, const T *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, T *d_weight, T *d_bias) { int linear_bias_wgrad_cuda(const T *input, const T *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, T *d_weight, T *d_bias, void *lt_workspace, size_t workspaceSize) {
const float alpha = 1.0; const float alpha = 1.0;
const float beta_zero = 0.0; const float beta_zero = 0.0;
int status = 1; int status = 1;
...@@ -610,7 +606,9 @@ int linear_bias_wgrad_cuda(const T *input, const T *d_output, int64_t in_feature ...@@ -610,7 +606,9 @@ int linear_bias_wgrad_cuda(const T *input, const T *d_output, int64_t in_feature
out_features, out_features,
d_weight, d_weight,
in_features, in_features,
d_bias); d_bias,
lt_workspace,
workspaceSize);
#endif #endif
if (status != 0){ if (status != 0){
...@@ -652,7 +650,7 @@ int linear_bias_wgrad_cuda(const T *input, const T *d_output, int64_t in_feature ...@@ -652,7 +650,7 @@ int linear_bias_wgrad_cuda(const T *input, const T *d_output, int64_t in_feature
} }
template <typename T> template <typename T>
int linear_act_forward_cuda(const T *input, const T *weight, const T *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *output, void *pre_act) { int linear_act_forward_cuda(const T *input, const T *weight, const T *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *output, void *pre_act, void *lt_workspace, size_t workspaceSize) {
int status = 1; int status = 1;
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600 #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
status = gemm_bias_act_lt( status = gemm_bias_act_lt(
...@@ -671,7 +669,9 @@ int linear_act_forward_cuda(const T *input, const T *weight, const T *bias, int6 ...@@ -671,7 +669,9 @@ int linear_act_forward_cuda(const T *input, const T *weight, const T *bias, int6
out_features, out_features,
pre_act, pre_act,
is_gelu, is_gelu,
heuristic); heuristic,
lt_workspace,
workspaceSize);
return status; return status;
#else #else
return 1; return 1;
...@@ -679,7 +679,7 @@ int linear_act_forward_cuda(const T *input, const T *weight, const T *bias, int6 ...@@ -679,7 +679,7 @@ int linear_act_forward_cuda(const T *input, const T *weight, const T *bias, int6
} }
template <typename T> template <typename T>
int bias_act_linear_dgrad_bgrad_cuda(const T *weight, const T *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *d_input, T *d_bias) { int bias_act_linear_dgrad_bgrad_cuda(const T *weight, const T *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *d_input, T *d_bias, void *lt_workspace, size_t workspaceSize) {
const float alpha = 1.0; const float alpha = 1.0;
int status = 1; int status = 1;
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600 #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11600
...@@ -699,17 +699,19 @@ int bias_act_linear_dgrad_bgrad_cuda(const T *weight, const T *d_output, const v ...@@ -699,17 +699,19 @@ int bias_act_linear_dgrad_bgrad_cuda(const T *weight, const T *d_output, const v
in_features, in_features,
d_bias, d_bias,
is_gelu, is_gelu,
heuristic); heuristic,
lt_workspace,
workspaceSize);
#endif #endif
return status; return status;
} }
template int linear_bias_wgrad_cuda<at::Half>(const at::Half *input, const at::Half *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, at::Half *d_weight, at::Half *d_bias); template int linear_bias_wgrad_cuda<at::Half>(const at::Half *input, const at::Half *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, at::Half *d_weight, at::Half *d_bias, void *lt_workspace, size_t workspaceSize);
template int linear_bias_wgrad_cuda<at::BFloat16>(const at::BFloat16 *input, const at::BFloat16 *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, at::BFloat16 *d_weight, at::BFloat16 *d_bias); template int linear_bias_wgrad_cuda<at::BFloat16>(const at::BFloat16 *input, const at::BFloat16 *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, at::BFloat16 *d_weight, at::BFloat16 *d_bias, void *lt_workspace, size_t workspaceSize);
template int linear_act_forward_cuda<at::Half>(const at::Half *input, const at::Half *weight, const at::Half *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::Half *output, void *pre_act); template int linear_act_forward_cuda<at::Half>(const at::Half *input, const at::Half *weight, const at::Half *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::Half *output, void *pre_act, void *lt_workspace, size_t workspaceSize);
template int linear_act_forward_cuda<at::BFloat16>(const at::BFloat16 *input, const at::BFloat16 *weight, const at::BFloat16 *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::BFloat16 *output, void *pre_act); template int linear_act_forward_cuda<at::BFloat16>(const at::BFloat16 *input, const at::BFloat16 *weight, const at::BFloat16 *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::BFloat16 *output, void *pre_act, void *lt_workspace, size_t workspaceSize);
template int bias_act_linear_dgrad_bgrad_cuda<at::Half>(const at::Half *weight, const at::Half *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::Half *d_input, at::Half *d_bias); template int bias_act_linear_dgrad_bgrad_cuda<at::Half>(const at::Half *weight, const at::Half *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::Half *d_input, at::Half *d_bias, void *lt_workspace, size_t workspaceSize);
template int bias_act_linear_dgrad_bgrad_cuda<at::BFloat16>(const at::BFloat16 *weight, const at::BFloat16 *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::BFloat16 *d_input, at::BFloat16 *d_bias); template int bias_act_linear_dgrad_bgrad_cuda<at::BFloat16>(const at::BFloat16 *weight, const at::BFloat16 *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, at::BFloat16 *d_input, at::BFloat16 *d_bias, void *lt_workspace, size_t workspaceSize);
\ No newline at end of file \ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment