Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Uni-Core
Commits
3f498d32
"launch/vscode:/vscode.git/clone" did not exist on "c13ea718999806322e2c88fdd40d06aa45801990"
Commit
3f498d32
authored
Jul 13, 2022
by
Guolin Ke
Browse files
save memory for softmax
parent
97fd9948
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
12 additions
and
15 deletions
+12
-15
csrc/softmax_dropout/interface.cpp
csrc/softmax_dropout/interface.cpp
+2
-2
csrc/softmax_dropout/softmax_dropout_kernel.cu
csrc/softmax_dropout/softmax_dropout_kernel.cu
+5
-8
unicore/modules/layer_norm.py
unicore/modules/layer_norm.py
+5
-5
No files found.
csrc/softmax_dropout/interface.cpp
View file @
3f498d32
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
std
::
vector
<
c10
::
optional
<
torch
::
Tensor
>>
fwd_cuda
(
std
::
vector
<
c10
::
optional
<
torch
::
Tensor
>>
fwd_cuda
(
bool
is_training
,
bool
is_training
,
const
torch
::
Tensor
&
input
,
torch
::
Tensor
&
input
,
float
dropout_prob
,
float
dropout_prob
,
c10
::
optional
<
at
::
Generator
>
gen_
c10
::
optional
<
at
::
Generator
>
gen_
);
);
...
@@ -25,7 +25,7 @@ torch::Tensor bwd_cuda(
...
@@ -25,7 +25,7 @@ torch::Tensor bwd_cuda(
std
::
vector
<
c10
::
optional
<
torch
::
Tensor
>>
fwd
(
std
::
vector
<
c10
::
optional
<
torch
::
Tensor
>>
fwd
(
bool
is_training
,
bool
is_training
,
const
torch
::
Tensor
&
input
,
torch
::
Tensor
&
input
,
float
dropout_prob
,
float
dropout_prob
,
c10
::
optional
<
at
::
Generator
>
gen_
c10
::
optional
<
at
::
Generator
>
gen_
)
{
)
{
...
...
csrc/softmax_dropout/softmax_dropout_kernel.cu
View file @
3f498d32
...
@@ -18,7 +18,7 @@
...
@@ -18,7 +18,7 @@
std
::
vector
<
c10
::
optional
<
torch
::
Tensor
>>
fwd_cuda
(
std
::
vector
<
c10
::
optional
<
torch
::
Tensor
>>
fwd_cuda
(
bool
is_training
,
bool
is_training
,
const
torch
::
Tensor
&
input
,
torch
::
Tensor
&
input
,
float
dropout_prob
,
float
dropout_prob
,
c10
::
optional
<
at
::
Generator
>
gen_
c10
::
optional
<
at
::
Generator
>
gen_
)
{
)
{
...
@@ -29,11 +29,10 @@ std::vector<c10::optional<torch::Tensor>> fwd_cuda(
...
@@ -29,11 +29,10 @@ std::vector<c10::optional<torch::Tensor>> fwd_cuda(
auto
act_options
=
input
.
options
().
requires_grad
(
false
);
auto
act_options
=
input
.
options
().
requires_grad
(
false
);
auto
mask_options
=
act_options
.
dtype
(
softmax_mask_dtype
(
k_seq_len
));
auto
mask_options
=
act_options
.
dtype
(
softmax_mask_dtype
(
k_seq_len
));
torch
::
Tensor
softmax_results
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
act_options
);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void
*
input_ptr
=
reinterpret_cast
<
void
*>
(
input
.
data_ptr
());
void
*
input_ptr
=
reinterpret_cast
<
void
*>
(
input
.
data_ptr
());
void
*
softmax_results_ptr
=
reinterpret_cast
<
void
*>
(
softmax_results
.
data_ptr
());
void
*
softmax_results_ptr
=
reinterpret_cast
<
void
*>
(
input
.
data_ptr
());
// Padded Softmax
// Padded Softmax
bool
softmax_success
=
false
;
bool
softmax_success
=
false
;
...
@@ -84,7 +83,7 @@ std::vector<c10::optional<torch::Tensor>> fwd_cuda(
...
@@ -84,7 +83,7 @@ std::vector<c10::optional<torch::Tensor>> fwd_cuda(
softmax_success
=
false
;
softmax_success
=
false
;
}
}
if
(
softmax_success
)
{
if
(
softmax_success
)
{
return
{
dropout_results
,
dropout_mask
,
softmax_results
};
return
{
dropout_results
,
dropout_mask
,
input
};
}
else
{
}
else
{
return
{
c10
::
optional
<
torch
::
Tensor
>
(),
c10
::
optional
<
torch
::
Tensor
>
(),
c10
::
optional
<
torch
::
Tensor
>
()};
return
{
c10
::
optional
<
torch
::
Tensor
>
(),
c10
::
optional
<
torch
::
Tensor
>
(),
c10
::
optional
<
torch
::
Tensor
>
()};
}
}
...
@@ -120,7 +119,7 @@ std::vector<c10::optional<torch::Tensor>> fwd_cuda(
...
@@ -120,7 +119,7 @@ std::vector<c10::optional<torch::Tensor>> fwd_cuda(
softmax_success
=
false
;
softmax_success
=
false
;
}
}
if
(
softmax_success
)
{
if
(
softmax_success
)
{
return
{
softmax_results
,
c10
::
optional
<
torch
::
Tensor
>
(),
softmax_results
};
return
{
input
,
c10
::
optional
<
torch
::
Tensor
>
(),
input
};
}
else
{
}
else
{
return
{
c10
::
optional
<
torch
::
Tensor
>
(),
c10
::
optional
<
torch
::
Tensor
>
(),
c10
::
optional
<
torch
::
Tensor
>
()};
return
{
c10
::
optional
<
torch
::
Tensor
>
(),
c10
::
optional
<
torch
::
Tensor
>
(),
c10
::
optional
<
torch
::
Tensor
>
()};
}
}
...
@@ -131,9 +130,7 @@ torch::Tensor bwd_cuda(
...
@@ -131,9 +130,7 @@ torch::Tensor bwd_cuda(
torch
::
Tensor
&
output_grads
,
torch
::
Tensor
&
output_grads
,
const
torch
::
Tensor
&
softmax_results
,
const
torch
::
Tensor
&
softmax_results
,
const
c10
::
optional
<
torch
::
Tensor
>
&
dropout_mask
,
const
c10
::
optional
<
torch
::
Tensor
>
&
dropout_mask
,
float
dropout_prob
float
dropout_prob
)
{
)
{
const
int
attn_batches
=
output_grads
.
size
(
0
);
const
int
attn_batches
=
output_grads
.
size
(
0
);
const
int
q_seq_len
=
output_grads
.
size
(
1
);
const
int
q_seq_len
=
output_grads
.
size
(
1
);
const
int
k_seq_len
=
output_grads
.
size
(
2
);
const
int
k_seq_len
=
output_grads
.
size
(
2
);
...
...
unicore/modules/layer_norm.py
View file @
3f498d32
...
@@ -15,12 +15,12 @@ class FusedLayerNormFastFunction(torch.autograd.Function):
...
@@ -15,12 +15,12 @@ class FusedLayerNormFastFunction(torch.autograd.Function):
def
forward
(
ctx
,
input
,
weight
,
bias
,
normalized_shape
,
eps
):
def
forward
(
ctx
,
input
,
weight
,
bias
,
normalized_shape
,
eps
):
ctx
.
normalized_shape
=
normalized_shape
ctx
.
normalized_shape
=
normalized_shape
ctx
.
eps
=
eps
ctx
.
eps
=
eps
input
_
=
input
.
contiguous
()
input
=
input
.
contiguous
()
weight
_
=
weight
.
contiguous
()
weight
=
weight
.
contiguous
()
bias
_
=
bias
.
contiguous
()
bias
=
bias
.
contiguous
()
output
,
mean
,
invvar
=
unicore_fused_layernorm
.
forward
(
output
,
mean
,
invvar
=
unicore_fused_layernorm
.
forward
(
input
_
,
ctx
.
normalized_shape
,
weight
_
,
bias
_
,
ctx
.
eps
)
input
,
ctx
.
normalized_shape
,
weight
,
bias
,
ctx
.
eps
)
ctx
.
save_for_backward
(
input
_
,
weight
_
,
bias
_
,
mean
,
invvar
)
ctx
.
save_for_backward
(
input
,
weight
,
bias
,
mean
,
invvar
)
return
output
return
output
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
grad_output
):
def
backward
(
ctx
,
grad_output
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment