Commit e68ebbe8 authored by Tri Dao's avatar Tri Dao
Browse files

Simplify FusedDense

parent 1bc6e5b0
...@@ -6,6 +6,8 @@ ...@@ -6,6 +6,8 @@
#include <stdio.h> #include <stdio.h>
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
// https://github.com/NVIDIA/apex/blob/master/csrc/type_shim.h // https://github.com/NVIDIA/apex/blob/master/csrc/type_shim.h
// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 // #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#define DISPATCH_HALF_AND_BF16(TYPE, NAME, ...) \ #define DISPATCH_HALF_AND_BF16(TYPE, NAME, ...) \
...@@ -24,14 +26,6 @@ ...@@ -24,14 +26,6 @@
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
} }
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
template <typename T>
int linear_bias_forward_cuda(at::Tensor input, T *weight, at::Tensor bias, int in_features, int batch_size, int out_features, at::Tensor output, void *lt_workspace);
template <typename T>
int linear_bias_backward_cuda(T *input, T *weight, T *d_output, int in_features, int batch_size, int out_features, T *d_weight, T *d_bias, T *d_input, bool residual, void *lt_workspace);
template <typename T> template <typename T>
int linear_bias_wgrad_cuda(T *input, T *d_output, int in_features, int batch_size, int out_features, T *d_weight, T *d_bias, void *lt_workspace); int linear_bias_wgrad_cuda(T *input, T *d_output, int in_features, int batch_size, int out_features, T *d_weight, T *d_bias, void *lt_workspace);
...@@ -39,103 +33,34 @@ template <typename T> ...@@ -39,103 +33,34 @@ template <typename T>
int linear_gelu_forward_cuda(T *input, T *weight, T *bias, int in_features, int batch_size, int out_features, int heuristic, T *output, T *gelu_in, void *lt_workspace) ; int linear_gelu_forward_cuda(T *input, T *weight, T *bias, int in_features, int batch_size, int out_features, int heuristic, T *output, T *gelu_in, void *lt_workspace) ;
template <typename T> template <typename T>
int linear_gelu_linear_backward_cuda(T *input, T *gelu_in, T *output1, T *weight1, T *weight2, T *d_output1, T *d_output2, int in_features, int batch_size, int hidden_features, int out_features, int heuristic, T *d_weight1, T *d_weight2, T *d_bias1, T *d_bias2, T *d_input, bool residual, void *lt_workspace); int bias_gelu_linear_dgrad_bgrad_cuda(T *weight, T *d_output, T *gelu_in, int in_features, int batch_size, int out_features, int heuristic, T *d_input, T *d_bias, void *lt_workspace);
at::Tensor linear_bias_forward(at::Tensor input, at::Tensor weight, at::Tensor bias) {
auto batch_size = input.size(0);
auto in_features = input.size(1);
int out_features = weight.size(0);
//auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
// create output/workspace tensor
auto out = at::empty({batch_size, out_features}, at::dtype(input.dtype()).device(input.device()));
//auto reserved_space = at::empty({reserved_size}, inputs[0].type());
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto lt_workspace = at::empty({1 << 22}, at::dtype(input.dtype()).device(input.device()));
DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_forward", [&] {
scalar_t* w_ptr = weight.data_ptr<scalar_t>();
auto result = linear_bias_forward_cuda<scalar_t>(
input,
w_ptr,
bias,
in_features,
batch_size,
out_features,
out,
//out.data_ptr<scalar_t>(),
// reserved_space.data_ptr<scalar_t>(),
(void*) (lt_workspace.data_ptr<scalar_t>()));
TORCH_CHECK(result == 0, "linear_bias_forward failed.")
});
return {out};
}
std::vector<at::Tensor> linear_bias_backward(at::Tensor input, at::Tensor weight, at::Tensor d_output) {
auto batch_size = input.size(0);
auto in_features = input.size(1);
int out_features = weight.size(0);
//auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data()); std::vector<at::Tensor> linear_bias_wgrad(at::Tensor input, at::Tensor d_output, bool has_d_bias) {
// create output/workspace tensor
auto opts = input.options();
auto d_weight = at::empty({out_features, in_features}, opts);
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION < 11600
auto d_bias = d_output.view({-1, out_features}).sum(0, false);
#else
auto d_bias = at::empty({out_features}, opts);
#endif
auto d_input = at::empty({batch_size, in_features}, opts);
//auto reserved_space = at::empty({reserved_size}, inputs[0].type());
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto lt_workspace = at::empty({1 << 22}, opts);
DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_backward", [&] {
scalar_t* w_ptr = weight.data_ptr<scalar_t>();
auto result = linear_bias_backward_cuda<scalar_t>(
input.data_ptr<scalar_t>(),
w_ptr,
d_output.data_ptr<scalar_t>(),
in_features,
batch_size,
out_features,
d_weight.data_ptr<scalar_t>(),
d_bias.data_ptr<scalar_t>(),
d_input.data_ptr<scalar_t>(),
// reserved_space.data_ptr<scalar_t>(),
/*residual=*/false,
(void*) (lt_workspace.data_ptr<scalar_t>()));
TORCH_CHECK(result == 0, "linear_bias_backward failed.")
});
return {d_input, d_weight, d_bias};
}
std::vector<at::Tensor> linear_bias_wgrad(at::Tensor input, at::Tensor d_output) {
auto batch_size = input.size(0);
auto in_features = input.size(1);
int batch_size = input.size(0);
int in_features = input.size(1);
int out_features = d_output.size(1); int out_features = d_output.size(1);
//auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data()); TORCH_CHECK(input.dtype() == torch::kFloat16 || input.dtype() == torch::kBFloat16);
TORCH_CHECK(input.dtype() == d_output.dtype());
TORCH_CHECK(input.is_cuda());
TORCH_CHECK(d_output.is_cuda());
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(d_output.is_contiguous());
CHECK_SHAPE(input, batch_size, in_features);
CHECK_SHAPE(d_output, batch_size, out_features);
// create output/workspace tensor // create output/workspace tensor
auto opts = input.options(); auto opts = input.options();
auto d_weight = at::empty({out_features, in_features}, opts); auto d_weight = at::empty({out_features, in_features}, opts);
at::Tensor d_bias;
if (has_d_bias) {
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION < 11600 #if defined(CUBLAS_VERSION) && CUBLAS_VERSION < 11600
auto d_bias = d_output.view({-1, out_features}).sum(0, false); d_bias = d_output.view({-1, out_features}).sum(0, false);
#else #else
auto d_bias = at::empty({out_features}, opts); d_bias = at::empty({out_features}, opts);
#endif #endif
//auto reserved_space = at::empty({reserved_size}, inputs[0].type()); }
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB // allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto lt_workspace = at::empty({1 << 22}, opts); auto lt_workspace = at::empty({1 << 22}, opts);
...@@ -147,93 +72,59 @@ std::vector<at::Tensor> linear_bias_wgrad(at::Tensor input, at::Tensor d_output) ...@@ -147,93 +72,59 @@ std::vector<at::Tensor> linear_bias_wgrad(at::Tensor input, at::Tensor d_output)
batch_size, batch_size,
out_features, out_features,
d_weight.data_ptr<scalar_t>(), d_weight.data_ptr<scalar_t>(),
d_bias.data_ptr<scalar_t>(), has_d_bias ? d_bias.data_ptr<scalar_t>() : nullptr,
// reserved_space.data_ptr<scalar_t>(),
(void*) (lt_workspace.data_ptr<scalar_t>())); (void*) (lt_workspace.data_ptr<scalar_t>()));
TORCH_CHECK(result == 0, "linear_bias_wgrad failed.") TORCH_CHECK(result == 0, "linear_bias_wgrad failed.");
}); });
return {d_weight, d_bias}; return {d_weight, d_bias};
} }
std::vector<at::Tensor> linear_bias_residual_backward(at::Tensor input, at::Tensor weight, at::Tensor d_output, at::Tensor d_input) { std::vector<at::Tensor> linear_gelu_forward(at::Tensor input, at::Tensor weight,
c10::optional<at::Tensor> bias_,
auto batch_size = input.size(0);
auto in_features = input.size(1);
int out_features = weight.size(0);
//auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
// create output/workspace tensor
auto opts = input.options();
auto d_weight = at::empty({out_features, in_features}, opts);
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION < 11600
auto d_bias = d_output.view({-1, out_features}).sum(0, false);
#else
auto d_bias = at::empty({out_features}, opts);
#endif
CHECK_SHAPE(d_input, batch_size, in_features);
//auto reserved_space = at::empty({reserved_size}, inputs[0].type());
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto lt_workspace = at::empty({1 << 22}, opts);
DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_backward", [&] {
scalar_t* w_ptr = weight.data_ptr<scalar_t>();
auto result = linear_bias_backward_cuda<scalar_t>(
input.data_ptr<scalar_t>(),
w_ptr,
d_output.data_ptr<scalar_t>(),
in_features,
batch_size,
out_features,
d_weight.data_ptr<scalar_t>(),
d_bias.data_ptr<scalar_t>(),
d_input.data_ptr<scalar_t>(),
// reserved_space.data_ptr<scalar_t>(),
/*residual=*/true,
(void*) (lt_workspace.data_ptr<scalar_t>()));
TORCH_CHECK(result == 0, "linear_bias_residual_backward failed.")
});
return {d_input, d_weight, d_bias};
}
std::vector<at::Tensor> linear_gelu_forward(at::Tensor input, at::Tensor weight, at::Tensor bias,
bool save_gelu_in, int heuristic) { bool save_gelu_in, int heuristic) {
auto batch_size = input.size(0); int batch_size = input.size(0);
auto in_features = input.size(1); int in_features = input.size(1);
int out_features = weight.size(0); int out_features = weight.size(0);
//auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data()); TORCH_CHECK(input.dtype() == torch::kFloat16 || input.dtype() == torch::kBFloat16);
TORCH_CHECK(input.dtype() == weight.dtype());
TORCH_CHECK(input.is_cuda());
TORCH_CHECK(weight.is_cuda());
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(weight.is_contiguous());
CHECK_SHAPE(input, batch_size, in_features);
CHECK_SHAPE(weight, out_features, in_features);
if (bias_.has_value()) {
auto bias = bias_.value();
TORCH_CHECK(bias.dtype() == input.dtype());
TORCH_CHECK(bias.is_cuda());
TORCH_CHECK(bias.is_contiguous());
CHECK_SHAPE(bias, out_features);
}
// create output/workspace tensor // create output/workspace tensor
auto opts = input.options(); auto opts = input.options();
auto output = at::empty({batch_size, out_features}, opts); auto output = at::empty({batch_size, out_features}, opts);
at::Tensor gelu_in; at::Tensor gelu_in;
if (save_gelu_in) { gelu_in = at::empty({batch_size, out_features}, opts); } if (save_gelu_in) { gelu_in = at::empty({batch_size, out_features}, opts); }
//auto reserved_space = at::empty({reserved_size}, inputs[0].type());
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB // allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto lt_workspace = at::empty({1 << 22}, opts); auto lt_workspace = at::empty({1 << 22}, opts);
DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_gelu_forward", [&] { DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_gelu_forward", [&] {
scalar_t* w_ptr = weight.data_ptr<scalar_t>();
scalar_t* b_ptr = bias.data_ptr<scalar_t>();
auto result = linear_gelu_forward_cuda<scalar_t>( auto result = linear_gelu_forward_cuda<scalar_t>(
input.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
w_ptr, weight.data_ptr<scalar_t>(),
b_ptr, bias_.has_value()? bias_.value().data_ptr<scalar_t>() : nullptr,
in_features, in_features,
batch_size, batch_size,
out_features, out_features,
heuristic, heuristic,
output.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(),
save_gelu_in ? gelu_in.data_ptr<scalar_t>() : nullptr, save_gelu_in ? gelu_in.data_ptr<scalar_t>() : nullptr,
// reserved_space.data_ptr<scalar_t>(),
(void*) (lt_workspace.data_ptr<scalar_t>())); (void*) (lt_workspace.data_ptr<scalar_t>()));
TORCH_CHECK(result == 0, "linear_gelu_forward failed.") TORCH_CHECK(result == 0, "linear_gelu_forward failed.");
}); });
std::vector<at::Tensor> result = {output}; std::vector<at::Tensor> result = {output};
...@@ -241,116 +132,54 @@ std::vector<at::Tensor> linear_gelu_forward(at::Tensor input, at::Tensor weight, ...@@ -241,116 +132,54 @@ std::vector<at::Tensor> linear_gelu_forward(at::Tensor input, at::Tensor weight,
return result; return result;
} }
std::vector<at::Tensor> linear_gelu_linear_backward(at::Tensor input, at::Tensor gelu_in, at::Tensor output1, at::Tensor weight1, at::Tensor weight2, at::Tensor d_output2, int heuristic) { std::vector<at::Tensor> bias_gelu_linear_dgrad_bgrad(
at::Tensor weight, at::Tensor d_output, at::Tensor gelu_in, int heuristic
) {
auto batch_size = input.size(0); int batch_size = d_output.size(0);
auto in_features = input.size(1); int out_features = d_output.size(1);
int in_features = weight.size(1);
int hidden_features = weight1.size(0);
int out_features = weight2.size(0); TORCH_CHECK(weight.dtype() == torch::kFloat16 || weight.dtype() == torch::kBFloat16);
TORCH_CHECK(weight.dtype() == d_output.dtype());
//auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data()); TORCH_CHECK(weight.dtype() == gelu_in.dtype());
TORCH_CHECK(weight.is_cuda());
TORCH_CHECK(d_output.is_cuda());
TORCH_CHECK(gelu_in.is_cuda());
TORCH_CHECK(weight.is_contiguous());
TORCH_CHECK(d_output.is_contiguous());
TORCH_CHECK(gelu_in.is_contiguous());
CHECK_SHAPE(weight, out_features, in_features);
CHECK_SHAPE(d_output, batch_size, out_features);
CHECK_SHAPE(gelu_in, batch_size, in_features);
// create output/workspace tensor // create output/workspace tensor
auto opts = input.options(); auto opts = weight.options();
auto d_weight1 = at::empty({hidden_features, in_features}, opts); auto d_bias = at::empty({in_features}, opts);
auto d_weight2 = at::empty({out_features, hidden_features}, opts);
auto d_bias1 = at::empty({hidden_features}, opts);
auto d_bias2 = at::empty({out_features}, opts);
auto d_input = at::empty({batch_size, in_features}, opts); auto d_input = at::empty({batch_size, in_features}, opts);
auto d_output1 = at::empty({batch_size, hidden_features}, opts);
//auto reserved_space = at::empty({reserved_size}, inputs[0].type());
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto lt_workspace = at::empty({1 << 22}, opts);
DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_backward", [&] {
//scalar_t* w_ptr = weight.data_ptr<scalar_t>();
//scalar_t* d_b_ptr = d_bias.data_ptr<scalar_t>();
auto result = linear_gelu_linear_backward_cuda<scalar_t>(
input.data_ptr<scalar_t>(),
gelu_in.data_ptr<scalar_t>(),
output1.data_ptr<scalar_t>(),
weight1.data_ptr<scalar_t>(),
weight2.data_ptr<scalar_t>(),
d_output1.data_ptr<scalar_t>(),
d_output2.data_ptr<scalar_t>(),
in_features,
batch_size,
hidden_features,
out_features,
heuristic,
d_weight1.data_ptr<scalar_t>(),
d_weight2.data_ptr<scalar_t>(),
d_bias1.data_ptr<scalar_t>(),
d_bias2.data_ptr<scalar_t>(),
d_input.data_ptr<scalar_t>(),
// reserved_space.data_ptr<scalar_t>(),
/*residual=*/false,
(void*) (lt_workspace.data_ptr<scalar_t>()));
TORCH_CHECK(result == 0, "linear_gelu_linear_backward failed.")
});
return {d_input, d_weight1, d_bias1, d_weight2, d_bias2};
}
std::vector<at::Tensor> linear_residual_gelu_linear_backward(at::Tensor input, at::Tensor gelu_in, at::Tensor output1, at::Tensor weight1, at::Tensor weight2, at::Tensor d_output2, at::Tensor d_input, int heuristic) {
auto batch_size = input.size(0);
auto in_features = input.size(1);
int hidden_features = weight1.size(0);
int out_features = weight2.size(0);
//auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
// create output/workspace tensor
auto opts = input.options();
auto d_weight1 = at::empty({hidden_features, in_features}, opts);
auto d_weight2 = at::empty({out_features, hidden_features}, opts);
auto d_bias1 = at::empty({hidden_features}, opts);
auto d_bias2 = at::empty({out_features}, opts);
CHECK_SHAPE(d_input, batch_size, in_features);
auto d_output1 = at::empty({batch_size, hidden_features}, opts);
//auto reserved_space = at::empty({reserved_size}, inputs[0].type());
// allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB // allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
auto lt_workspace = at::empty({1 << 22}, opts); auto lt_workspace = at::empty({1 << 22}, opts);
DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_backward", [&] { DISPATCH_HALF_AND_BF16(weight.scalar_type(), "bias_gelu_linear_dgrad_bgrad", [&] {
//scalar_t* w_ptr = weight.data_ptr<scalar_t>(); auto result = bias_gelu_linear_dgrad_bgrad_cuda<scalar_t>(
//scalar_t* d_b_ptr = d_bias.data_ptr<scalar_t>(); weight.data_ptr<scalar_t>(),
auto result = linear_gelu_linear_backward_cuda<scalar_t>( d_output.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
gelu_in.data_ptr<scalar_t>(), gelu_in.data_ptr<scalar_t>(),
output1.data_ptr<scalar_t>(),
weight1.data_ptr<scalar_t>(),
weight2.data_ptr<scalar_t>(),
d_output1.data_ptr<scalar_t>(),
d_output2.data_ptr<scalar_t>(),
in_features, in_features,
batch_size, batch_size,
hidden_features,
out_features, out_features,
heuristic, heuristic,
d_weight1.data_ptr<scalar_t>(),
d_weight2.data_ptr<scalar_t>(),
d_bias1.data_ptr<scalar_t>(),
d_bias2.data_ptr<scalar_t>(),
d_input.data_ptr<scalar_t>(), d_input.data_ptr<scalar_t>(),
// reserved_space.data_ptr<scalar_t>(), d_bias.data_ptr<scalar_t>(),
/*residual=*/true,
(void*) (lt_workspace.data_ptr<scalar_t>())); (void*) (lt_workspace.data_ptr<scalar_t>()));
TORCH_CHECK(result == 0, "linear_residual_gelu_linear_backward failed.") TORCH_CHECK(result == 0, "bias_gelu_linear_dgrad_bgrad failed.");
}); });
return {d_input, d_weight1, d_bias1, d_weight2, d_bias2}; return {d_input, d_bias};
} }
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("linear_bias_forward", &linear_bias_forward, "linear bias forward");
m.def("linear_bias_backward", &linear_bias_backward, "linear bias backward");
m.def("linear_bias_wgrad", &linear_bias_wgrad, "linear bias wgrad"); m.def("linear_bias_wgrad", &linear_bias_wgrad, "linear bias wgrad");
m.def("linear_bias_residual_backward", &linear_bias_residual_backward, "linear bias residual backward");
m.def("linear_gelu_forward", &linear_gelu_forward, "linear gelu forward"); m.def("linear_gelu_forward", &linear_gelu_forward, "linear gelu forward");
m.def("linear_gelu_linear_backward", &linear_gelu_linear_backward, "linear gelu linear backward"); m.def("bias_gelu_linear_dgrad_bgrad", &bias_gelu_linear_dgrad_bgrad, "bias gelu linear dgrad bgrad");
m.def("linear_residual_gelu_linear_backward", &linear_residual_gelu_linear_backward, "linear residual gelu linear backward");
} }
This diff is collapsed.
...@@ -10,9 +10,9 @@ from torch.nn.modules.utils import _pair ...@@ -10,9 +10,9 @@ from torch.nn.modules.utils import _pair
from einops import rearrange from einops import rearrange
try: try:
from flash_attn.ops.fused_dense import FusedDenseTD from flash_attn.ops.fused_dense import FusedDense
except ImportError: except ImportError:
FusedDenseTD = None FusedDense = None
class PatchEmbed(nn.Module): class PatchEmbed(nn.Module):
...@@ -37,10 +37,10 @@ class PatchEmbed(nn.Module): ...@@ -37,10 +37,10 @@ class PatchEmbed(nn.Module):
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1] self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten self.flatten = flatten
if fused_bias_fc and FusedDenseTD is None: if fused_bias_fc and FusedDense is None:
raise ImportError('fused_dense is not installed') raise ImportError('fused_dense is not installed')
linear_cls = nn.Linear if not fused_bias_fc or not bias else FusedDenseTD linear_cls = nn.Linear if not fused_bias_fc or not bias else FusedDense
self.proj = linear_cls(in_chans * patch_size[0] * patch_size[1], embed_dim, bias=bias) self.proj = linear_cls(in_chans * patch_size[0] * patch_size[1], embed_dim, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
......
...@@ -30,9 +30,9 @@ from flash_attn.bert_padding import unpad_input, pad_input ...@@ -30,9 +30,9 @@ from flash_attn.bert_padding import unpad_input, pad_input
from flash_attn.bert_padding import index_first_axis, index_first_axis_residual from flash_attn.bert_padding import index_first_axis, index_first_axis_residual
try: try:
from flash_attn.ops.fused_dense import FusedDenseTD from flash_attn.ops.fused_dense import FusedDense
except ImportError: except ImportError:
FusedDenseTD = None FusedDense = None
try: try:
from flash_attn.ops.layer_norm import dropout_add_layer_norm, layer_norm from flash_attn.ops.layer_norm import dropout_add_layer_norm, layer_norm
...@@ -70,6 +70,8 @@ def create_mlp_cls(config, layer_idx=None, return_residual=False): ...@@ -70,6 +70,8 @@ def create_mlp_cls(config, layer_idx=None, return_residual=False):
activation=partial(F.gelu, approximate=approximate), activation=partial(F.gelu, approximate=approximate),
return_residual=return_residual) return_residual=return_residual)
else: else:
if FusedDenseGeluDense is None:
raise ImportError('fused_dense is not installed')
mlp_checkpoint_lvl = getattr(config, 'mlp_checkpoint_lvl', 0) mlp_checkpoint_lvl = getattr(config, 'mlp_checkpoint_lvl', 0)
# mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer # mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
if isinstance(mlp_checkpoint_lvl, Sequence): if isinstance(mlp_checkpoint_lvl, Sequence):
...@@ -168,9 +170,9 @@ class BertPooler(nn.Module): ...@@ -168,9 +170,9 @@ class BertPooler(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
fused_bias_fc = getattr(config, 'fused_bias_fc', False) fused_bias_fc = getattr(config, 'fused_bias_fc', False)
if fused_bias_fc and FusedDenseTD is None: if fused_bias_fc and FusedDense is None:
raise ImportError('fused_dense is not installed') raise ImportError('fused_dense is not installed')
linear_cls = nn.Linear if not fused_bias_fc else FusedDenseTD linear_cls = nn.Linear if not fused_bias_fc else FusedDense
self.dense = linear_cls(config.hidden_size, config.hidden_size) self.dense = linear_cls(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh() self.activation = nn.Tanh()
...@@ -188,12 +190,12 @@ class BertPredictionHeadTransform(nn.Module): ...@@ -188,12 +190,12 @@ class BertPredictionHeadTransform(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
fused_bias_fc = getattr(config, 'fused_bias_fc', False) fused_bias_fc = getattr(config, 'fused_bias_fc', False)
if fused_bias_fc and FusedDenseTD is None: if fused_bias_fc and FusedDense is None:
raise ImportError('fused_dense is not installed') raise ImportError('fused_dense is not installed')
self.fused_dropout_add_ln = getattr(config, 'fused_dropout_add_ln', False) self.fused_dropout_add_ln = getattr(config, 'fused_dropout_add_ln', False)
if self.fused_dropout_add_ln and layer_norm is None: if self.fused_dropout_add_ln and layer_norm is None:
raise ImportError('dropout_add_layer_norm is not installed') raise ImportError('dropout_add_layer_norm is not installed')
linear_cls = nn.Linear if not fused_bias_fc else FusedDenseTD linear_cls = nn.Linear if not fused_bias_fc else FusedDense
self.dense = linear_cls(config.hidden_size, config.hidden_size) self.dense = linear_cls(config.hidden_size, config.hidden_size)
approximate = 'tanh' if config.hidden_act in ['gelu_new', 'gelu_fast'] else 'none' approximate = 'tanh' if config.hidden_act in ['gelu_new', 'gelu_fast'] else 'none'
self.transform_act_fn = nn.GELU(approximate=approximate) self.transform_act_fn = nn.GELU(approximate=approximate)
...@@ -215,9 +217,9 @@ class BertLMPredictionHead(nn.Module): ...@@ -215,9 +217,9 @@ class BertLMPredictionHead(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
fused_bias_fc = getattr(config, 'fused_bias_fc', False) fused_bias_fc = getattr(config, 'fused_bias_fc', False)
if fused_bias_fc and FusedDenseTD is None: if fused_bias_fc and FusedDense is None:
raise ImportError('fused_dense is not installed') raise ImportError('fused_dense is not installed')
linear_cls = nn.Linear if not fused_bias_fc else FusedDenseTD linear_cls = nn.Linear if not fused_bias_fc else FusedDense
self.transform = BertPredictionHeadTransform(config) self.transform = BertPredictionHeadTransform(config)
......
...@@ -61,6 +61,8 @@ def create_mlp_cls(config, layer_idx=None): ...@@ -61,6 +61,8 @@ def create_mlp_cls(config, layer_idx=None):
assert layer_idx is not None assert layer_idx is not None
mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx] mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx]
if fused_dense_gelu_dense: if fused_dense_gelu_dense:
if FusedDenseGeluDense is None:
raise ImportError('fused_dense is not installed')
mlp_cls = partial(FusedDenseGeluDense, hidden_features=inner_dim, mlp_cls = partial(FusedDenseGeluDense, hidden_features=inner_dim,
checkpoint_lvl=mlp_checkpoint_lvl) checkpoint_lvl=mlp_checkpoint_lvl)
elif fused_dense_sqrelu_dense: elif fused_dense_sqrelu_dense:
......
...@@ -21,9 +21,9 @@ except ImportError: ...@@ -21,9 +21,9 @@ except ImportError:
flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None
try: try:
from flash_attn.ops.fused_dense import FusedDenseTD, FusedDenseResidual from flash_attn.ops.fused_dense import FusedDense
except ImportError: except ImportError:
FusedDenseTD, FusedDenseResidual = None, None FusedDense = None
try: try:
from flash_attn.layers.rotary import RotaryEmbedding from flash_attn.layers.rotary import RotaryEmbedding
...@@ -270,7 +270,7 @@ class CrossAttention(nn.Module): ...@@ -270,7 +270,7 @@ class CrossAttention(nn.Module):
class LinearResidual(nn.Linear): class LinearResidual(nn.Linear):
"""Wrap nn.Linear to return the residual as well. For compatibility with FusedDenseResidual. """Wrap nn.Linear to return the residual as well. For compatibility with FusedDense.
""" """
def forward(self, input: torch.Tensor) -> torch.Tensor: def forward(self, input: torch.Tensor) -> torch.Tensor:
...@@ -311,10 +311,11 @@ class MHA(nn.Module): ...@@ -311,10 +311,11 @@ class MHA(nn.Module):
assert RotaryEmbedding is not None, 'rotary_emb is not installed' assert RotaryEmbedding is not None, 'rotary_emb is not installed'
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base) self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base)
if fused_bias_fc and FusedDenseTD is None: if fused_bias_fc and FusedDense is None:
raise ImportError('fused_dense is not installed') raise ImportError('fused_dense is not installed')
linear_cls = nn.Linear if not fused_bias_fc else FusedDenseTD linear_cls = nn.Linear if not fused_bias_fc else FusedDense
linear_resid_cls = LinearResidual if not fused_bias_fc else FusedDenseResidual linear_resid_cls = (LinearResidual if not fused_bias_fc
else partial(FusedDense, return_residual=True))
if not self.cross_attn: if not self.cross_attn:
if not self.return_residual: if not self.return_residual:
self.Wqkv = linear_cls(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs) self.Wqkv = linear_cls(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs)
......
...@@ -5,11 +5,9 @@ import torch.nn as nn ...@@ -5,11 +5,9 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
try: try:
from flash_attn.ops.fused_dense import fused_dense_gelu_dense_function_td from flash_attn.ops.fused_dense import FusedDenseGeluDense
from flash_attn.ops.fused_dense import fused_dense_res_gelu_dense_function_td
except ImportError: except ImportError:
fused_dense_gelu_dense_function_td = None FusedDenseGeluDense = None
fused_dense_res_gelu_dense_function_td = None
class Mlp(nn.Module): class Mlp(nn.Module):
...@@ -30,43 +28,3 @@ class Mlp(nn.Module): ...@@ -30,43 +28,3 @@ class Mlp(nn.Module):
y = self.activation(y) y = self.activation(y)
y = self.fc2(y) y = self.fc2(y)
return y if not self.return_residual else (y, x) return y if not self.return_residual else (y, x)
class FusedDenseGeluDense(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, bias=True,
checkpoint_lvl=0, heuristic=0, return_residual=False, device=None, dtype=None):
"""
checkpoint_lvl (increasing lvl means slower but more memory saving):
0: no recomputation in the bwd
1: recompute gelu_out in the bwd
2: recompute gelu_in and gelu_out in the bwd
heuristic:
-1: don't fuse gemm + gelu (separate kernel)
0..4: use this heuristic for the algo section in the fused gemm + gelu
For CUDA >= 11.8, you'd want heuristic=0 for both fp16 and bf16 for best perf.
For CUDA <= 11.7, you'd want heuristic=1 for fp16 and heuristic=-1 for bf16.
return_residual: whether to return the input x along with the output. This is for
performance reason: for post-norm architecture, returning the input allows us
to fuse the backward of nn.Linear with the residual connection.
"""
assert checkpoint_lvl in [0, 1, 2]
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
assert bias == True, "DenseGeluDense module without bias is currently not supported"
assert (fused_dense_gelu_dense_function_td is not None
and fused_dense_res_gelu_dense_function_td is not None), 'fused_dense_lib is not installed'
self.checkpoint_lvl = checkpoint_lvl
self.heuristic = heuristic
self.return_residual = return_residual
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias, **factory_kwargs)
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, **factory_kwargs)
def forward(self, x):
assert x.is_cuda
fn = (fused_dense_gelu_dense_function_td if not self.return_residual
else fused_dense_res_gelu_dense_function_td)
return fn(x, self.fc1.weight, self.fc1.bias, self.fc2.weight, self.fc2.bias,
self.checkpoint_lvl, self.heuristic)
This diff is collapsed.
...@@ -6,29 +6,44 @@ import pytest ...@@ -6,29 +6,44 @@ import pytest
from einops import rearrange from einops import rearrange
from flash_attn.ops.fused_dense import FusedDenseTD, FusedDenseGeluDenseTD from flash_attn.ops.fused_dense import FusedDense, FusedDenseGeluDense
from flash_attn.ops.fused_dense import FusedDenseResidual, FusedDenseResGeluDense
@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) @pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
@pytest.mark.parametrize('return_residual', [False, True])
@pytest.mark.parametrize('has_bias', [True, False])
@pytest.mark.parametrize('out_features', [1024, 4096]) @pytest.mark.parametrize('out_features', [1024, 4096])
@pytest.mark.parametrize('in_features', [1024, 4096]) @pytest.mark.parametrize('in_features', [1024, 4096])
def test_fused_linear_bias(in_features, out_features, dtype): def test_fused_linear_bias(in_features, out_features, has_bias, return_residual, dtype):
device = 'cuda' device = 'cuda'
rtol, atol = (3e-3, 1e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3) rtol, atol = (3e-3, 1e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3)
# set seed # set seed
torch.random.manual_seed(0) torch.random.manual_seed(0)
batch_size = 8 batch_size = 8
seqlen = 512 seqlen = 512
x_pt = torch.randn(batch_size, seqlen, in_features, device=device, dtype=dtype, requires_grad=True) x_pt = torch.randn(batch_size, seqlen, in_features, device=device, dtype=dtype,
requires_grad=True)
x = x_pt.detach().clone().requires_grad_() x = x_pt.detach().clone().requires_grad_()
model_pt = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype) model_pt = torch.nn.Linear(in_features, out_features, bias=has_bias, device=device, dtype=dtype)
model = FusedDenseTD(in_features, out_features, device=device, dtype=dtype) model = FusedDense(in_features, out_features, bias=has_bias, return_residual=return_residual,
device=device, dtype=dtype)
with torch.no_grad(): with torch.no_grad():
model.weight.copy_(model_pt.weight) model.weight.copy_(model_pt.weight)
model.bias.copy_(model_pt.bias) if has_bias:
model.bias.copy_(model_pt.bias)
out_pt = model_pt(x_pt) out_pt = model_pt(x_pt)
out = model(x) if not return_residual:
out = model(x)
else:
out, x_copy = model(x)
x_copy = (x_copy[..., :out_features] if out_features < in_features
else F.pad(x_copy, (0, out_features - in_features)))
x_pt_copy = (x_pt[..., :out_features] if out_features < in_features
else F.pad(x_pt, (0, out_features - in_features)))
# Just add some random function of the residual
out_pt = out_pt + F.gelu(x_pt_copy)
out = out + F.gelu(x_copy)
# with torch.no_grad(): # with torch.no_grad():
# out_fl = F.linear(x_pt.float(), model.weight.float(), model.bias.float()).half() # out_fl = F.linear(x_pt.float(), model.weight.float(), model.bias.float()).half()
assert torch.allclose(out, out_pt, rtol=rtol, atol=atol) assert torch.allclose(out, out_pt, rtol=rtol, atol=atol)
...@@ -40,66 +55,52 @@ def test_fused_linear_bias(in_features, out_features, dtype): ...@@ -40,66 +55,52 @@ def test_fused_linear_bias(in_features, out_features, dtype):
assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=atol) assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=atol)
# The error for d_weight and d_bias is quite a bit higher # The error for d_weight and d_bias is quite a bit higher
assert torch.allclose(model.weight.grad, model_pt.weight.grad, rtol=rtol, atol=atol * 10) assert torch.allclose(model.weight.grad, model_pt.weight.grad, rtol=rtol, atol=atol * 10)
assert torch.allclose(model.bias.grad, model_pt.bias.grad, rtol=rtol, atol=atol * 5) if has_bias:
assert torch.allclose(model.bias.grad, model_pt.bias.grad, rtol=rtol, atol=atol * 5)
@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
@pytest.mark.parametrize('out_features,in_features', [(1024, 1024), (4096, 4096)])
def test_fused_linear_bias_residual(in_features, out_features, dtype):
device = 'cuda'
rtol, atol = (3e-3, 1e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3)
# set seed
torch.random.manual_seed(0)
batch_size = 8
seqlen = 512
x_pt = torch.randn(batch_size, seqlen, in_features, device=device, dtype=dtype, requires_grad=True)
x = x_pt.detach().clone().requires_grad_()
model_pt = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype)
model = FusedDenseResidual(in_features, out_features, device=device, dtype=dtype)
with torch.no_grad():
model.weight.copy_(model_pt.weight)
model.bias.copy_(model_pt.bias)
out_pt = model_pt(x_pt) + F.gelu(x_pt) # Just add some random function of the residual x_pt
out, x_copy = model(x)
out = out + F.gelu(x_copy)
assert torch.allclose(out, out_pt, rtol=rtol, atol=atol * 2)
# If we don't divide by batch_size, the gradient gets a bit too large.
g = torch.randn_like(out) / 32
out_pt.backward(g)
out.backward(g)
assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=atol)
# The error for d_weight and d_bias is quite a bit higher
assert torch.allclose(model.weight.grad, model_pt.weight.grad, rtol=rtol, atol=atol * 10)
assert torch.allclose(model.bias.grad, model_pt.bias.grad, rtol=rtol, atol=atol * 5)
@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) @pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
@pytest.mark.parametrize('heuristic', [1, -1]) @pytest.mark.parametrize('heuristic', [0, -1])
@pytest.mark.parametrize('checkpoint_lvl', [0, 1, 2]) @pytest.mark.parametrize('checkpoint_lvl', [0, 1, 2])
@pytest.mark.parametrize('return_residual', [False, True])
@pytest.mark.parametrize('has_bias2', [True, False])
@pytest.mark.parametrize('has_bias1', [True, False])
@pytest.mark.parametrize('out_features', [1024, 4096]) @pytest.mark.parametrize('out_features', [1024, 4096])
@pytest.mark.parametrize('in_features', [1024, 4096]) @pytest.mark.parametrize('in_features', [1024, 4096])
def test_fused_dense_gelu_dense(in_features, out_features, checkpoint_lvl, heuristic, dtype): def test_fused_dense_gelu_dense(in_features, out_features, has_bias1, has_bias2, return_residual,
checkpoint_lvl, heuristic, dtype):
device = 'cuda' device = 'cuda'
rtol, atol = (3e-3, 1e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3) rtol, atol = (3e-3, 3e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3)
# set seed # set seed
torch.random.manual_seed(0) torch.random.manual_seed(0)
batch_size = 8 batch_size = 8
seqlen = 512 seqlen = 512
x_pt = torch.randn(batch_size, seqlen, in_features, device=device, dtype=dtype, requires_grad=True) x_pt = torch.randn(batch_size, seqlen, in_features, device=device, dtype=dtype,
requires_grad=True)
x = x_pt.detach().clone().requires_grad_() x = x_pt.detach().clone().requires_grad_()
model_pt_fc1 = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype) model_pt_fc1 = torch.nn.Linear(in_features, out_features, bias=has_bias1, device=device,
model_pt_fc2 = torch.nn.Linear(out_features, in_features, device=device, dtype=dtype) dtype=dtype)
model = FusedDenseGeluDenseTD(in_features, out_features, in_features, model_pt_fc2 = torch.nn.Linear(out_features, in_features, bias=has_bias2, device=device,
checkpoint_lvl=checkpoint_lvl, heuristic=heuristic, dtype=dtype)
device=device, dtype=dtype) model = FusedDenseGeluDense(in_features, out_features, in_features, bias1=has_bias1,
bias2=has_bias2, return_residual=return_residual,
checkpoint_lvl=checkpoint_lvl, heuristic=heuristic,
device=device, dtype=dtype)
with torch.no_grad(): with torch.no_grad():
model.fc1.weight.copy_(model_pt_fc1.weight) model.fc1.weight.copy_(model_pt_fc1.weight)
model.fc1.bias.copy_(model_pt_fc1.bias) if has_bias1:
model.fc1.bias.copy_(model_pt_fc1.bias)
model.fc2.weight.copy_(model_pt_fc2.weight) model.fc2.weight.copy_(model_pt_fc2.weight)
model.fc2.bias.copy_(model_pt_fc2.bias) if has_bias2:
model.fc2.bias.copy_(model_pt_fc2.bias)
out_pt = model_pt_fc2(F.gelu(model_pt_fc1(x_pt), approximate='tanh')) out_pt = model_pt_fc2(F.gelu(model_pt_fc1(x_pt), approximate='tanh'))
out = model(x) if not return_residual:
out = model(x)
else:
out, x_copy = model(x)
# Just add some random function of the residual
out_pt = out_pt + F.gelu(x_pt)
out = out + F.gelu(x_copy)
assert torch.allclose(out, out_pt, rtol=rtol, atol=atol) assert torch.allclose(out, out_pt, rtol=rtol, atol=atol)
# If we don't divide by batch_size, the gradient gets a bit too large. # If we don't divide by batch_size, the gradient gets a bit too large.
...@@ -109,46 +110,8 @@ def test_fused_dense_gelu_dense(in_features, out_features, checkpoint_lvl, heuri ...@@ -109,46 +110,8 @@ def test_fused_dense_gelu_dense(in_features, out_features, checkpoint_lvl, heuri
assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=atol) assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=atol)
# The error for d_weight and d_bias is quite a bit higher # The error for d_weight and d_bias is quite a bit higher
assert torch.allclose(model.fc1.weight.grad, model_pt_fc1.weight.grad, rtol=rtol, atol=atol * 10) assert torch.allclose(model.fc1.weight.grad, model_pt_fc1.weight.grad, rtol=rtol, atol=atol * 10)
assert torch.allclose(model.fc1.bias.grad, model_pt_fc1.bias.grad, rtol=rtol, atol=atol * 5) if has_bias1:
assert torch.allclose(model.fc2.weight.grad, model_pt_fc2.weight.grad, rtol=rtol, atol=atol * 10) assert torch.allclose(model.fc1.bias.grad, model_pt_fc1.bias.grad, rtol=rtol, atol=atol * 5)
assert torch.allclose(model.fc2.bias.grad, model_pt_fc2.bias.grad, rtol=rtol, atol=atol * 5)
@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
@pytest.mark.parametrize('checkpoint_lvl', [0, 1, 2])
@pytest.mark.parametrize('out_features', [1024, 4096])
@pytest.mark.parametrize('in_features', [1024, 4096])
def test_fused_dense_residual_gelu_dense(in_features, out_features, checkpoint_lvl, dtype):
device = 'cuda'
rtol, atol = (3e-3, 1e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3)
# set seed
torch.random.manual_seed(0)
batch_size = 8
seqlen = 512
x_pt = torch.randn(batch_size, seqlen, in_features, device=device, dtype=dtype, requires_grad=True)
x = x_pt.detach().clone().requires_grad_()
model_pt_fc1 = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype)
model_pt_fc2 = torch.nn.Linear(out_features, in_features, device=device, dtype=dtype)
model = FusedDenseResGeluDense(in_features, out_features, in_features,
checkpoint_lvl=checkpoint_lvl,
device=device, dtype=dtype)
with torch.no_grad():
model.fc1.weight.copy_(model_pt_fc1.weight)
model.fc1.bias.copy_(model_pt_fc1.bias)
model.fc2.weight.copy_(model_pt_fc2.weight)
model.fc2.bias.copy_(model_pt_fc2.bias)
out_pt = model_pt_fc2(F.gelu(model_pt_fc1(x_pt), approximate='tanh')) + F.gelu(x_pt)
out, x_copy = model(x)
out = out + F.gelu(x_copy)
assert torch.allclose(out, out_pt, rtol=rtol, atol=atol * 2)
# If we don't divide by batch_size, the gradient gets a bit too large.
g = torch.randn_like(out) / 32
out_pt.backward(g)
out.backward(g)
assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=atol)
# The error for d_weight and d_bias is quite a bit higher
assert torch.allclose(model.fc1.weight.grad, model_pt_fc1.weight.grad, rtol=rtol, atol=atol * 10)
assert torch.allclose(model.fc1.bias.grad, model_pt_fc1.bias.grad, rtol=rtol, atol=atol * 5)
assert torch.allclose(model.fc2.weight.grad, model_pt_fc2.weight.grad, rtol=rtol, atol=atol * 10) assert torch.allclose(model.fc2.weight.grad, model_pt_fc2.weight.grad, rtol=rtol, atol=atol * 10)
assert torch.allclose(model.fc2.bias.grad, model_pt_fc2.bias.grad, rtol=rtol, atol=atol * 5) if has_bias2:
assert torch.allclose(model.fc2.bias.grad, model_pt_fc2.bias.grad, rtol=rtol, atol=atol * 5)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment