Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
97951590
Commit
97951590
authored
Sep 05, 2023
by
Tri Dao
Browse files
[Rotary] Set device before launching Triton kernel to avoid error
parent
6d673cd9
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
30 additions
and
30 deletions
+30
-30
flash_attn/ops/triton/rotary.py
flash_attn/ops/triton/rotary.py
+30
-27
tests/models/test_baichuan.py
tests/models/test_baichuan.py
+0
-3
No files found.
flash_attn/ops/triton/rotary.py
View file @
97951590
...
@@ -205,6 +205,9 @@ def apply_rotary(
...
@@ -205,6 +205,9 @@ def apply_rotary(
grid
=
lambda
META
:
(
triton
.
cdiv
(
seqlen
,
META
[
"BLOCK_M"
]),
batch
,
nheads
)
# noqa
grid
=
lambda
META
:
(
triton
.
cdiv
(
seqlen
,
META
[
"BLOCK_M"
]),
batch
,
nheads
)
# noqa
BLOCK_M
=
4
if
interleaved
else
(
8
if
rotary_dim
<=
64
else
4
)
BLOCK_M
=
4
if
interleaved
else
(
8
if
rotary_dim
<=
64
else
4
)
# Need this, otherwise Triton tries to launch from cuda:0 and we get
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
with
torch
.
cuda
.
device
(
x
.
device
.
index
):
rotary_kernel
[
grid
](
rotary_kernel
[
grid
](
output
,
# data ptrs
output
,
# data ptrs
x
,
x
,
...
...
tests/models/test_baichuan.py
View file @
97951590
...
@@ -148,9 +148,6 @@ def test_baichuan_parallel_forward(model_name, world_size):
...
@@ -148,9 +148,6 @@ def test_baichuan_parallel_forward(model_name, world_size):
rank
=
parallel_state
.
get_tensor_model_parallel_rank
()
rank
=
parallel_state
.
get_tensor_model_parallel_rank
()
process_group
=
parallel_state
.
get_tensor_model_parallel_group
()
process_group
=
parallel_state
.
get_tensor_model_parallel_group
()
# Need this, otherwise the Triton kernel seems to launched from the wrong device.
torch
.
cuda
.
set_device
(
device
)
pretrained_state_dict
=
remap_state_dict_hf_baichuan
(
pretrained_state_dict
=
remap_state_dict_hf_baichuan
(
state_dict_from_pretrained
(
model_name
),
config
state_dict_from_pretrained
(
model_name
),
config
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment