Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d93c976a
Unverified
Commit
d93c976a
authored
May 14, 2025
by
Lucas Wilkinson
Committed by
GitHub
May 14, 2025
Browse files
[Kernel] Have rotary embeddings support tensors (#18046)
Signed-off-by:
Lucas Wilkinson
<
lwilkinson@neuralmagic.com
>
parent
749f7925
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
59 additions
and
31 deletions
+59
-31
csrc/pos_encoding_kernels.cu
csrc/pos_encoding_kernels.cu
+28
-12
tests/kernels/core/test_pos_encoding.py
tests/kernels/core/test_pos_encoding.py
+13
-1
tests/kernels/core/test_rotary_embedding.py
tests/kernels/core/test_rotary_embedding.py
+15
-2
vllm/_custom_ops.py
vllm/_custom_ops.py
+3
-16
No files found.
csrc/pos_encoding_kernels.cu
View file @
d93c976a
...
@@ -44,7 +44,8 @@ inline __device__ void apply_rotary_embedding(
...
@@ -44,7 +44,8 @@ inline __device__ void apply_rotary_embedding(
// head_size]
// head_size]
const
scalar_t
*
cache_ptr
,
const
int
head_size
,
const
int
num_heads
,
const
scalar_t
*
cache_ptr
,
const
int
head_size
,
const
int
num_heads
,
const
int
num_kv_heads
,
const
int
rot_dim
,
const
int
token_idx
,
const
int
num_kv_heads
,
const
int
rot_dim
,
const
int
token_idx
,
const
int64_t
query_stride
,
const
int64_t
key_stride
)
{
const
int64_t
query_stride
,
const
int64_t
key_stride
,
const
int64_t
head_stride
)
{
const
int
embed_dim
=
rot_dim
/
2
;
const
int
embed_dim
=
rot_dim
/
2
;
const
scalar_t
*
cos_ptr
=
cache_ptr
;
const
scalar_t
*
cos_ptr
=
cache_ptr
;
const
scalar_t
*
sin_ptr
=
cache_ptr
+
embed_dim
;
const
scalar_t
*
sin_ptr
=
cache_ptr
+
embed_dim
;
...
@@ -52,7 +53,8 @@ inline __device__ void apply_rotary_embedding(
...
@@ -52,7 +53,8 @@ inline __device__ void apply_rotary_embedding(
const
int
nq
=
num_heads
*
embed_dim
;
const
int
nq
=
num_heads
*
embed_dim
;
for
(
int
i
=
threadIdx
.
x
;
i
<
nq
;
i
+=
blockDim
.
x
)
{
for
(
int
i
=
threadIdx
.
x
;
i
<
nq
;
i
+=
blockDim
.
x
)
{
const
int
head_idx
=
i
/
embed_dim
;
const
int
head_idx
=
i
/
embed_dim
;
const
int64_t
token_head
=
token_idx
*
query_stride
+
head_idx
*
head_size
;
const
int64_t
token_head
=
token_idx
*
query_stride
+
head_idx
*
head_stride
;
const
int
rot_offset
=
i
%
embed_dim
;
const
int
rot_offset
=
i
%
embed_dim
;
apply_token_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
apply_token_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
query
+
token_head
,
cos_ptr
,
sin_ptr
,
rot_offset
,
embed_dim
);
query
+
token_head
,
cos_ptr
,
sin_ptr
,
rot_offset
,
embed_dim
);
...
@@ -62,7 +64,8 @@ inline __device__ void apply_rotary_embedding(
...
@@ -62,7 +64,8 @@ inline __device__ void apply_rotary_embedding(
const
int
nk
=
num_kv_heads
*
embed_dim
;
const
int
nk
=
num_kv_heads
*
embed_dim
;
for
(
int
i
=
threadIdx
.
x
;
i
<
nk
;
i
+=
blockDim
.
x
)
{
for
(
int
i
=
threadIdx
.
x
;
i
<
nk
;
i
+=
blockDim
.
x
)
{
const
int
head_idx
=
i
/
embed_dim
;
const
int
head_idx
=
i
/
embed_dim
;
const
int64_t
token_head
=
token_idx
*
key_stride
+
head_idx
*
head_size
;
const
int64_t
token_head
=
token_idx
*
key_stride
+
head_idx
*
head_stride
;
const
int
rot_offset
=
i
%
embed_dim
;
const
int
rot_offset
=
i
%
embed_dim
;
apply_token_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
apply_token_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
key
+
token_head
,
cos_ptr
,
sin_ptr
,
rot_offset
,
embed_dim
);
key
+
token_head
,
cos_ptr
,
sin_ptr
,
rot_offset
,
embed_dim
);
...
@@ -84,7 +87,8 @@ __global__ void rotary_embedding_kernel(
...
@@ -84,7 +87,8 @@ __global__ void rotary_embedding_kernel(
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, rot_dim //
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, rot_dim //
// 2]
// 2]
const
int
rot_dim
,
const
int64_t
query_stride
,
const
int64_t
key_stride
,
const
int
rot_dim
,
const
int64_t
query_stride
,
const
int64_t
key_stride
,
const
int
num_heads
,
const
int
num_kv_heads
,
const
int
head_size
)
{
const
int64_t
head_stride
,
const
int
num_heads
,
const
int
num_kv_heads
,
const
int
head_size
)
{
// Each thread block is responsible for one token.
// Each thread block is responsible for one token.
const
int
token_idx
=
blockIdx
.
x
;
const
int
token_idx
=
blockIdx
.
x
;
int64_t
pos
=
positions
[
token_idx
];
int64_t
pos
=
positions
[
token_idx
];
...
@@ -92,7 +96,7 @@ __global__ void rotary_embedding_kernel(
...
@@ -92,7 +96,7 @@ __global__ void rotary_embedding_kernel(
apply_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
apply_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
query
,
key
,
cache_ptr
,
head_size
,
num_heads
,
num_kv_heads
,
rot_dim
,
query
,
key
,
cache_ptr
,
head_size
,
num_heads
,
num_kv_heads
,
rot_dim
,
token_idx
,
query_stride
,
key_stride
);
token_idx
,
query_stride
,
key_stride
,
head_stride
);
}
}
template
<
typename
scalar_t
,
bool
IS_NEOX
>
template
<
typename
scalar_t
,
bool
IS_NEOX
>
...
@@ -109,9 +113,9 @@ __global__ void batched_rotary_embedding_kernel(
...
@@ -109,9 +113,9 @@ __global__ void batched_rotary_embedding_kernel(
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, rot_dim //
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, rot_dim //
// 2]
// 2]
const
int64_t
*
__restrict__
cos_sin_cache_offsets
,
// [batch_size, seq_len]
const
int64_t
*
__restrict__
cos_sin_cache_offsets
,
// [batch_size, seq_len]
// or [num_tokens]
const
int
rot_dim
,
const
int64_t
query_stride
,
const
int64_t
key_stride
,
const
int
rot_dim
,
const
int64_t
query_stride
,
const
int64_t
key_stride
,
const
int
num_heads
,
const
int
num_kv_heads
,
const
int
head_size
)
{
const
int64_t
head_stride
,
const
int
num_heads
,
const
int
num_kv_heads
,
const
int
head_size
)
{
// Each thread block is responsible for one token.
// Each thread block is responsible for one token.
const
int
token_idx
=
blockIdx
.
x
;
const
int
token_idx
=
blockIdx
.
x
;
int64_t
pos
=
positions
[
token_idx
];
int64_t
pos
=
positions
[
token_idx
];
...
@@ -121,7 +125,7 @@ __global__ void batched_rotary_embedding_kernel(
...
@@ -121,7 +125,7 @@ __global__ void batched_rotary_embedding_kernel(
apply_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
apply_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
query
,
key
,
cache_ptr
,
head_size
,
num_heads
,
num_kv_heads
,
rot_dim
,
query
,
key
,
cache_ptr
,
head_size
,
num_heads
,
num_kv_heads
,
rot_dim
,
token_idx
,
query_stride
,
key_stride
);
token_idx
,
query_stride
,
key_stride
,
head_stride
);
}
}
}
// namespace vllm
}
// namespace vllm
...
@@ -179,6 +183,12 @@ void rotary_embedding(
...
@@ -179,6 +183,12 @@ void rotary_embedding(
int
seq_dim_idx
=
positions_ndim
-
1
;
int
seq_dim_idx
=
positions_ndim
-
1
;
int64_t
query_stride
=
query
.
stride
(
seq_dim_idx
);
int64_t
query_stride
=
query
.
stride
(
seq_dim_idx
);
int64_t
key_stride
=
key
.
has_value
()
?
key
->
stride
(
seq_dim_idx
)
:
0
;
int64_t
key_stride
=
key
.
has_value
()
?
key
->
stride
(
seq_dim_idx
)
:
0
;
// Determine head stride: for [*, heads, head_size] use stride of last dim;
// for flat [*, heads*head_size], heads blocks are contiguous of size
// head_size
int
query_ndim
=
query
.
dim
();
int64_t
head_stride
=
(
query_ndim
==
positions_ndim
+
2
)
?
query
.
stride
(
-
2
)
:
head_size
;
dim3
grid
(
num_tokens
);
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
<
int64_t
>
(
num_heads
*
rot_dim
/
2
,
512
));
dim3
block
(
std
::
min
<
int64_t
>
(
num_heads
*
rot_dim
/
2
,
512
));
...
@@ -190,14 +200,14 @@ void rotary_embedding(
...
@@ -190,14 +200,14 @@ void rotary_embedding(
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
has_value
()
?
key
->
data_ptr
<
scalar_t
>
()
:
nullptr
,
key
.
has_value
()
?
key
->
data_ptr
<
scalar_t
>
()
:
nullptr
,
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
rot_dim
,
query_stride
,
key_stride
,
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
rot_dim
,
query_stride
,
key_stride
,
num_heads
,
num_kv_heads
,
head_size
);
head_stride
,
num_heads
,
num_kv_heads
,
head_size
);
}
else
{
}
else
{
vllm
::
rotary_embedding_kernel
<
scalar_t
,
false
>
vllm
::
rotary_embedding_kernel
<
scalar_t
,
false
>
<<<
grid
,
block
,
0
,
stream
>>>
(
<<<
grid
,
block
,
0
,
stream
>>>
(
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
has_value
()
?
key
->
data_ptr
<
scalar_t
>
()
:
nullptr
,
key
.
has_value
()
?
key
->
data_ptr
<
scalar_t
>
()
:
nullptr
,
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
rot_dim
,
query_stride
,
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
rot_dim
,
query_stride
,
key_stride
,
num_heads
,
num_kv_heads
,
head_size
);
key_stride
,
head_stride
,
num_heads
,
num_kv_heads
,
head_size
);
}
}
});
});
}
}
...
@@ -263,6 +273,12 @@ void batched_rotary_embedding(
...
@@ -263,6 +273,12 @@ void batched_rotary_embedding(
int
seq_dim_idx
=
positions_ndim
-
1
;
int
seq_dim_idx
=
positions_ndim
-
1
;
int64_t
query_stride
=
query
.
stride
(
seq_dim_idx
);
int64_t
query_stride
=
query
.
stride
(
seq_dim_idx
);
int64_t
key_stride
=
key
.
has_value
()
?
key
->
stride
(
seq_dim_idx
)
:
0
;
int64_t
key_stride
=
key
.
has_value
()
?
key
->
stride
(
seq_dim_idx
)
:
0
;
// Determine head stride: for [*, heads, head_size] use stride of last dim;
// for flat [*, heads*head_size], heads blocks are contiguous of size
// head_size
int
query_ndim
=
query
.
dim
();
int64_t
head_stride
=
(
query_ndim
==
positions_ndim
+
2
)
?
query
.
stride
(
-
2
)
:
head_size
;
dim3
grid
(
num_tokens
);
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
<
int64_t
>
(
num_heads
*
rot_dim
/
2
,
512
));
dim3
block
(
std
::
min
<
int64_t
>
(
num_heads
*
rot_dim
/
2
,
512
));
...
@@ -276,7 +292,7 @@ void batched_rotary_embedding(
...
@@ -276,7 +292,7 @@ void batched_rotary_embedding(
key
.
has_value
()
?
key
->
data_ptr
<
scalar_t
>
()
:
nullptr
,
key
.
has_value
()
?
key
->
data_ptr
<
scalar_t
>
()
:
nullptr
,
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache_offsets
.
data_ptr
<
int64_t
>
(),
rot_dim
,
query_stride
,
cos_sin_cache_offsets
.
data_ptr
<
int64_t
>
(),
rot_dim
,
query_stride
,
key_stride
,
num_heads
,
num_kv_heads
,
head_size
);
key_stride
,
head_stride
,
num_heads
,
num_kv_heads
,
head_size
);
}
else
{
}
else
{
vllm
::
batched_rotary_embedding_kernel
<
scalar_t
,
false
>
vllm
::
batched_rotary_embedding_kernel
<
scalar_t
,
false
>
<<<
grid
,
block
,
0
,
stream
>>>
(
<<<
grid
,
block
,
0
,
stream
>>>
(
...
@@ -284,7 +300,7 @@ void batched_rotary_embedding(
...
@@ -284,7 +300,7 @@ void batched_rotary_embedding(
key
.
has_value
()
?
key
->
data_ptr
<
scalar_t
>
()
:
nullptr
,
key
.
has_value
()
?
key
->
data_ptr
<
scalar_t
>
()
:
nullptr
,
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache_offsets
.
data_ptr
<
int64_t
>
(),
rot_dim
,
query_stride
,
cos_sin_cache_offsets
.
data_ptr
<
int64_t
>
(),
rot_dim
,
query_stride
,
key_stride
,
num_heads
,
num_kv_heads
,
head_size
);
key_stride
,
head_stride
,
num_heads
,
num_kv_heads
,
head_size
);
}
}
});
});
}
}
tests/kernels/core/test_pos_encoding.py
View file @
d93c976a
...
@@ -29,12 +29,20 @@ def _get_flat_tensor_shape(batch_size: int, seq_len: int, num_heads: int,
...
@@ -29,12 +29,20 @@ def _get_flat_tensor_shape(batch_size: int, seq_len: int, num_heads: int,
return
(
batch_size
,
seq_len
,
num_heads
*
head_size
)
return
(
batch_size
,
seq_len
,
num_heads
*
head_size
)
# For testing sliced tensors
def
_get_padded_tensor_shape
(
batch_size
:
int
,
seq_len
:
int
,
num_heads
:
int
,
head_size
:
int
)
->
tuple
[
int
,
...]:
return
(
batch_size
,
seq_len
,
num_heads
,
head_size
+
64
)
def
_get_batch_tensor_shape
(
batch_size
:
int
,
seq_len
:
int
,
num_heads
:
int
,
def
_get_batch_tensor_shape
(
batch_size
:
int
,
seq_len
:
int
,
num_heads
:
int
,
head_size
:
int
)
->
tuple
[
int
,
...]:
head_size
:
int
)
->
tuple
[
int
,
...]:
return
(
batch_size
,
seq_len
,
num_heads
,
head_size
)
return
(
batch_size
,
seq_len
,
num_heads
,
head_size
)
TENSORS_SHAPES_FN
=
[
_get_batch_tensor_shape
,
_get_flat_tensor_shape
]
TENSORS_SHAPES_FN
=
[
_get_batch_tensor_shape
,
_get_flat_tensor_shape
,
_get_padded_tensor_shape
]
@
pytest
.
mark
.
parametrize
(
"is_neox_style"
,
IS_NEOX_STYLE
)
@
pytest
.
mark
.
parametrize
(
"is_neox_style"
,
IS_NEOX_STYLE
)
...
@@ -79,6 +87,10 @@ def test_rotary_embedding(
...
@@ -79,6 +87,10 @@ def test_rotary_embedding(
query
=
torch
.
randn
(
query_shape
,
dtype
=
dtype
)
query
=
torch
.
randn
(
query_shape
,
dtype
=
dtype
)
key
=
torch
.
randn_like
(
query
)
if
use_key
else
None
key
=
torch
.
randn_like
(
query
)
if
use_key
else
None
# slice tensor if required, noop otherwise
query
=
query
[...,
:
head_size
]
key
=
key
[...,
:
head_size
]
if
use_key
else
None
# NOTE(woosuk): The reference implementation should be executed first
# NOTE(woosuk): The reference implementation should be executed first
# because the custom kernel is in-place.
# because the custom kernel is in-place.
ref_query
,
ref_key
=
rope
.
forward_native
(
positions
,
query
,
key
)
ref_query
,
ref_key
=
rope
.
forward_native
(
positions
,
query
,
key
)
...
...
tests/kernels/core/test_rotary_embedding.py
View file @
d93c976a
...
@@ -38,9 +38,10 @@ def rotary_embedding_opcheck(rot,
...
@@ -38,9 +38,10 @@ def rotary_embedding_opcheck(rot,
@
pytest
.
mark
.
parametrize
(
"head_size"
,
[
32
,
108
])
@
pytest
.
mark
.
parametrize
(
"head_size"
,
[
32
,
108
])
@
pytest
.
mark
.
parametrize
(
"seq_len"
,
[
11
,
1024
])
@
pytest
.
mark
.
parametrize
(
"seq_len"
,
[
11
,
1024
])
@
pytest
.
mark
.
parametrize
(
"use_key"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_key"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"head_stride_is_contingous"
,
[
True
,
False
])
def
test_rotary_embedding_opcheck
(
dist_init
,
device
,
max_position
,
def
test_rotary_embedding_opcheck
(
dist_init
,
device
,
max_position
,
is_neox_style
,
rotary_dim
,
head_size
,
is_neox_style
,
rotary_dim
,
head_size
,
seq_len
,
use_key
):
seq_len
,
use_key
,
head_stride_is_contingous
):
batch_size
=
1
batch_size
=
1
base
=
10000
base
=
10000
num_heads
=
7
num_heads
=
7
...
@@ -50,15 +51,27 @@ def test_rotary_embedding_opcheck(dist_init, device, max_position,
...
@@ -50,15 +51,27 @@ def test_rotary_embedding_opcheck(dist_init, device, max_position,
positions
=
torch
.
randint
(
0
,
positions
=
torch
.
randint
(
0
,
max_position
,
(
batch_size
,
seq_len
),
max_position
,
(
batch_size
,
seq_len
),
device
=
device
)
device
=
device
)
head_stride
=
head_size
+
(
64
if
head_stride_is_contingous
else
0
)
query
=
torch
.
randn
(
batch_size
,
query
=
torch
.
randn
(
batch_size
,
seq_len
,
seq_len
,
num_heads
*
head_size
,
num_heads
,
head_stride
,
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
device
=
device
)
device
=
device
)
key
=
torch
.
randn_like
(
query
)
if
use_key
else
None
key
=
torch
.
randn_like
(
query
)
if
use_key
else
None
query
=
query
[...,
:
head_size
]
key
=
key
[...,
:
head_size
]
if
use_key
else
None
rotary_embedding_opcheck
(
rot
,
positions
,
query
,
key
)
rotary_embedding_opcheck
(
rot
,
positions
,
query
,
key
)
offsets
=
torch
.
zeros
(
batch_size
*
seq_len
,
offsets
=
torch
.
zeros
(
batch_size
*
seq_len
,
device
=
device
,
device
=
device
,
dtype
=
torch
.
long
)
dtype
=
torch
.
long
)
rotary_embedding_opcheck
(
rot
,
positions
,
query
,
key
,
offsets
)
rotary_embedding_opcheck
(
rot
,
positions
,
query
,
key
,
offsets
)
# if we have a contiguous head stride, test the alternate
# [..., num_heads * head_dim] shape/layout
if
head_stride_is_contingous
:
rotary_embedding_opcheck
(
rot
,
positions
,
query
.
flatten
(
start_dim
=-
2
),
key
.
flatten
(
start_dim
=-
2
)
if
use_key
else
None
)
vllm/_custom_ops.py
View file @
d93c976a
...
@@ -254,14 +254,8 @@ def rotary_embedding(
...
@@ -254,14 +254,8 @@ def rotary_embedding(
cos_sin_cache
:
torch
.
Tensor
,
cos_sin_cache
:
torch
.
Tensor
,
is_neox
:
bool
,
is_neox
:
bool
,
)
->
None
:
)
->
None
:
# TODO: Remove this contiguous call when the kernel is updated to support tensor slices
torch
.
ops
.
_C
.
rotary_embedding
(
positions
,
query
,
key
,
head_size
,
query_contiguous
=
query
.
contiguous
()
cos_sin_cache
,
is_neox
)
key_contiguous
=
key
.
contiguous
()
if
key
is
not
None
else
None
torch
.
ops
.
_C
.
rotary_embedding
(
positions
,
query_contiguous
,
key_contiguous
,
head_size
,
cos_sin_cache
,
is_neox
)
query
.
copy_
(
query_contiguous
)
if
key
is
not
None
:
key
.
copy_
(
key_contiguous
)
def
batched_rotary_embedding
(
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
def
batched_rotary_embedding
(
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
...
@@ -269,16 +263,9 @@ def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
...
@@ -269,16 +263,9 @@ def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
cos_sin_cache
:
torch
.
Tensor
,
is_neox
:
bool
,
cos_sin_cache
:
torch
.
Tensor
,
is_neox
:
bool
,
rot_dim
:
int
,
rot_dim
:
int
,
cos_sin_cache_offsets
:
torch
.
Tensor
)
->
None
:
cos_sin_cache_offsets
:
torch
.
Tensor
)
->
None
:
# TODO: Remove this contiguous call when the kernel is updated to support tensor slices
torch
.
ops
.
_C
.
batched_rotary_embedding
(
positions
,
query
,
key
,
head_size
,
query_contiguous
=
query
.
contiguous
()
key_contiguous
=
key
.
contiguous
()
if
key
is
not
None
else
None
torch
.
ops
.
_C
.
batched_rotary_embedding
(
positions
,
query_contiguous
,
key_contiguous
,
head_size
,
cos_sin_cache
,
is_neox
,
rot_dim
,
cos_sin_cache
,
is_neox
,
rot_dim
,
cos_sin_cache_offsets
)
cos_sin_cache_offsets
)
query
.
copy_
(
query_contiguous
)
if
key
is
not
None
:
key
.
copy_
(
key_contiguous
)
# layer norm ops
# layer norm ops
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment