Unverified Commit 7e9bd08f authored by Terry's avatar Terry Committed by GitHub
Browse files

Add batched RoPE kernel (#3095)

parent ae0ccb40
from typing import Optional
import argparse
import torch
import nvtx
from itertools import accumulate
from vllm.model_executor.layers.rotary_embedding import get_rope
def benchmark_rope_kernels_multi_lora(
is_neox_style: bool,
batch_size: int,
seq_len: int,
num_heads: int,
head_size: int,
rotary_dim: Optional[int],
dtype: torch.dtype,
seed: int,
device: str,
max_position: int = 8192,
base: int = 10000,
) -> None:
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.set_default_device(device)
if rotary_dim is None:
rotary_dim = head_size
# silulating serving 4 LoRAs
scaling_factors = [1, 2, 4, 8]
# batched RoPE can take multiple scaling factors
batched_rope = get_rope(head_size, rotary_dim, max_position, base,
is_neox_style, {
"type": "linear",
"factor": tuple(scaling_factors)
})
# non-batched RoPE takes only one scaling factor, we create multiple
# instances to simulate the same behavior
non_batched_ropes = []
for scaling_factor in scaling_factors:
non_batched_ropes.append(
get_rope(head_size, rotary_dim, max_position, base, is_neox_style,
{
"type": "linear",
"factor": (scaling_factor, )
}))
positions = torch.randint(0, max_position, (batch_size, seq_len))
query = torch.randn(batch_size,
seq_len,
num_heads * head_size,
dtype=dtype)
key = torch.randn_like(query)
# create query offsets for batched RoPE, we concat multiple kv cache
# together and each query needs to find the right kv cache of its type
offset_map = torch.tensor(
list(
accumulate([0] + [
max_position * scaling_factor * 2
for scaling_factor in scaling_factors[:-1]
])))
query_types = torch.randint(0,
len(scaling_factors), (batch_size, seq_len),
device=device)
# map query types to offsets
query_offsets = offset_map[query_types]
# the kernel takes flattened offsets
flatten_offsets = query_offsets.flatten()
# batched queries of the same type together for non-batched RoPE
queries = [query[query_types == i] for i in range(len(scaling_factors))]
keys = [key[query_types == i] for i in range(len(scaling_factors))]
packed_qkr = zip(queries, keys, non_batched_ropes)
# synchronize before start timing
torch.cuda.synchronize()
with nvtx.annotate("non-batched", color="yellow"):
for q, k, r in packed_qkr:
r.forward(positions, q, k)
torch.cuda.synchronize()
with nvtx.annotate("batched", color="green"):
batched_rope.forward(positions, query, key, flatten_offsets)
torch.cuda.synchronize()
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description="Benchmark the rotary embedding kernels.")
parser.add_argument("--is-neox-style", type=bool, default=True)
parser.add_argument("--batch-size", type=int, default=16)
parser.add_argument("--seq-len", type=int, default=512)
parser.add_argument("--num-heads", type=int, default=8)
parser.add_argument("--head-size",
type=int,
choices=[64, 80, 96, 112, 128, 256],
default=128)
parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32)
parser.add_argument("--dtype",
type=str,
choices=["bfloat16", "float"],
default="float")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--device",
type=str,
choices=["cuda:0", "cuda:1"],
default="cuda:0")
args = parser.parse_args()
print(args)
benchmark_rope_kernels_multi_lora(
is_neox_style=args.is_neox_style,
batch_size=args.batch_size,
seq_len=args.seq_len,
num_heads=args.num_heads,
head_size=args.head_size,
rotary_dim=args.rotary_dim,
dtype=getattr(torch, args.dtype),
seed=args.seed,
device=args.device,
)
...@@ -53,6 +53,16 @@ void rotary_embedding( ...@@ -53,6 +53,16 @@ void rotary_embedding(
torch::Tensor& cos_sin_cache, torch::Tensor& cos_sin_cache,
bool is_neox); bool is_neox);
void batched_rotary_embedding(
torch::Tensor& positions,
torch::Tensor& query,
torch::Tensor& key,
int head_size,
torch::Tensor& cos_sin_cache,
bool is_neox,
int rot_dim,
torch::Tensor& cos_sin_cache_offsets);
void silu_and_mul( void silu_and_mul(
torch::Tensor& out, torch::Tensor& out,
torch::Tensor& input); torch::Tensor& input);
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
namespace vllm { namespace vllm {
template<typename scalar_t, bool IS_NEOX> template<typename scalar_t, bool IS_NEOX>
inline __device__ void apply_rotary_embedding( inline __device__ void apply_token_rotary_embedding(
scalar_t* __restrict__ arr, scalar_t* __restrict__ arr,
const scalar_t* __restrict__ cos_ptr, const scalar_t* __restrict__ cos_ptr,
const scalar_t* __restrict__ sin_ptr, const scalar_t* __restrict__ sin_ptr,
...@@ -38,22 +38,18 @@ inline __device__ void apply_rotary_embedding( ...@@ -38,22 +38,18 @@ inline __device__ void apply_rotary_embedding(
} }
template<typename scalar_t, bool IS_NEOX> template<typename scalar_t, bool IS_NEOX>
__global__ void rotary_embedding_kernel( inline __device__ void apply_rotary_embedding(
const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens]
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] const scalar_t* cache_ptr,
const int rot_dim, const int head_size,
const int64_t query_stride,
const int64_t key_stride,
const int num_heads, const int num_heads,
const int num_kv_heads, const int num_kv_heads,
const int head_size) { const int rot_dim,
// Each thread block is responsible for one token. const int token_idx,
const int token_idx = blockIdx.x; const int64_t query_stride,
int64_t pos = positions[token_idx]; const int64_t key_stride)
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; {
const int embed_dim = rot_dim / 2; const int embed_dim = rot_dim / 2;
const scalar_t* cos_ptr = cache_ptr; const scalar_t* cos_ptr = cache_ptr;
const scalar_t* sin_ptr = cache_ptr + embed_dim; const scalar_t* sin_ptr = cache_ptr + embed_dim;
...@@ -63,7 +59,7 @@ __global__ void rotary_embedding_kernel( ...@@ -63,7 +59,7 @@ __global__ void rotary_embedding_kernel(
const int head_idx = i / embed_dim; const int head_idx = i / embed_dim;
const int64_t token_head = token_idx * query_stride + head_idx * head_size; const int64_t token_head = token_idx * query_stride + head_idx * head_size;
const int rot_offset = i % embed_dim; const int rot_offset = i % embed_dim;
apply_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr, apply_token_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr,
sin_ptr, rot_offset, embed_dim); sin_ptr, rot_offset, embed_dim);
} }
...@@ -72,11 +68,53 @@ __global__ void rotary_embedding_kernel( ...@@ -72,11 +68,53 @@ __global__ void rotary_embedding_kernel(
const int head_idx = i / embed_dim; const int head_idx = i / embed_dim;
const int64_t token_head = token_idx * key_stride + head_idx * head_size; const int64_t token_head = token_idx * key_stride + head_idx * head_size;
const int rot_offset = i % embed_dim; const int rot_offset = i % embed_dim;
apply_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr, apply_token_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr,
sin_ptr, rot_offset, embed_dim); sin_ptr, rot_offset, embed_dim);
} }
} }
template<typename scalar_t, bool IS_NEOX>
__global__ void rotary_embedding_kernel(
const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens]
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
const int rot_dim,
const int64_t query_stride,
const int64_t key_stride,
const int num_heads,
const int num_kv_heads,
const int head_size) {
// Each thread block is responsible for one token.
const int token_idx = blockIdx.x;
int64_t pos = positions[token_idx];
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;
apply_rotary_embedding<scalar_t, IS_NEOX>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride);
}
template<typename scalar_t, bool IS_NEOX>
__global__ void batched_rotary_embedding_kernel(
const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens]
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len] or [num_tokens]
const int rot_dim,
const int64_t query_stride,
const int64_t key_stride,
const int num_heads,
const int num_kv_heads,
const int head_size) {
// Each thread block is responsible for one token.
const int token_idx = blockIdx.x;
int64_t pos = positions[token_idx];
int64_t cos_sin_cache_offset = cos_sin_cache_offsets[token_idx];
const scalar_t* cache_ptr = cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim;
apply_rotary_embedding<scalar_t, IS_NEOX>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride);
}
} // namespace vllm } // namespace vllm
void rotary_embedding( void rotary_embedding(
...@@ -128,3 +166,61 @@ void rotary_embedding( ...@@ -128,3 +166,61 @@ void rotary_embedding(
} }
}); });
} }
/*
Batched version of rotary embedding, pack multiple LoRAs together
and process in batched manner.
*/
void batched_rotary_embedding(
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size]
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size]
int head_size,
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
bool is_neox,
int rot_dim,
torch::Tensor& cos_sin_cache_offsets // [num_tokens]
) {
int64_t num_tokens = cos_sin_cache_offsets.size(0);
int num_heads = query.size(-1) / head_size;
int num_kv_heads = key.size(-1) / head_size;
int64_t query_stride = query.stride(-2);
int64_t key_stride = key.stride(-2);
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * rot_dim / 2, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
query.scalar_type(),
"rotary_embedding",
[&] {
if (is_neox) {
vllm::batched_rotary_embedding_kernel<scalar_t, true><<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(),
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos_sin_cache.data_ptr<scalar_t>(),
cos_sin_cache_offsets.data_ptr<int64_t>(),
rot_dim,
query_stride,
key_stride,
num_heads,
num_kv_heads,
head_size);
} else {
vllm::batched_rotary_embedding_kernel<scalar_t, false><<<grid, block, 0, stream>>>(
positions.data_ptr<int64_t>(),
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos_sin_cache.data_ptr<scalar_t>(),
cos_sin_cache_offsets.data_ptr<int64_t>(),
rot_dim,
query_stride,
key_stride,
num_heads,
num_kv_heads,
head_size);
}
});
}
...@@ -56,6 +56,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -56,6 +56,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
&rotary_embedding, &rotary_embedding,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key"); "Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
ops.def(
"batched_rotary_embedding",
&batched_rotary_embedding,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key (supports multiple loras)");
// Quantization ops // Quantization ops
#ifndef USE_ROCM #ifndef USE_ROCM
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ"); ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
......
from typing import Optional from typing import List, Optional
import pytest import pytest
import torch import torch
from allclose_default import get_default_atol, get_default_rtol from allclose_default import get_default_atol, get_default_rtol
from itertools import accumulate
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
IS_NEOX_STYLE = [True, False] IS_NEOX_STYLE = [True, False]
...@@ -72,3 +73,135 @@ def test_rotary_embedding( ...@@ -72,3 +73,135 @@ def test_rotary_embedding(
ref_key, ref_key,
atol=get_default_atol(out_key), atol=get_default_atol(out_key),
rtol=get_default_rtol(out_key)) rtol=get_default_rtol(out_key))
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("seq_len", SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_batched_rotary_embedding(
is_neox_style: bool,
batch_size: int,
seq_len: int,
num_heads: int,
head_size: int,
rotary_dim: Optional[int],
dtype: torch.dtype,
seed: int,
device: str,
max_position: int = 8192,
base: int = 10000,
) -> None:
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.set_default_device(device)
if rotary_dim is None:
rotary_dim = head_size
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, {
"type": "linear",
"factor": (1, )
})
rope = rope.to(dtype=dtype)
positions = torch.randint(0, max_position, (batch_size, seq_len))
query = torch.randn(batch_size,
seq_len,
num_heads * head_size,
dtype=dtype)
key = torch.randn_like(query)
# NOTE(woosuk): The reference implementation should be executed first
# because the custom kernel is in-place.
ref_query, ref_key = rope._forward(positions, query, key)
out_query, out_key = rope.forward(positions,
query,
key,
offsets=torch.zeros(batch_size * seq_len,
dtype=int,
device=device))
# Compare the results.
assert torch.allclose(out_query,
ref_query,
atol=get_default_atol(out_query),
rtol=get_default_rtol(out_query))
assert torch.allclose(out_key,
ref_key,
atol=get_default_atol(out_key),
rtol=get_default_rtol(out_key))
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("seq_len", SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_batched_rotary_embedding_multi_lora(
is_neox_style: bool,
batch_size: int,
seq_len: int,
num_heads: int,
head_size: int,
rotary_dim: Optional[int],
dtype: torch.dtype,
seed: int,
device: str,
max_position: int = 8192,
base: int = 10000,
) -> None:
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.set_default_device(device)
if rotary_dim is None:
rotary_dim = head_size
scaling_factors: List[int] = [1, 2, 4]
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, {
"type": "linear",
"factor": tuple(scaling_factors)
})
rope = rope.to(dtype=dtype)
positions = torch.randint(0, max_position, (batch_size, seq_len))
query = torch.randn(batch_size,
seq_len,
num_heads * head_size,
dtype=dtype)
key = torch.randn_like(query)
offset_map = torch.tensor(
list(
accumulate([0] + [
max_position * scaling_factor * 2
for scaling_factor in scaling_factors[:-1]
])))
query_types = torch.randint(0,
len(scaling_factors), (batch_size, seq_len),
device=device)
query_offsets = offset_map[query_types]
# NOTE(woosuk): The reference implementation should be executed first
# because the custom kernel is in-place.
ref_query, ref_key = rope._forward(positions, query, key, query_offsets)
out_query, out_key = rope.forward(positions, query, key,
query_offsets.flatten())
# Compare the results.
assert torch.allclose(out_query,
ref_query,
atol=get_default_atol(out_query),
rtol=get_default_rtol(out_query))
assert torch.allclose(out_key,
ref_key,
atol=get_default_atol(out_key),
rtol=get_default_rtol(out_key))
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
# limitations under the License. # limitations under the License.
"""Rotary Positional Embeddings.""" """Rotary Positional Embeddings."""
import math import math
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -96,6 +96,7 @@ class RotaryEmbedding(nn.Module): ...@@ -96,6 +96,7 @@ class RotaryEmbedding(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
"""PyTorch-native implementation equivalent to forward().""" """PyTorch-native implementation equivalent to forward()."""
query = query.view(*query.shape[:-1], -1, self.head_size) query = query.view(*query.shape[:-1], -1, self.head_size)
...@@ -107,7 +108,9 @@ class RotaryEmbedding(nn.Module): ...@@ -107,7 +108,9 @@ class RotaryEmbedding(nn.Module):
query_pass = query[..., self.rotary_dim:] query_pass = query[..., self.rotary_dim:]
key_pass = key[..., self.rotary_dim:] key_pass = key[..., self.rotary_dim:]
cos_sin = self.cos_sin_cache[positions] self.cos_sin_cache = self.cos_sin_cache.to(positions.get_device())
cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
if offsets is not None else positions]
cos, sin = cos_sin.chunk(2, dim=-1) cos, sin = cos_sin.chunk(2, dim=-1)
if self.is_neox_style: if self.is_neox_style:
# NOTE(woosuk): Here we assume that the positions tensor has the # NOTE(woosuk): Here we assume that the positions tensor has the
...@@ -137,11 +140,19 @@ class RotaryEmbedding(nn.Module): ...@@ -137,11 +140,19 @@ class RotaryEmbedding(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# ops.rotary_embedding() is an in-place operation that self.cos_sin_cache = self.cos_sin_cache.to(positions.get_device())
# updates the query and key tensors. # ops.rotary_embedding()/batched_rotary_embedding() are in-place operations that
ops.rotary_embedding(positions, query, key, self.head_size, # update the query and key tensors.
self.cos_sin_cache, self.is_neox_style) if offsets is not None:
ops.batched_rotary_embedding(positions, query, key, self.head_size,
self.cos_sin_cache,
self.is_neox_style, self.rotary_dim,
offsets)
else:
ops.rotary_embedding(positions, query, key, self.head_size,
self.cos_sin_cache, self.is_neox_style)
return query, key return query, key
...@@ -158,27 +169,32 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding): ...@@ -158,27 +169,32 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding):
max_position_embeddings: int, max_position_embeddings: int,
base: int, base: int,
is_neox_style: bool, is_neox_style: bool,
scaling_factor: float, scaling_factors: Union[List[float], float],
) -> None: ) -> None:
self.scaling_factor = scaling_factor if isinstance(scaling_factors, float):
scaling_factors = [scaling_factors]
self.scaling_factors = scaling_factors
super().__init__(head_size, rotary_dim, max_position_embeddings, base, super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style) is_neox_style)
def _compute_cos_sin_cache(self) -> torch.Tensor: def _compute_cos_sin_cache(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.base) inv_freq = self._compute_inv_freq(self.base)
# NOTE(woosuk): self.max_position_embeddings is the original cache_list = []
# maximum length before applying the rope scaling. for scaling_factor in self.scaling_factors:
# Thus, the maximum length after applying the rope scaling is # NOTE(woosuk): self.max_position_embeddings is the original
# self.max_position_embeddings * self.scaling_factor. # maximum length before applying the rope scaling.
max_len = self.max_position_embeddings * self.scaling_factor # Thus, the maximum length after applying the rope scaling is
t = torch.arange(max_len, dtype=torch.float) # self.max_position_embeddings * self.scaling_factor.
t = t / self.scaling_factor max_len = self.max_position_embeddings * scaling_factor
t = torch.arange(max_len, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq) t = t / scaling_factor
cos = freqs.cos()
sin = freqs.sin() freqs = torch.einsum("i,j -> ij", t, inv_freq)
cache = torch.cat((cos, sin), dim=-1) cos = freqs.cos()
return cache sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
cache_list.append(cache)
return torch.cat(cache_list, dim=0)
class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding): class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment