Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
e9820408
Commit
e9820408
authored
Nov 28, 2024
by
Woosuk Kwon
Browse files
test
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
fdf6d72b
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
5 deletions
+5
-5
csrc/flash_attn/flash_api.cpp
csrc/flash_attn/flash_api.cpp
+5
-5
No files found.
csrc/flash_attn/flash_api.cpp
View file @
e9820408
...
...
@@ -665,11 +665,11 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
// number of times random will be generated per thread, to offset philox counter in thc random
// state
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
//
int64_t counter_offset = params.b * params.h * 32;
//
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
//
auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
//
//
Forward kernel will populate memory with the seed and offset.
//
params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
int64_t
counter_offset
=
params
.
b
*
params
.
h
*
32
;
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kFloat32
).
device
(
torch
::
kCUDA
);
auto
rng_state
=
torch
::
empty
({
2
},
options
.
dtype
(
torch
::
kInt64
));
// Forward kernel will populate memory with the seed and offset.
params
.
rng_state
=
reinterpret_cast
<
uint64_t
*>
(
rng_state
.
data_ptr
());
// if (p_dropout > 0.0) {
// auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment