Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
97951590
Commit
97951590
authored
Sep 05, 2023
by
Tri Dao
Browse files
[Rotary] Set device before launching Triton kernel to avoid error
parent
6d673cd9
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
30 additions
and
30 deletions
+30
-30
flash_attn/ops/triton/rotary.py
flash_attn/ops/triton/rotary.py
+30
-27
tests/models/test_baichuan.py
tests/models/test_baichuan.py
+0
-3
No files found.
flash_attn/ops/triton/rotary.py
View file @
97951590
...
...
@@ -205,31 +205,34 @@ def apply_rotary(
grid
=
lambda
META
:
(
triton
.
cdiv
(
seqlen
,
META
[
"BLOCK_M"
]),
batch
,
nheads
)
# noqa
BLOCK_M
=
4
if
interleaved
else
(
8
if
rotary_dim
<=
64
else
4
)
rotary_kernel
[
grid
](
output
,
# data ptrs
x
,
cos
,
sin
,
cu_seqlens
,
seqlen_offsets
,
seqlen
,
# shapes
nheads
,
rotary_dim
,
seqlen_ro
,
seqlen
//
128
,
# key for triton cache (limit number of compilations)
output
.
stride
(
0
)
if
not
is_varlen
else
0
,
# batch_strides if not varlen else 0
output
.
stride
(
-
3
),
# seqlen_stride or total_seqlen_stride
output
.
stride
(
-
2
),
# nheads_stride
output
.
stride
(
-
1
),
# headdim_stride
x
.
stride
(
0
)
if
not
is_varlen
else
0
,
# batch_strides if not varlen else 0
x
.
stride
(
-
3
),
# seqlen stride or total_seqlen_stride
x
.
stride
(
-
2
),
# nheads stride
x
.
stride
(
-
1
),
# headdim stride
BLOCK_K
,
isinstance
(
seqlen_offsets
,
torch
.
Tensor
),
is_varlen
,
interleaved
,
conjugate
,
BLOCK_M
,
)
# Need this, otherwise Triton tries to launch from cuda:0 and we get
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
with
torch
.
cuda
.
device
(
x
.
device
.
index
):
rotary_kernel
[
grid
](
output
,
# data ptrs
x
,
cos
,
sin
,
cu_seqlens
,
seqlen_offsets
,
seqlen
,
# shapes
nheads
,
rotary_dim
,
seqlen_ro
,
seqlen
//
128
,
# key for triton cache (limit number of compilations)
output
.
stride
(
0
)
if
not
is_varlen
else
0
,
# batch_strides if not varlen else 0
output
.
stride
(
-
3
),
# seqlen_stride or total_seqlen_stride
output
.
stride
(
-
2
),
# nheads_stride
output
.
stride
(
-
1
),
# headdim_stride
x
.
stride
(
0
)
if
not
is_varlen
else
0
,
# batch_strides if not varlen else 0
x
.
stride
(
-
3
),
# seqlen stride or total_seqlen_stride
x
.
stride
(
-
2
),
# nheads stride
x
.
stride
(
-
1
),
# headdim stride
BLOCK_K
,
isinstance
(
seqlen_offsets
,
torch
.
Tensor
),
is_varlen
,
interleaved
,
conjugate
,
BLOCK_M
,
)
return
output
tests/models/test_baichuan.py
View file @
97951590
...
...
@@ -148,9 +148,6 @@ def test_baichuan_parallel_forward(model_name, world_size):
rank
=
parallel_state
.
get_tensor_model_parallel_rank
()
process_group
=
parallel_state
.
get_tensor_model_parallel_group
()
# Need this, otherwise the Triton kernel seems to launched from the wrong device.
torch
.
cuda
.
set_device
(
device
)
pretrained_state_dict
=
remap_state_dict_hf_baichuan
(
state_dict_from_pretrained
(
model_name
),
config
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment