"launch/vscode:/vscode.git/clone" did not exist on "c13ea718999806322e2c88fdd40d06aa45801990"
Commit 3f498d32 authored by Guolin Ke's avatar Guolin Ke
Browse files

save memory for softmax

parent 97fd9948
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
std::vector<c10::optional<torch::Tensor>> fwd_cuda( std::vector<c10::optional<torch::Tensor>> fwd_cuda(
bool is_training, bool is_training,
const torch::Tensor &input, torch::Tensor &input,
float dropout_prob, float dropout_prob,
c10::optional<at::Generator> gen_ c10::optional<at::Generator> gen_
); );
...@@ -25,7 +25,7 @@ torch::Tensor bwd_cuda( ...@@ -25,7 +25,7 @@ torch::Tensor bwd_cuda(
std::vector<c10::optional<torch::Tensor>> fwd( std::vector<c10::optional<torch::Tensor>> fwd(
bool is_training, bool is_training,
const torch::Tensor &input, torch::Tensor &input,
float dropout_prob, float dropout_prob,
c10::optional<at::Generator> gen_ c10::optional<at::Generator> gen_
) { ) {
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
std::vector<c10::optional<torch::Tensor>> fwd_cuda( std::vector<c10::optional<torch::Tensor>> fwd_cuda(
bool is_training, bool is_training,
const torch::Tensor &input, torch::Tensor &input,
float dropout_prob, float dropout_prob,
c10::optional<at::Generator> gen_ c10::optional<at::Generator> gen_
) { ) {
...@@ -29,11 +29,10 @@ std::vector<c10::optional<torch::Tensor>> fwd_cuda( ...@@ -29,11 +29,10 @@ std::vector<c10::optional<torch::Tensor>> fwd_cuda(
auto act_options = input.options().requires_grad(false); auto act_options = input.options().requires_grad(false);
auto mask_options = act_options.dtype(softmax_mask_dtype(k_seq_len)); auto mask_options = act_options.dtype(softmax_mask_dtype(k_seq_len));
torch::Tensor softmax_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax) // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void *input_ptr = reinterpret_cast<void *>(input.data_ptr()); void *input_ptr = reinterpret_cast<void *>(input.data_ptr());
void *softmax_results_ptr = reinterpret_cast<void *>(softmax_results.data_ptr()); void *softmax_results_ptr = reinterpret_cast<void *>(input.data_ptr());
// Padded Softmax // Padded Softmax
bool softmax_success = false; bool softmax_success = false;
...@@ -84,7 +83,7 @@ std::vector<c10::optional<torch::Tensor>> fwd_cuda( ...@@ -84,7 +83,7 @@ std::vector<c10::optional<torch::Tensor>> fwd_cuda(
softmax_success = false; softmax_success = false;
} }
if (softmax_success) { if (softmax_success) {
return {dropout_results, dropout_mask, softmax_results}; return {dropout_results, dropout_mask, input};
} else { } else {
return {c10::optional<torch::Tensor>(), c10::optional<torch::Tensor>(), c10::optional<torch::Tensor>()}; return {c10::optional<torch::Tensor>(), c10::optional<torch::Tensor>(), c10::optional<torch::Tensor>()};
} }
...@@ -120,7 +119,7 @@ std::vector<c10::optional<torch::Tensor>> fwd_cuda( ...@@ -120,7 +119,7 @@ std::vector<c10::optional<torch::Tensor>> fwd_cuda(
softmax_success = false; softmax_success = false;
} }
if (softmax_success) { if (softmax_success) {
return {softmax_results, c10::optional<torch::Tensor>(), softmax_results}; return {input, c10::optional<torch::Tensor>(), input};
} else { } else {
return {c10::optional<torch::Tensor>(), c10::optional<torch::Tensor>(), c10::optional<torch::Tensor>()}; return {c10::optional<torch::Tensor>(), c10::optional<torch::Tensor>(), c10::optional<torch::Tensor>()};
} }
...@@ -131,9 +130,7 @@ torch::Tensor bwd_cuda( ...@@ -131,9 +130,7 @@ torch::Tensor bwd_cuda(
torch::Tensor &output_grads, torch::Tensor &output_grads,
const torch::Tensor &softmax_results, const torch::Tensor &softmax_results,
const c10::optional<torch::Tensor> &dropout_mask, const c10::optional<torch::Tensor> &dropout_mask,
float dropout_prob float dropout_prob) {
)
{
const int attn_batches = output_grads.size(0); const int attn_batches = output_grads.size(0);
const int q_seq_len = output_grads.size(1); const int q_seq_len = output_grads.size(1);
const int k_seq_len = output_grads.size(2); const int k_seq_len = output_grads.size(2);
......
...@@ -15,12 +15,12 @@ class FusedLayerNormFastFunction(torch.autograd.Function): ...@@ -15,12 +15,12 @@ class FusedLayerNormFastFunction(torch.autograd.Function):
def forward(ctx, input, weight, bias, normalized_shape, eps): def forward(ctx, input, weight, bias, normalized_shape, eps):
ctx.normalized_shape = normalized_shape ctx.normalized_shape = normalized_shape
ctx.eps = eps ctx.eps = eps
input_ = input.contiguous() input = input.contiguous()
weight_ = weight.contiguous() weight = weight.contiguous()
bias_ = bias.contiguous() bias = bias.contiguous()
output, mean, invvar = unicore_fused_layernorm.forward( output, mean, invvar = unicore_fused_layernorm.forward(
input_, ctx.normalized_shape, weight_, bias_, ctx.eps) input, ctx.normalized_shape, weight, bias, ctx.eps)
ctx.save_for_backward(input_, weight_, bias_, mean, invvar) ctx.save_for_backward(input, weight, bias, mean, invvar)
return output return output
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment