Unverified Commit cf0ccd40 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Optimize rope in sgl kernel (#4267)

parent 3d56585a
...@@ -65,7 +65,7 @@ void apply_rope_pos_ids_cos_sin_cache( ...@@ -65,7 +65,7 @@ void apply_rope_pos_ids_cos_sin_cache(
static_cast<c_type*>(q_rope.data_ptr()), static_cast<c_type*>(q_rope.data_ptr()),
static_cast<c_type*>(k_rope.data_ptr()), static_cast<c_type*>(k_rope.data_ptr()),
static_cast<float*>(cos_sin_cache.data_ptr()), static_cast<float*>(cos_sin_cache.data_ptr()),
static_cast<int32_t*>(pos_ids.data_ptr()), static_cast<int64_t*>(pos_ids.data_ptr()),
nnz, nnz,
num_qo_heads, num_qo_heads,
num_kv_heads, num_kv_heads,
......
...@@ -139,14 +139,13 @@ def apply_rope_with_cos_sin_cache_inplace( ...@@ -139,14 +139,13 @@ def apply_rope_with_cos_sin_cache_inplace(
if cos_sin_cache.dtype != torch.float32: if cos_sin_cache.dtype != torch.float32:
raise ValueError("cos_sin_cache should be float32") raise ValueError("cos_sin_cache should be float32")
positions = positions.int()
torch.ops.sgl_kernel.apply_rope_pos_ids_cos_sin_cache( torch.ops.sgl_kernel.apply_rope_pos_ids_cos_sin_cache(
q=query.view(query.shape[0], -1, head_size), q=query.view(query.shape[0], -1, head_size),
k=key.view(key.shape[0], -1, head_size), k=key.view(key.shape[0], -1, head_size),
q_rope=query.view(query.shape[0], -1, head_size), q_rope=query.view(query.shape[0], -1, head_size),
k_rope=key.view(key.shape[0], -1, head_size), k_rope=key.view(key.shape[0], -1, head_size),
cos_sin_cache=cos_sin_cache, cos_sin_cache=cos_sin_cache,
pos_ids=positions, pos_ids=positions.long(),
interleave=(not is_neox), interleave=(not is_neox),
cuda_stream=get_cuda_stream(), cuda_stream=get_cuda_stream(),
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment