Commit 97951590 authored by Tri Dao's avatar Tri Dao
Browse files

[Rotary] Set device before launching Triton kernel to avoid error

parent 6d673cd9
...@@ -205,6 +205,9 @@ def apply_rotary( ...@@ -205,6 +205,9 @@ def apply_rotary(
grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads) # noqa grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads) # noqa
BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4) BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4)
# Need this, otherwise Triton tries to launch from cuda:0 and we get
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
with torch.cuda.device(x.device.index):
rotary_kernel[grid]( rotary_kernel[grid](
output, # data ptrs output, # data ptrs
x, x,
......
...@@ -148,9 +148,6 @@ def test_baichuan_parallel_forward(model_name, world_size): ...@@ -148,9 +148,6 @@ def test_baichuan_parallel_forward(model_name, world_size):
rank = parallel_state.get_tensor_model_parallel_rank() rank = parallel_state.get_tensor_model_parallel_rank()
process_group = parallel_state.get_tensor_model_parallel_group() process_group = parallel_state.get_tensor_model_parallel_group()
# Need this, otherwise the Triton kernel seems to launched from the wrong device.
torch.cuda.set_device(device)
pretrained_state_dict = remap_state_dict_hf_baichuan( pretrained_state_dict = remap_state_dict_hf_baichuan(
state_dict_from_pretrained(model_name), config state_dict_from_pretrained(model_name), config
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment