Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
50601bf4
"vscode:/vscode.git/clone" did not exist on "7df79c86ddc4ebf36de94671b454485caf6cc395"
Commit
50601bf4
authored
May 19, 2024
by
Woosuk Kwon
Browse files
Use int64_t for page pointer arth
parent
f80aa0fd
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
9 deletions
+9
-9
csrc/flash_attn/src/utils.h
csrc/flash_attn/src/utils.h
+9
-9
No files found.
csrc/flash_attn/src/utils.h
View file @
50601bf4
...
@@ -296,21 +296,21 @@ void cp_async_wait() {
...
@@ -296,21 +296,21 @@ void cp_async_wait() {
// assumes that the tensor has already been positioned at the correct head.
// assumes that the tensor has already been positioned at the correct head.
template
<
typename
Kernel_traits
>
template
<
typename
Kernel_traits
>
__forceinline__
__device__
__forceinline__
__device__
int
resolve_thread_kv_page_slice_offset
(
const
int
tidx
,
const
int
n_block_max
,
const
int
page_block_size
,
int
64_t
resolve_thread_kv_page_slice_offset
(
const
int
tidx
,
const
int
n_block_max
,
const
int
page_block_size
,
const
int
*
block_table
,
const
int
page_stride
,
const
int
row_stride
)
{
const
int
*
block_table
,
const
int
page_stride
,
const
int
row_stride
)
{
constexpr
int
kGmemThreadsPerRow
=
Kernel_traits
::
kGmemThreadsPerRow
;
constexpr
int
kGmemThreadsPerRow
=
Kernel_traits
::
kGmemThreadsPerRow
;
constexpr
int
kGmemRowsPerThread
=
Kernel_traits
::
kGmemRowsPerThread
;
constexpr
int
kGmemRowsPerThread
=
Kernel_traits
::
kGmemRowsPerThread
;
constexpr
int
kGmemElemsPerLoad
=
Kernel_traits
::
kGmemElemsPerLoad
;
constexpr
int
kGmemElemsPerLoad
=
Kernel_traits
::
kGmemElemsPerLoad
;
constexpr
int
kBlockN
=
Kernel_traits
::
kBlockN
;
constexpr
int
kBlockN
=
Kernel_traits
::
kBlockN
;
const
int
col_offset
=
tidx
%
kGmemThreadsPerRow
*
kGmemElemsPerLoad
;
const
int
64_t
col_offset
=
tidx
%
kGmemThreadsPerRow
*
kGmemElemsPerLoad
;
const
int
block_row_offset
=
tidx
/
kGmemThreadsPerRow
*
kGmemRowsPerThread
;
const
int
64_t
block_row_offset
=
tidx
/
kGmemThreadsPerRow
*
kGmemRowsPerThread
;
const
int
global_row_offset
=
block_row_offset
+
(
n_block_max
-
1
)
*
kBlockN
;
const
int
64_t
global_row_offset
=
block_row_offset
+
(
n_block_max
-
1
)
*
kBlockN
;
const
int
page_offset
=
global_row_offset
%
page_block_size
;
const
int
64_t
page_offset
=
global_row_offset
%
page_block_size
;
const
int
virtual_page_idx
=
global_row_offset
/
page_block_size
;
const
int
65_t
virtual_page_idx
=
global_row_offset
/
page_block_size
;
return
block_table
[
virtual_page_idx
]
*
page_stride
return
((
int64_t
)
block_table
[
virtual_page_idx
]
)
*
((
int64_t
)
page_stride
)
+
page_offset
*
row_stride
+
page_offset
*
((
int64_t
)
row_stride
)
+
col_offset
;
+
col_offset
;
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment