Unverified Commit 94bff099 authored by Sudhakar Singh's avatar Sudhakar Singh Committed by GitHub
Browse files

RoPE enhancements (#1478)



* add support for `sb1d` freqs tensor in Fused RoPE
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* add `start_positions` variable to `apply_rotary_pos_emb` function to make staggered rope application faster
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add pytorch path for `start_positions` and corresponding tests
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add tests for start_positions with thd
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes from feedback
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* remove start_positions from backward pass
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* from feedback
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* make notes shorter
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

---------
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 9a819334
......@@ -22,6 +22,7 @@ def _non_overlapping_grad(output: torch.Tensor) -> torch.Tensor:
return torch.sum(output * t)
@pytest.mark.parametrize("start_positions", [True, False])
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
@pytest.mark.parametrize("seq_length", [2048, 4096])
@pytest.mark.parametrize("hidden_size", [128, 256])
......@@ -43,7 +44,17 @@ def test_fused_rope(
loss_func: Callable,
cp_size: int,
interleaved: bool,
start_positions: bool,
) -> None:
if margin == 0 and start_positions == True:
# This makes sure that the `start_positions` offsets being applied
# are with the maximum length of the rope embeddings.
pytest.skip("Skipping test with margin=0 and start_positions=True")
if start_positions == True and cp_size > 1:
# `start_positions` is only supported for `cp_size=1` and inference.
pytest.skip("Skipping test with cp_size>1 and start_positions=True")
device = torch.device("cuda:0")
batch_size, head_num = 2, 64
t = torch.rand(
......@@ -51,6 +62,14 @@ def test_fused_rope(
dtype=dtype,
device=device,
)
# Get arbitrary offsets to be used with RoPE for all the sequences
start_positions = (
torch.randint(0, margin, (batch_size,), dtype=torch.int32, device=device)
if start_positions
else None
)
if tensor_format == "bshd":
t = t.transpose(0, 1).contiguous()
if transpose:
......@@ -69,14 +88,18 @@ def test_fused_rope(
t.float(),
emb,
tensor_format=tensor_format,
start_positions=start_positions,
interleaved=interleaved,
fused=False,
cp_size=cp_size,
cp_rank=cp_rank,
).to(dtype)
loss_unfused = loss_func(output_unfused)
loss_unfused.backward()
grad_unfused = t.grad.detach().clone()
if not isinstance(start_positions, torch.Tensor):
loss_unfused.backward()
grad_unfused = t.grad.detach().clone()
t.grad = None
# fused
......@@ -84,21 +107,29 @@ def test_fused_rope(
t,
emb,
tensor_format=tensor_format,
start_positions=start_positions,
interleaved=interleaved,
fused=True,
cp_size=cp_size,
cp_rank=cp_rank,
)
loss_fused = loss_func(output_fused)
loss_fused.backward()
grad_fused = t.grad.detach().clone()
if not isinstance(start_positions, torch.Tensor):
loss_fused.backward()
grad_fused = t.grad.detach().clone()
t.grad = None
torch.testing.assert_close(output_fused, output_unfused)
torch.testing.assert_close(grad_fused, grad_unfused)
if not isinstance(start_positions, torch.Tensor):
torch.testing.assert_close(grad_fused, grad_unfused)
assert output_fused.is_contiguous()
@pytest.mark.parametrize("margin", [10])
@pytest.mark.parametrize("start_positions", [True, False])
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
@pytest.mark.parametrize("hidden_size", [128, 256])
@pytest.mark.parametrize("rotary_percent", [0.5, 1.0])
......@@ -114,10 +145,25 @@ def test_fused_rope_thd(
loss_func: Callable,
cp_size: int,
interleaved: bool,
start_positions: bool,
margin: int,
) -> None:
if start_positions == True and cp_size > 1:
# `start_positions` is only supported for `cp_size=1` and inference.
pytest.skip("Skipping test with cp_size>1 and start_positions=True")
device = torch.device("cuda:0")
batch_size, head_num = 2, 64
cu_seqlens = [0, 400, 542, 711, 727, 752, 1270, 1426, 1450, 1954, 2044, 2048]
# Get arbitrary offsets to be used with RoPE for all the sequences
start_positions = (
torch.randint(0, margin, (len(cu_seqlens) - 1,), dtype=torch.int32, device=device)
if start_positions
else None
)
if cp_size > 1:
cu_seqlens_padded = [0]
for i in range(1, len(cu_seqlens)):
......@@ -152,6 +198,7 @@ def test_fused_rope_thd(
output_unfused = apply_rotary_pos_emb(
t.float(),
emb,
start_positions=start_positions,
tensor_format="thd",
interleaved=interleaved,
fused=False,
......@@ -160,14 +207,17 @@ def test_fused_rope_thd(
cp_rank=cp_rank,
).to(dtype)
loss_unfused = loss_func(output_unfused)
loss_unfused.backward()
grad_unfused = t.grad.detach().clone()
if not isinstance(start_positions, torch.Tensor):
loss_unfused.backward()
grad_unfused = t.grad.detach().clone()
t.grad = None
# fused
output_fused = apply_rotary_pos_emb(
t,
emb,
start_positions=start_positions,
interleaved=interleaved,
fused=True,
tensor_format="thd",
......@@ -176,9 +226,15 @@ def test_fused_rope_thd(
cp_rank=cp_rank,
)
loss_fused = loss_func(output_fused)
loss_fused.backward()
grad_fused = t.grad.detach().clone()
if not isinstance(start_positions, torch.Tensor):
loss_fused.backward()
grad_fused = t.grad.detach().clone()
t.grad = None
torch.testing.assert_close(output_fused, output_unfused)
torch.testing.assert_close(grad_fused, grad_unfused)
if not isinstance(start_positions, torch.Tensor):
torch.testing.assert_close(grad_fused, grad_unfused)
assert output_fused.is_contiguous()
......@@ -115,10 +115,10 @@ __device__ void fused_rope_block_backward(const scalar_t *src, const float *freq
template <typename scalar_t>
__global__ void fused_rope_forward_kernel(const scalar_t *src, const int *cu_seqlens,
const float *freqs, scalar_t *dst, const bool interleaved,
const int cp_size, const int cp_rank, const int s,
const int h, const int d, const int d2,
const int stride_s_or_t, const int stride_b,
const float *freqs, const int *start_positions,
scalar_t *dst, const bool interleaved, const int cp_size,
const int cp_rank, const int s, const int h, const int d,
const int d2, const int stride_s_or_t, const int stride_b,
const int stride_h, const int stride_d,
const int o_stride_s_or_t, const int o_stride_b,
const int o_stride_h, const int o_stride_d) {
......@@ -149,7 +149,8 @@ __global__ void fused_rope_forward_kernel(const scalar_t *src, const int *cu_seq
cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 + s_id - cur_seqlens / 2;
}
} else {
s_id_for_freqs = s_id;
int begin_offset = (start_positions == nullptr) ? 0 : start_positions[b_id];
s_id_for_freqs = s_id + begin_offset;
}
fused_rope_block_forward(src, freqs, dst, interleaved, s_id_for_freqs, offset_block,
......@@ -199,11 +200,12 @@ __global__ void fused_rope_backward_kernel(
template <typename scalar_t>
void fused_rope_forward_launcher(const scalar_t *input, const int *cu_seqlens, const float *freqs,
scalar_t *output, const NVTE_QKV_Format qkv_format,
const bool interleaved, const int cp_size, const int cp_rank,
const int s, const int b, const int h, const int d, const int d2,
const int stride_s_or_t, const int stride_b, const int stride_h,
const int stride_d, cudaStream_t stream) {
const int *start_positions, scalar_t *output,
const NVTE_QKV_Format qkv_format, const bool interleaved,
const int cp_size, const int cp_rank, const int s, const int b,
const int h, const int d, const int d2, const int stride_s_or_t,
const int stride_b, const int stride_h, const int stride_d,
cudaStream_t stream) {
int warps_per_block = h < 16 ? 4 : 8;
dim3 blocks(s, b);
dim3 threads(THREADS_PER_WARP, warps_per_block);
......@@ -223,8 +225,9 @@ void fused_rope_forward_launcher(const scalar_t *input, const int *cu_seqlens, c
const int o_stride_d = 1;
fused_rope_forward_kernel<<<blocks, threads, 0, stream>>>(
input, cu_seqlens, freqs, output, interleaved, cp_size, cp_rank, s, h, d, d2, stride_s_or_t,
stride_b, stride_h, stride_d, o_stride_s_or_t, o_stride_b, o_stride_h, o_stride_d);
input, cu_seqlens, freqs, start_positions, output, interleaved, cp_size, cp_rank, s, h, d, d2,
stride_s_or_t, stride_b, stride_h, stride_d, o_stride_s_or_t, o_stride_b, o_stride_h,
o_stride_d);
NVTE_CHECK_CUDA(cudaGetLastError());
}
......@@ -262,15 +265,17 @@ void fused_rope_backward_launcher(const scalar_t *output_grads, const int *cu_se
}
void fused_rope_forward(const Tensor &input, const Tensor &cu_seqlens, const Tensor &freqs,
Tensor *output, const NVTE_QKV_Format qkv_format, const bool interleaved,
const int cp_size, const int cp_rank, const int s, const int b, const int h,
const int d, const int d2, const int stride_s_or_t, const int stride_b,
const Tensor &start_positions, Tensor *output,
const NVTE_QKV_Format qkv_format, const bool interleaved, const int cp_size,
const int cp_rank, const int s, const int b, const int h, const int d,
const int d2, const int stride_s_or_t, const int stride_b,
const int stride_h, const int stride_d, cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, scalar_t,
fused_rope_forward_launcher(reinterpret_cast<const scalar_t *>(input.data.dptr),
reinterpret_cast<const int *>(cu_seqlens.data.dptr),
reinterpret_cast<const float *>(freqs.data.dptr),
reinterpret_cast<const int *>(start_positions.data.dptr),
reinterpret_cast<scalar_t *>(output->data.dptr), qkv_format,
interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s_or_t,
stride_b, stride_h, stride_d, stream););
......@@ -295,19 +300,19 @@ void fused_rope_backward(const Tensor &output_grads, const Tensor &cu_seqlens, c
} // end namespace transformer_engine
void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens,
const NVTETensor freqs, NVTETensor output,
const NVTE_QKV_Format qkv_format, const bool interleaved,
const int cp_size, const int cp_rank, const int s, const int b,
const int h, const int d, const int d2, const int stride_s_or_t,
const int stride_b, const int stride_h, const int stride_d,
cudaStream_t stream) {
const NVTETensor freqs, const NVTETensor start_positions,
NVTETensor output, const NVTE_QKV_Format qkv_format,
const bool interleaved, const int cp_size, const int cp_rank,
const int s, const int b, const int h, const int d, const int d2,
const int stride_s_or_t, const int stride_b, const int stride_h,
const int stride_d, cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_rope_forward);
using namespace transformer_engine;
fused_rope_forward(*reinterpret_cast<const Tensor *>(input),
*reinterpret_cast<const Tensor *>(cu_seqlens),
*reinterpret_cast<const Tensor *>(freqs), reinterpret_cast<Tensor *>(output),
qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s_or_t,
stride_b, stride_h, stride_d, stream);
fused_rope_forward(
*reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(cu_seqlens),
*reinterpret_cast<const Tensor *>(freqs), *reinterpret_cast<const Tensor *>(start_positions),
reinterpret_cast<Tensor *>(output), qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2,
stride_s_or_t, stride_b, stride_h, stride_d, stream);
}
void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens,
......
......@@ -20,6 +20,7 @@ extern "C" {
* \param[in] cu_seqlens The cumulative sum of sequence lengths tensor.
* (Required for the thd format, empty tensor for other formats)
* \param[in] freqs The freqs tensor.
* \param[in] start_positions The beginning offsets for applying RoPE embeddings.
* \param[out] output Output tensor.
* \param[in] qkv_format QKV format.
* \param[in] interleaved Whether to use interleaved rotary position embedding.
......@@ -37,12 +38,12 @@ extern "C" {
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens,
const NVTETensor freqs, NVTETensor output,
const NVTE_QKV_Format qkv_format, const bool interleaved,
const int cp_size, const int cp_rank, const int s, const int b,
const int h, const int d, const int d2, const int stride_s_or_t,
const int stride_b, const int stride_h, const int stride_d,
cudaStream_t stream);
const NVTETensor freqs, const NVTETensor start_positions,
NVTETensor output, const NVTE_QKV_Format qkv_format,
const bool interleaved, const int cp_size, const int cp_rank,
const int s, const int b, const int h, const int d, const int d2,
const int stride_s_or_t, const int stride_b, const int stride_h,
const int stride_d, cudaStream_t stream);
/*! \brief Compute the backward of the fused rope.
*
......
......@@ -278,6 +278,7 @@ void fp8_block_scaling_partial_cast(const at::Tensor &inp, at::Tensor out, const
**************************************************************************************************/
at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs,
const std::optional<at::Tensor> start_positions,
const NVTE_QKV_Format qkv_format, const bool interleaved,
const std::optional<at::Tensor> cu_seqlens, const int cp_size,
const int cp_rank);
......
......@@ -7,6 +7,7 @@
#include "extensions.h"
at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs,
const std::optional<at::Tensor> start_positions,
const NVTE_QKV_Format qkv_format, const bool interleaved,
const std::optional<at::Tensor> cu_seqlens, const int cp_size,
const int cp_rank) {
......@@ -26,6 +27,11 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs,
auto freqs_cu = makeTransformerEngineTensor(freqs);
auto output_cu = makeTransformerEngineTensor(output);
auto start_positions_cu = transformer_engine::TensorWrapper(); // empty cu_seqlens tensor
if (start_positions) {
start_positions_cu = makeTransformerEngineTensor(start_positions.value());
}
if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
TORCH_CHECK(input.dim() == 3, "expected 3D tensor");
TORCH_CHECK(cu_seqlens.has_value(), "expected cu_seqlens tensor");
......@@ -54,9 +60,9 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs,
auto cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens.value());
nvte_fused_rope_forward(input_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(),
output_cu.data(), qkv_format, interleaved, cp_size, cp_rank, max_s, b,
h, d, d2, stride_t, /*stride_b=*/0, stride_h, stride_d,
at::cuda::getCurrentCUDAStream());
start_positions_cu.data(), output_cu.data(), qkv_format, interleaved,
cp_size, cp_rank, max_s, b, h, d, d2, stride_t, /*stride_b=*/0,
stride_h, stride_d, at::cuda::getCurrentCUDAStream());
return output;
}
......@@ -87,9 +93,10 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs,
"greater than the freqs tensor");
auto cu_seqlens_cu = transformer_engine::TensorWrapper(); // empty cu_seqlens tensor
nvte_fused_rope_forward(input_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(), output_cu.data(),
qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s,
stride_b, stride_h, stride_d, at::cuda::getCurrentCUDAStream());
nvte_fused_rope_forward(input_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(),
start_positions_cu.data(), output_cu.data(), qkv_format, interleaved,
cp_size, cp_rank, s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d,
at::cuda::getCurrentCUDAStream());
return output;
}
......@@ -142,8 +149,8 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor
nvte_fused_rope_backward(output_grads_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(),
input_grads_cu.data(), qkv_format, interleaved, cp_size, cp_rank,
max_s, b, h, d, d2, stride_t, /*stride_b=*/0, stride_h, stride_d,
at::cuda::getCurrentCUDAStream());
max_s, b, h, d, d2, stride_t,
/*stride_b=*/0, stride_h, stride_d, at::cuda::getCurrentCUDAStream());
return input_grads;
}
......
......@@ -119,6 +119,7 @@ class FusedRoPEFunc(torch.autograd.Function):
ctx,
t: torch.Tensor,
freqs: torch.Tensor,
start_positions: Union[torch.Tensor, None] = None,
tensor_format: str = "sbhd",
interleaved: bool = False,
cu_seqlens: Union[torch.Tensor, None] = None,
......@@ -126,6 +127,7 @@ class FusedRoPEFunc(torch.autograd.Function):
cp_rank: int = 0,
) -> torch.Tensor:
"""Fused RoPE forward."""
if freqs.dtype != torch.float32:
freqs = freqs.float()
assert tensor_format in (
......@@ -134,7 +136,14 @@ class FusedRoPEFunc(torch.autograd.Function):
"thd",
), f"Unsupported tensor_format: {tensor_format}."
output = tex.fused_rope_forward(
t, freqs, QKVFormat[tensor_format], interleaved, cu_seqlens, cp_size, cp_rank
t,
freqs,
start_positions,
QKVFormat[tensor_format],
interleaved,
cu_seqlens,
cp_size,
cp_rank,
)
ctx.save_for_backward(freqs, cu_seqlens)
ctx.tensor_format = tensor_format
......@@ -158,7 +167,7 @@ class FusedRoPEFunc(torch.autograd.Function):
ctx.cp_rank,
)
return grad_input, None, None, None, None, None, None
return grad_input, None, None, None, None, None, None, None
def _rotate_half(x: torch.Tensor, interleaved: bool) -> torch.Tensor:
......@@ -185,6 +194,7 @@ def _rotate_half(x: torch.Tensor, interleaved: bool) -> torch.Tensor:
def _apply_rotary_pos_emb_base(
t: torch.Tensor,
freqs: torch.Tensor,
start_positions: torch.Tensor = None,
tensor_format: str = "sbhd",
interleaved: bool = False,
) -> torch.Tensor:
......@@ -199,6 +209,9 @@ def _apply_rotary_pos_emb_base(
freqs: torch.Tensor
Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
with `s2 >= s` and `d2 <= d`.
start_positions: torch.Tensor, default = None.
Tokens in a sequence `i` should be applied with position encoding offset by
`start_positions[i]`. If `start_positions=None`, there's no offset.
tensor_format: {'sbhd', 'bshd'}, default = 'sbhd'
Should be `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is of shape
`[seq, bs, ...]`.
......@@ -208,14 +221,32 @@ def _apply_rotary_pos_emb_base(
max_seq_len = freqs.shape[0]
cur_seq_len = t.shape[1] if tensor_format == "bshd" else t.shape[0]
# In case `start_positions` are provided, create a staggered `freqs` tensor
# offset by the values in `start_positions`.
# `start_positions` is only supported for `cp_size=1` and inference.
if start_positions is not None:
max_offset = torch.max(start_positions)
assert (
max_offset + cur_seq_len <= max_seq_len
), f"Rotary Embeddings only suppported up to {max_seq_len} sequence length!"
# Stack staggered rope embeddings along the batch dimension
freqs = torch.concatenate([freqs[i : i + cur_seq_len] for i in start_positions], dim=1)
# Note that from this point, `freqs` has a shape `(s,b,1,d)`.
# Only apply the rotary embeddings up to the sequence length of the running
# input.
assert (
cur_seq_len <= max_seq_len
), f"Rotary Embeddings only supported up to {max_seq_len} sequence length!"
freqs = freqs[:cur_seq_len]
# [seq, 1, 1, dim] -> [1, seq, 1, dim] or
# [seq, b, 1, dim] -> [b, seq, 1, dim]
if tensor_format == "bshd":
freqs = freqs.transpose(0, 1) # [seq, 1, 1, dim] -> [1, seq, 1, dim]
freqs = freqs.transpose(0, 1)
# cos/sin first then dtype conversion for better precision
cos_ = torch.cos(freqs).to(t.dtype)
sin_ = torch.sin(freqs).to(t.dtype)
......@@ -252,13 +283,14 @@ def _get_freqs_on_this_cp_rank(
)
# cp_size == 1
return freqs[:seqlen]
return freqs
def apply_rotary_pos_emb(
t: torch.Tensor,
freqs: torch.Tensor,
tensor_format: str = "sbhd",
start_positions: Union[torch.Tensor, None] = None,
interleaved: bool = False,
fused: bool = False,
cu_seqlens: Union[torch.Tensor, None] = None,
......@@ -268,6 +300,19 @@ def apply_rotary_pos_emb(
"""
Apply rotary positional embedding tensor to the input tensor.
Support matrix:
Fused/Unfused:
Training:
qkv_formats: "thd", "bshd", "sbhd"
context parallel: yes
start_positions: no
interleaving: yes
Inference:
qkv_formats: "thd", "bshd", "sbhd"
context parallelism: no
start_positions: yes
interleaving: yes
Parameters
----------
t: torch.Tensor
......@@ -276,6 +321,9 @@ def apply_rotary_pos_emb(
freqs: torch.Tensor
Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
with `s2 >= s` and `d2 <= d`.
start_positions: torch.Tensor, default = None.
Tokens in a sequence `i` should be applied with position encoding offset by
`start_positions[i]`. If `start_positions=None`, there's no offset.
tensor_format: {'sbhd', 'bshd', 'thd'}, default = 'sbhd'
is `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is
of shape `[seq, bs, ...]`. 'thd' is only supported when `fused` is True.
......@@ -292,27 +340,43 @@ def apply_rotary_pos_emb(
cp_rank: int, default = 0.
Context parallel rank. Only valid when `tensor_format` is 'thd' and `fused` is True.
"""
# `start_positions` is only supported for `cp_size=1` and inference.
assert not (
cp_size > 1 and start_positions is not None
), """start_positions != None with CP SIZE > 1 is not supported!"""
assert (
tensor_format != "thd" or cu_seqlens is not None
), "cu_seqlens must not be None when tensor_format is 'thd'."
if fused:
return FusedRoPEFunc.apply(
t, freqs, tensor_format, interleaved, cu_seqlens, cp_size, cp_rank
t, freqs, start_positions, tensor_format, interleaved, cu_seqlens, cp_size, cp_rank
)
# Unfused THD format
if tensor_format == "thd":
cu_seqlens = cu_seqlens // cp_size
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
# The following code essentially splits the `thd` tensor into corresponding
# `s1hd` tensors (for each sequence) and applies rotary embedding to
# those sequences individually.
# Note that if `start_positions` is not `None`, then for each sequence,
# it's corresponding rope offset is also supplied from `start_positions`
# individually.
return torch.cat(
[
_apply_rotary_pos_emb_base(
x.unsqueeze(1),
_get_freqs_on_this_cp_rank(freqs, x.size(0), cp_size, cp_rank),
start_positions=(
start_positions[idx : idx + 1] if start_positions is not None else None
),
interleaved=interleaved,
)
for x in torch.split(t, seqlens)
for idx, x in enumerate(torch.split(t, seqlens))
]
).squeeze(1)
......@@ -326,6 +390,7 @@ def apply_rotary_pos_emb(
return _apply_rotary_pos_emb_base(
t,
_get_freqs_on_this_cp_rank(freqs, seqlen, cp_size, cp_rank),
start_positions,
tensor_format,
interleaved=interleaved,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment