Commit 3f498d32 authored by Guolin Ke's avatar Guolin Ke
Browse files

save memory for softmax

parent 97fd9948
......@@ -5,7 +5,7 @@
std::vector<c10::optional<torch::Tensor>> fwd_cuda(
bool is_training,
const torch::Tensor &input,
torch::Tensor &input,
float dropout_prob,
c10::optional<at::Generator> gen_
);
......@@ -25,7 +25,7 @@ torch::Tensor bwd_cuda(
std::vector<c10::optional<torch::Tensor>> fwd(
bool is_training,
const torch::Tensor &input,
torch::Tensor &input,
float dropout_prob,
c10::optional<at::Generator> gen_
) {
......
......@@ -18,7 +18,7 @@
std::vector<c10::optional<torch::Tensor>> fwd_cuda(
bool is_training,
const torch::Tensor &input,
torch::Tensor &input,
float dropout_prob,
c10::optional<at::Generator> gen_
) {
......@@ -29,11 +29,10 @@ std::vector<c10::optional<torch::Tensor>> fwd_cuda(
auto act_options = input.options().requires_grad(false);
auto mask_options = act_options.dtype(softmax_mask_dtype(k_seq_len));
torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void *input_ptr = reinterpret_cast<void *>(input.data_ptr());
void *softmax_results_ptr = reinterpret_cast<void *>(softmax_results.data_ptr());
void *softmax_results_ptr = reinterpret_cast<void *>(input.data_ptr());
// Padded Softmax
bool softmax_success = false;
......@@ -84,7 +83,7 @@ std::vector<c10::optional<torch::Tensor>> fwd_cuda(
softmax_success = false;
}
if (softmax_success) {
return {dropout_results, dropout_mask, softmax_results};
return {dropout_results, dropout_mask, input};
} else {
return {c10::optional<torch::Tensor>(), c10::optional<torch::Tensor>(), c10::optional<torch::Tensor>()};
}
......@@ -120,7 +119,7 @@ std::vector<c10::optional<torch::Tensor>> fwd_cuda(
softmax_success = false;
}
if (softmax_success) {
return {softmax_results, c10::optional<torch::Tensor>(), softmax_results};
return {input, c10::optional<torch::Tensor>(), input};
} else {
return {c10::optional<torch::Tensor>(), c10::optional<torch::Tensor>(), c10::optional<torch::Tensor>()};
}
......@@ -131,9 +130,7 @@ torch::Tensor bwd_cuda(
torch::Tensor &output_grads,
const torch::Tensor &softmax_results,
const c10::optional<torch::Tensor> &dropout_mask,
float dropout_prob
)
{
float dropout_prob) {
const int attn_batches = output_grads.size(0);
const int q_seq_len = output_grads.size(1);
const int k_seq_len = output_grads.size(2);
......
......@@ -15,12 +15,12 @@ class FusedLayerNormFastFunction(torch.autograd.Function):
def forward(ctx, input, weight, bias, normalized_shape, eps):
ctx.normalized_shape = normalized_shape
ctx.eps = eps
input_ = input.contiguous()
weight_ = weight.contiguous()
bias_ = bias.contiguous()
input = input.contiguous()
weight = weight.contiguous()
bias = bias.contiguous()
output, mean, invvar = unicore_fused_layernorm.forward(
input_, ctx.normalized_shape, weight_, bias_, ctx.eps)
ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
input, ctx.normalized_shape, weight, bias, ctx.eps)
ctx.save_for_backward(input, weight, bias, mean, invvar)
return output
@staticmethod
def backward(ctx, grad_output):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment