Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
018b1b5f
Commit
018b1b5f
authored
Nov 28, 2024
by
Woosuk Kwon
Browse files
test
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
e9820408
Pipeline
#2018
failed with stages
in 0 seconds
Changes
1
Pipelines
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
10 deletions
+10
-10
csrc/flash_attn/flash_api.cpp
csrc/flash_attn/flash_api.cpp
+10
-10
No files found.
csrc/flash_attn/flash_api.cpp
View file @
018b1b5f
...
@@ -406,16 +406,16 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
...
@@ -406,16 +406,16 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
params
,
batch_size
,
num_heads
,
head_size
,
seqlen_k
,
seqlen_q
,
params
,
batch_size
,
num_heads
,
head_size
,
seqlen_k
,
seqlen_q
,
head_size_rounded
,
p_dropout
,
/*num_splits*/
0
,
dprops
,
opts
);
head_size_rounded
,
p_dropout
,
/*num_splits*/
0
,
dprops
,
opts
);
// NOTE(woosuk): Commented out because they are not used in inference.
// number of times random will be generated per thread, to offset philox counter in thc random
// // number of times random will be generated per thread, to offset philox counter in thc random
// state
// // state
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
// // We use a custom RNG that increases the offset by batch_size * nheads * 32.
int64_t
counter_offset
=
params
.
b
*
params
.
h
*
32
;
// int64_t counter_offset = params.b * params.h * 32;
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
);
// auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
auto
rng_state
=
torch
::
empty
({
2
},
options
.
dtype
(
torch
::
kInt64
));
// auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
// Forward kernel will populate memory with the seed and offset.
// // Forward kernel will populate memory with the seed and offset.
params
.
rng_state
=
reinterpret_cast
<
uint64_t
*>
(
rng_state
.
data_ptr
());
// params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
// NOTE(woosuk): Commented out because they are not used in inference.
// if (p_dropout > 0.0) {
// if (p_dropout > 0.0) {
// auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
// auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
// gen_, at::cuda::detail::getDefaultCUDAGenerator());
// gen_, at::cuda::detail::getDefaultCUDAGenerator());
...
@@ -661,7 +661,6 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
...
@@ -661,7 +661,6 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
p_dropout
,
/*num_splits*/
0
,
dprops
,
opts
);
p_dropout
,
/*num_splits*/
0
,
dprops
,
opts
);
}
}
// NOTE(woosuk): Commented out because they are not used in inference.
// number of times random will be generated per thread, to offset philox counter in thc random
// number of times random will be generated per thread, to offset philox counter in thc random
// state
// state
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
...
@@ -671,6 +670,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
...
@@ -671,6 +670,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
// Forward kernel will populate memory with the seed and offset.
// Forward kernel will populate memory with the seed and offset.
params
.
rng_state
=
reinterpret_cast
<
uint64_t
*>
(
rng_state
.
data_ptr
());
params
.
rng_state
=
reinterpret_cast
<
uint64_t
*>
(
rng_state
.
data_ptr
());
// NOTE(woosuk): Commented out because they are not used in inference.
// if (p_dropout > 0.0) {
// if (p_dropout > 0.0) {
// auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
// auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
// gen_, at::cuda::detail::getDefaultCUDAGenerator());
// gen_, at::cuda::detail::getDefaultCUDAGenerator());
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment