Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
767b71cc
Unverified
Commit
767b71cc
authored
Jul 23, 2023
by
Joel Lamy-Poirier
Committed by
GitHub
Jul 23, 2023
Browse files
Fix random state for dropout_layer_norm (#315)
parent
d38357dd
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
10 deletions
+10
-10
csrc/layer_norm/ln_api.cpp
csrc/layer_norm/ln_api.cpp
+10
-10
No files found.
csrc/layer_norm/ln_api.cpp
View file @
767b71cc
...
@@ -229,11 +229,6 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
...
@@ -229,11 +229,6 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
// Request the kernel launcher.
// Request the kernel launcher.
auto
launcher
=
get_fwd_launcher
(
wtype
,
itype
,
rtype
,
otype
,
ctype
,
round_multiple
(
hidden_size
,
multiple
));
auto
launcher
=
get_fwd_launcher
(
wtype
,
itype
,
rtype
,
otype
,
ctype
,
round_multiple
(
hidden_size
,
multiple
));
// Query the kernel-specific launch parameters.
launcher
(
launch_params
,
true
);
at
::
Tensor
workspace
,
barrier
;
// Set the kernel runtime parameters.
// Set the kernel runtime parameters.
layer_norm
::
FwdParams
&
params
=
launch_params
.
params
;
layer_norm
::
FwdParams
&
params
=
launch_params
.
params
;
params
.
rows
=
rows
;
params
.
rows
=
rows
;
...
@@ -252,6 +247,11 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
...
@@ -252,6 +247,11 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
params
.
rowscale_const
=
rowscale_const
;
params
.
rowscale_const
=
rowscale_const
;
params
.
is_rms_norm
=
is_rms_norm
;
params
.
is_rms_norm
=
is_rms_norm
;
// Query the kernel-specific launch parameters.
launcher
(
launch_params
,
true
);
at
::
Tensor
workspace
,
barrier
;
if
(
dropout_p
>
0.
f
)
{
if
(
dropout_p
>
0.
f
)
{
// number of times random will be generated per thread, to offset philox counter in thc random
// number of times random will be generated per thread, to offset philox counter in thc random
// state
// state
...
@@ -594,11 +594,6 @@ std::vector<at::Tensor> dropout_add_ln_parallel_residual_fwd(
...
@@ -594,11 +594,6 @@ std::vector<at::Tensor> dropout_add_ln_parallel_residual_fwd(
// Request the kernel launcher.
// Request the kernel launcher.
auto
launcher
=
get_parallel_fwd_launcher
(
wtype
,
itype
,
rtype
,
otype
,
ctype
,
round_multiple
(
hidden_size
,
multiple
));
auto
launcher
=
get_parallel_fwd_launcher
(
wtype
,
itype
,
rtype
,
otype
,
ctype
,
round_multiple
(
hidden_size
,
multiple
));
// Query the kernel-specific launch parameters.
launcher
(
launch_params
,
true
);
at
::
Tensor
workspace
,
barrier
;
// Set the kernel runtime parameters.
// Set the kernel runtime parameters.
layer_norm
::
FwdParams
&
params
=
launch_params
.
params
;
layer_norm
::
FwdParams
&
params
=
launch_params
.
params
;
params
.
rows
=
rows
;
params
.
rows
=
rows
;
...
@@ -621,6 +616,11 @@ std::vector<at::Tensor> dropout_add_ln_parallel_residual_fwd(
...
@@ -621,6 +616,11 @@ std::vector<at::Tensor> dropout_add_ln_parallel_residual_fwd(
params
.
inverse_cols
=
1.
f
/
float
(
params
.
cols
);
params
.
inverse_cols
=
1.
f
/
float
(
params
.
cols
);
params
.
is_rms_norm
=
is_rms_norm
;
params
.
is_rms_norm
=
is_rms_norm
;
// Query the kernel-specific launch parameters.
launcher
(
launch_params
,
true
);
at
::
Tensor
workspace
,
barrier
;
if
(
dropout_p
>
0.
f
)
{
if
(
dropout_p
>
0.
f
)
{
// number of times random will be generated per thread, to offset philox counter in thc random
// number of times random will be generated per thread, to offset philox counter in thc random
// state
// state
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment