Commit 97951590 authored by Tri Dao's avatar Tri Dao
Browse files

[Rotary] Set device before launching Triton kernel to avoid error

parent 6d673cd9
...@@ -205,31 +205,34 @@ def apply_rotary( ...@@ -205,31 +205,34 @@ def apply_rotary(
grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads) # noqa grid = lambda META: (triton.cdiv(seqlen, META["BLOCK_M"]), batch, nheads) # noqa
BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4) BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4)
rotary_kernel[grid]( # Need this, otherwise Triton tries to launch from cuda:0 and we get
output, # data ptrs # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
x, with torch.cuda.device(x.device.index):
cos, rotary_kernel[grid](
sin, output, # data ptrs
cu_seqlens, x,
seqlen_offsets, cos,
seqlen, # shapes sin,
nheads, cu_seqlens,
rotary_dim, seqlen_offsets,
seqlen_ro, seqlen, # shapes
seqlen // 128, # key for triton cache (limit number of compilations) nheads,
output.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0 rotary_dim,
output.stride(-3), # seqlen_stride or total_seqlen_stride seqlen_ro,
output.stride(-2), # nheads_stride seqlen // 128, # key for triton cache (limit number of compilations)
output.stride(-1), # headdim_stride output.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0
x.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0 output.stride(-3), # seqlen_stride or total_seqlen_stride
x.stride(-3), # seqlen stride or total_seqlen_stride output.stride(-2), # nheads_stride
x.stride(-2), # nheads stride output.stride(-1), # headdim_stride
x.stride(-1), # headdim stride x.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0
BLOCK_K, x.stride(-3), # seqlen stride or total_seqlen_stride
isinstance(seqlen_offsets, torch.Tensor), x.stride(-2), # nheads stride
is_varlen, x.stride(-1), # headdim stride
interleaved, BLOCK_K,
conjugate, isinstance(seqlen_offsets, torch.Tensor),
BLOCK_M, is_varlen,
) interleaved,
conjugate,
BLOCK_M,
)
return output return output
...@@ -148,9 +148,6 @@ def test_baichuan_parallel_forward(model_name, world_size): ...@@ -148,9 +148,6 @@ def test_baichuan_parallel_forward(model_name, world_size):
rank = parallel_state.get_tensor_model_parallel_rank() rank = parallel_state.get_tensor_model_parallel_rank()
process_group = parallel_state.get_tensor_model_parallel_group() process_group = parallel_state.get_tensor_model_parallel_group()
# Need this, otherwise the Triton kernel seems to launched from the wrong device.
torch.cuda.set_device(device)
pretrained_state_dict = remap_state_dict_hf_baichuan( pretrained_state_dict = remap_state_dict_hf_baichuan(
state_dict_from_pretrained(model_name), config state_dict_from_pretrained(model_name), config
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment