Unverified Commit 94bff099 authored by Sudhakar Singh's avatar Sudhakar Singh Committed by GitHub
Browse files

RoPE enhancements (#1478)



* add support for `sb1d` freqs tensor in Fused RoPE
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* add `start_positions` variable to `apply_rotary_pos_emb` function to make staggered rope application faster
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add pytorch path for `start_positions` and corresponding tests
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add tests for start_positions with thd
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes from feedback
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* remove start_positions from backward pass
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* from feedback
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* make notes shorter
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

---------
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 9a819334
...@@ -22,6 +22,7 @@ def _non_overlapping_grad(output: torch.Tensor) -> torch.Tensor: ...@@ -22,6 +22,7 @@ def _non_overlapping_grad(output: torch.Tensor) -> torch.Tensor:
return torch.sum(output * t) return torch.sum(output * t)
@pytest.mark.parametrize("start_positions", [True, False])
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
@pytest.mark.parametrize("seq_length", [2048, 4096]) @pytest.mark.parametrize("seq_length", [2048, 4096])
@pytest.mark.parametrize("hidden_size", [128, 256]) @pytest.mark.parametrize("hidden_size", [128, 256])
...@@ -43,7 +44,17 @@ def test_fused_rope( ...@@ -43,7 +44,17 @@ def test_fused_rope(
loss_func: Callable, loss_func: Callable,
cp_size: int, cp_size: int,
interleaved: bool, interleaved: bool,
start_positions: bool,
) -> None: ) -> None:
if margin == 0 and start_positions == True:
# This makes sure that the `start_positions` offsets being applied
# are with the maximum length of the rope embeddings.
pytest.skip("Skipping test with margin=0 and start_positions=True")
if start_positions == True and cp_size > 1:
# `start_positions` is only supported for `cp_size=1` and inference.
pytest.skip("Skipping test with cp_size>1 and start_positions=True")
device = torch.device("cuda:0") device = torch.device("cuda:0")
batch_size, head_num = 2, 64 batch_size, head_num = 2, 64
t = torch.rand( t = torch.rand(
...@@ -51,6 +62,14 @@ def test_fused_rope( ...@@ -51,6 +62,14 @@ def test_fused_rope(
dtype=dtype, dtype=dtype,
device=device, device=device,
) )
# Get arbitrary offsets to be used with RoPE for all the sequences
start_positions = (
torch.randint(0, margin, (batch_size,), dtype=torch.int32, device=device)
if start_positions
else None
)
if tensor_format == "bshd": if tensor_format == "bshd":
t = t.transpose(0, 1).contiguous() t = t.transpose(0, 1).contiguous()
if transpose: if transpose:
...@@ -69,14 +88,18 @@ def test_fused_rope( ...@@ -69,14 +88,18 @@ def test_fused_rope(
t.float(), t.float(),
emb, emb,
tensor_format=tensor_format, tensor_format=tensor_format,
start_positions=start_positions,
interleaved=interleaved, interleaved=interleaved,
fused=False, fused=False,
cp_size=cp_size, cp_size=cp_size,
cp_rank=cp_rank, cp_rank=cp_rank,
).to(dtype) ).to(dtype)
loss_unfused = loss_func(output_unfused) loss_unfused = loss_func(output_unfused)
if not isinstance(start_positions, torch.Tensor):
loss_unfused.backward() loss_unfused.backward()
grad_unfused = t.grad.detach().clone() grad_unfused = t.grad.detach().clone()
t.grad = None t.grad = None
# fused # fused
...@@ -84,21 +107,29 @@ def test_fused_rope( ...@@ -84,21 +107,29 @@ def test_fused_rope(
t, t,
emb, emb,
tensor_format=tensor_format, tensor_format=tensor_format,
start_positions=start_positions,
interleaved=interleaved, interleaved=interleaved,
fused=True, fused=True,
cp_size=cp_size, cp_size=cp_size,
cp_rank=cp_rank, cp_rank=cp_rank,
) )
loss_fused = loss_func(output_fused) loss_fused = loss_func(output_fused)
if not isinstance(start_positions, torch.Tensor):
loss_fused.backward() loss_fused.backward()
grad_fused = t.grad.detach().clone() grad_fused = t.grad.detach().clone()
t.grad = None t.grad = None
torch.testing.assert_close(output_fused, output_unfused) torch.testing.assert_close(output_fused, output_unfused)
if not isinstance(start_positions, torch.Tensor):
torch.testing.assert_close(grad_fused, grad_unfused) torch.testing.assert_close(grad_fused, grad_unfused)
assert output_fused.is_contiguous() assert output_fused.is_contiguous()
@pytest.mark.parametrize("margin", [10])
@pytest.mark.parametrize("start_positions", [True, False])
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
@pytest.mark.parametrize("hidden_size", [128, 256]) @pytest.mark.parametrize("hidden_size", [128, 256])
@pytest.mark.parametrize("rotary_percent", [0.5, 1.0]) @pytest.mark.parametrize("rotary_percent", [0.5, 1.0])
...@@ -114,10 +145,25 @@ def test_fused_rope_thd( ...@@ -114,10 +145,25 @@ def test_fused_rope_thd(
loss_func: Callable, loss_func: Callable,
cp_size: int, cp_size: int,
interleaved: bool, interleaved: bool,
start_positions: bool,
margin: int,
) -> None: ) -> None:
if start_positions == True and cp_size > 1:
# `start_positions` is only supported for `cp_size=1` and inference.
pytest.skip("Skipping test with cp_size>1 and start_positions=True")
device = torch.device("cuda:0") device = torch.device("cuda:0")
batch_size, head_num = 2, 64 batch_size, head_num = 2, 64
cu_seqlens = [0, 400, 542, 711, 727, 752, 1270, 1426, 1450, 1954, 2044, 2048] cu_seqlens = [0, 400, 542, 711, 727, 752, 1270, 1426, 1450, 1954, 2044, 2048]
# Get arbitrary offsets to be used with RoPE for all the sequences
start_positions = (
torch.randint(0, margin, (len(cu_seqlens) - 1,), dtype=torch.int32, device=device)
if start_positions
else None
)
if cp_size > 1: if cp_size > 1:
cu_seqlens_padded = [0] cu_seqlens_padded = [0]
for i in range(1, len(cu_seqlens)): for i in range(1, len(cu_seqlens)):
...@@ -152,6 +198,7 @@ def test_fused_rope_thd( ...@@ -152,6 +198,7 @@ def test_fused_rope_thd(
output_unfused = apply_rotary_pos_emb( output_unfused = apply_rotary_pos_emb(
t.float(), t.float(),
emb, emb,
start_positions=start_positions,
tensor_format="thd", tensor_format="thd",
interleaved=interleaved, interleaved=interleaved,
fused=False, fused=False,
...@@ -160,6 +207,8 @@ def test_fused_rope_thd( ...@@ -160,6 +207,8 @@ def test_fused_rope_thd(
cp_rank=cp_rank, cp_rank=cp_rank,
).to(dtype) ).to(dtype)
loss_unfused = loss_func(output_unfused) loss_unfused = loss_func(output_unfused)
if not isinstance(start_positions, torch.Tensor):
loss_unfused.backward() loss_unfused.backward()
grad_unfused = t.grad.detach().clone() grad_unfused = t.grad.detach().clone()
t.grad = None t.grad = None
...@@ -168,6 +217,7 @@ def test_fused_rope_thd( ...@@ -168,6 +217,7 @@ def test_fused_rope_thd(
output_fused = apply_rotary_pos_emb( output_fused = apply_rotary_pos_emb(
t, t,
emb, emb,
start_positions=start_positions,
interleaved=interleaved, interleaved=interleaved,
fused=True, fused=True,
tensor_format="thd", tensor_format="thd",
...@@ -176,9 +226,15 @@ def test_fused_rope_thd( ...@@ -176,9 +226,15 @@ def test_fused_rope_thd(
cp_rank=cp_rank, cp_rank=cp_rank,
) )
loss_fused = loss_func(output_fused) loss_fused = loss_func(output_fused)
if not isinstance(start_positions, torch.Tensor):
loss_fused.backward() loss_fused.backward()
grad_fused = t.grad.detach().clone() grad_fused = t.grad.detach().clone()
t.grad = None t.grad = None
torch.testing.assert_close(output_fused, output_unfused) torch.testing.assert_close(output_fused, output_unfused)
if not isinstance(start_positions, torch.Tensor):
torch.testing.assert_close(grad_fused, grad_unfused) torch.testing.assert_close(grad_fused, grad_unfused)
assert output_fused.is_contiguous()
...@@ -115,10 +115,10 @@ __device__ void fused_rope_block_backward(const scalar_t *src, const float *freq ...@@ -115,10 +115,10 @@ __device__ void fused_rope_block_backward(const scalar_t *src, const float *freq
template <typename scalar_t> template <typename scalar_t>
__global__ void fused_rope_forward_kernel(const scalar_t *src, const int *cu_seqlens, __global__ void fused_rope_forward_kernel(const scalar_t *src, const int *cu_seqlens,
const float *freqs, scalar_t *dst, const bool interleaved, const float *freqs, const int *start_positions,
const int cp_size, const int cp_rank, const int s, scalar_t *dst, const bool interleaved, const int cp_size,
const int h, const int d, const int d2, const int cp_rank, const int s, const int h, const int d,
const int stride_s_or_t, const int stride_b, const int d2, const int stride_s_or_t, const int stride_b,
const int stride_h, const int stride_d, const int stride_h, const int stride_d,
const int o_stride_s_or_t, const int o_stride_b, const int o_stride_s_or_t, const int o_stride_b,
const int o_stride_h, const int o_stride_d) { const int o_stride_h, const int o_stride_d) {
...@@ -149,7 +149,8 @@ __global__ void fused_rope_forward_kernel(const scalar_t *src, const int *cu_seq ...@@ -149,7 +149,8 @@ __global__ void fused_rope_forward_kernel(const scalar_t *src, const int *cu_seq
cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 + s_id - cur_seqlens / 2; cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 + s_id - cur_seqlens / 2;
} }
} else { } else {
s_id_for_freqs = s_id; int begin_offset = (start_positions == nullptr) ? 0 : start_positions[b_id];
s_id_for_freqs = s_id + begin_offset;
} }
fused_rope_block_forward(src, freqs, dst, interleaved, s_id_for_freqs, offset_block, fused_rope_block_forward(src, freqs, dst, interleaved, s_id_for_freqs, offset_block,
...@@ -199,11 +200,12 @@ __global__ void fused_rope_backward_kernel( ...@@ -199,11 +200,12 @@ __global__ void fused_rope_backward_kernel(
template <typename scalar_t> template <typename scalar_t>
void fused_rope_forward_launcher(const scalar_t *input, const int *cu_seqlens, const float *freqs, void fused_rope_forward_launcher(const scalar_t *input, const int *cu_seqlens, const float *freqs,
scalar_t *output, const NVTE_QKV_Format qkv_format, const int *start_positions, scalar_t *output,
const bool interleaved, const int cp_size, const int cp_rank, const NVTE_QKV_Format qkv_format, const bool interleaved,
const int s, const int b, const int h, const int d, const int d2, const int cp_size, const int cp_rank, const int s, const int b,
const int stride_s_or_t, const int stride_b, const int stride_h, const int h, const int d, const int d2, const int stride_s_or_t,
const int stride_d, cudaStream_t stream) { const int stride_b, const int stride_h, const int stride_d,
cudaStream_t stream) {
int warps_per_block = h < 16 ? 4 : 8; int warps_per_block = h < 16 ? 4 : 8;
dim3 blocks(s, b); dim3 blocks(s, b);
dim3 threads(THREADS_PER_WARP, warps_per_block); dim3 threads(THREADS_PER_WARP, warps_per_block);
...@@ -223,8 +225,9 @@ void fused_rope_forward_launcher(const scalar_t *input, const int *cu_seqlens, c ...@@ -223,8 +225,9 @@ void fused_rope_forward_launcher(const scalar_t *input, const int *cu_seqlens, c
const int o_stride_d = 1; const int o_stride_d = 1;
fused_rope_forward_kernel<<<blocks, threads, 0, stream>>>( fused_rope_forward_kernel<<<blocks, threads, 0, stream>>>(
input, cu_seqlens, freqs, output, interleaved, cp_size, cp_rank, s, h, d, d2, stride_s_or_t, input, cu_seqlens, freqs, start_positions, output, interleaved, cp_size, cp_rank, s, h, d, d2,
stride_b, stride_h, stride_d, o_stride_s_or_t, o_stride_b, o_stride_h, o_stride_d); stride_s_or_t, stride_b, stride_h, stride_d, o_stride_s_or_t, o_stride_b, o_stride_h,
o_stride_d);
NVTE_CHECK_CUDA(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
} }
...@@ -262,15 +265,17 @@ void fused_rope_backward_launcher(const scalar_t *output_grads, const int *cu_se ...@@ -262,15 +265,17 @@ void fused_rope_backward_launcher(const scalar_t *output_grads, const int *cu_se
} }
void fused_rope_forward(const Tensor &input, const Tensor &cu_seqlens, const Tensor &freqs, void fused_rope_forward(const Tensor &input, const Tensor &cu_seqlens, const Tensor &freqs,
Tensor *output, const NVTE_QKV_Format qkv_format, const bool interleaved, const Tensor &start_positions, Tensor *output,
const int cp_size, const int cp_rank, const int s, const int b, const int h, const NVTE_QKV_Format qkv_format, const bool interleaved, const int cp_size,
const int d, const int d2, const int stride_s_or_t, const int stride_b, const int cp_rank, const int s, const int b, const int h, const int d,
const int d2, const int stride_s_or_t, const int stride_b,
const int stride_h, const int stride_d, cudaStream_t stream) { const int stride_h, const int stride_d, cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, scalar_t, input.data.dtype, scalar_t,
fused_rope_forward_launcher(reinterpret_cast<const scalar_t *>(input.data.dptr), fused_rope_forward_launcher(reinterpret_cast<const scalar_t *>(input.data.dptr),
reinterpret_cast<const int *>(cu_seqlens.data.dptr), reinterpret_cast<const int *>(cu_seqlens.data.dptr),
reinterpret_cast<const float *>(freqs.data.dptr), reinterpret_cast<const float *>(freqs.data.dptr),
reinterpret_cast<const int *>(start_positions.data.dptr),
reinterpret_cast<scalar_t *>(output->data.dptr), qkv_format, reinterpret_cast<scalar_t *>(output->data.dptr), qkv_format,
interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s_or_t, interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s_or_t,
stride_b, stride_h, stride_d, stream);); stride_b, stride_h, stride_d, stream););
...@@ -295,19 +300,19 @@ void fused_rope_backward(const Tensor &output_grads, const Tensor &cu_seqlens, c ...@@ -295,19 +300,19 @@ void fused_rope_backward(const Tensor &output_grads, const Tensor &cu_seqlens, c
} // end namespace transformer_engine } // end namespace transformer_engine
void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens, void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens,
const NVTETensor freqs, NVTETensor output, const NVTETensor freqs, const NVTETensor start_positions,
const NVTE_QKV_Format qkv_format, const bool interleaved, NVTETensor output, const NVTE_QKV_Format qkv_format,
const int cp_size, const int cp_rank, const int s, const int b, const bool interleaved, const int cp_size, const int cp_rank,
const int h, const int d, const int d2, const int stride_s_or_t, const int s, const int b, const int h, const int d, const int d2,
const int stride_b, const int stride_h, const int stride_d, const int stride_s_or_t, const int stride_b, const int stride_h,
cudaStream_t stream) { const int stride_d, cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_rope_forward); NVTE_API_CALL(nvte_fused_rope_forward);
using namespace transformer_engine; using namespace transformer_engine;
fused_rope_forward(*reinterpret_cast<const Tensor *>(input), fused_rope_forward(
*reinterpret_cast<const Tensor *>(cu_seqlens), *reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(cu_seqlens),
*reinterpret_cast<const Tensor *>(freqs), reinterpret_cast<Tensor *>(output), *reinterpret_cast<const Tensor *>(freqs), *reinterpret_cast<const Tensor *>(start_positions),
qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s_or_t, reinterpret_cast<Tensor *>(output), qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2,
stride_b, stride_h, stride_d, stream); stride_s_or_t, stride_b, stride_h, stride_d, stream);
} }
void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens, void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens,
......
...@@ -20,6 +20,7 @@ extern "C" { ...@@ -20,6 +20,7 @@ extern "C" {
* \param[in] cu_seqlens The cumulative sum of sequence lengths tensor. * \param[in] cu_seqlens The cumulative sum of sequence lengths tensor.
* (Required for the thd format, empty tensor for other formats) * (Required for the thd format, empty tensor for other formats)
* \param[in] freqs The freqs tensor. * \param[in] freqs The freqs tensor.
* \param[in] start_positions The beginning offsets for applying RoPE embeddings.
* \param[out] output Output tensor. * \param[out] output Output tensor.
* \param[in] qkv_format QKV format. * \param[in] qkv_format QKV format.
* \param[in] interleaved Whether to use interleaved rotary position embedding. * \param[in] interleaved Whether to use interleaved rotary position embedding.
...@@ -37,12 +38,12 @@ extern "C" { ...@@ -37,12 +38,12 @@ extern "C" {
* \param[in] stream CUDA stream used for the operation. * \param[in] stream CUDA stream used for the operation.
*/ */
void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens, void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens,
const NVTETensor freqs, NVTETensor output, const NVTETensor freqs, const NVTETensor start_positions,
const NVTE_QKV_Format qkv_format, const bool interleaved, NVTETensor output, const NVTE_QKV_Format qkv_format,
const int cp_size, const int cp_rank, const int s, const int b, const bool interleaved, const int cp_size, const int cp_rank,
const int h, const int d, const int d2, const int stride_s_or_t, const int s, const int b, const int h, const int d, const int d2,
const int stride_b, const int stride_h, const int stride_d, const int stride_s_or_t, const int stride_b, const int stride_h,
cudaStream_t stream); const int stride_d, cudaStream_t stream);
/*! \brief Compute the backward of the fused rope. /*! \brief Compute the backward of the fused rope.
* *
......
...@@ -278,6 +278,7 @@ void fp8_block_scaling_partial_cast(const at::Tensor &inp, at::Tensor out, const ...@@ -278,6 +278,7 @@ void fp8_block_scaling_partial_cast(const at::Tensor &inp, at::Tensor out, const
**************************************************************************************************/ **************************************************************************************************/
at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs,
const std::optional<at::Tensor> start_positions,
const NVTE_QKV_Format qkv_format, const bool interleaved, const NVTE_QKV_Format qkv_format, const bool interleaved,
const std::optional<at::Tensor> cu_seqlens, const int cp_size, const std::optional<at::Tensor> cu_seqlens, const int cp_size,
const int cp_rank); const int cp_rank);
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include "extensions.h" #include "extensions.h"
at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs,
const std::optional<at::Tensor> start_positions,
const NVTE_QKV_Format qkv_format, const bool interleaved, const NVTE_QKV_Format qkv_format, const bool interleaved,
const std::optional<at::Tensor> cu_seqlens, const int cp_size, const std::optional<at::Tensor> cu_seqlens, const int cp_size,
const int cp_rank) { const int cp_rank) {
...@@ -26,6 +27,11 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, ...@@ -26,6 +27,11 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs,
auto freqs_cu = makeTransformerEngineTensor(freqs); auto freqs_cu = makeTransformerEngineTensor(freqs);
auto output_cu = makeTransformerEngineTensor(output); auto output_cu = makeTransformerEngineTensor(output);
auto start_positions_cu = transformer_engine::TensorWrapper(); // empty cu_seqlens tensor
if (start_positions) {
start_positions_cu = makeTransformerEngineTensor(start_positions.value());
}
if (qkv_format == NVTE_QKV_Format::NVTE_THD) { if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
TORCH_CHECK(input.dim() == 3, "expected 3D tensor"); TORCH_CHECK(input.dim() == 3, "expected 3D tensor");
TORCH_CHECK(cu_seqlens.has_value(), "expected cu_seqlens tensor"); TORCH_CHECK(cu_seqlens.has_value(), "expected cu_seqlens tensor");
...@@ -54,9 +60,9 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, ...@@ -54,9 +60,9 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs,
auto cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens.value()); auto cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens.value());
nvte_fused_rope_forward(input_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(), nvte_fused_rope_forward(input_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(),
output_cu.data(), qkv_format, interleaved, cp_size, cp_rank, max_s, b, start_positions_cu.data(), output_cu.data(), qkv_format, interleaved,
h, d, d2, stride_t, /*stride_b=*/0, stride_h, stride_d, cp_size, cp_rank, max_s, b, h, d, d2, stride_t, /*stride_b=*/0,
at::cuda::getCurrentCUDAStream()); stride_h, stride_d, at::cuda::getCurrentCUDAStream());
return output; return output;
} }
...@@ -87,9 +93,10 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, ...@@ -87,9 +93,10 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs,
"greater than the freqs tensor"); "greater than the freqs tensor");
auto cu_seqlens_cu = transformer_engine::TensorWrapper(); // empty cu_seqlens tensor auto cu_seqlens_cu = transformer_engine::TensorWrapper(); // empty cu_seqlens tensor
nvte_fused_rope_forward(input_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(), output_cu.data(), nvte_fused_rope_forward(input_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(),
qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s, start_positions_cu.data(), output_cu.data(), qkv_format, interleaved,
stride_b, stride_h, stride_d, at::cuda::getCurrentCUDAStream()); cp_size, cp_rank, s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d,
at::cuda::getCurrentCUDAStream());
return output; return output;
} }
...@@ -142,8 +149,8 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor ...@@ -142,8 +149,8 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor
nvte_fused_rope_backward(output_grads_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(), nvte_fused_rope_backward(output_grads_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(),
input_grads_cu.data(), qkv_format, interleaved, cp_size, cp_rank, input_grads_cu.data(), qkv_format, interleaved, cp_size, cp_rank,
max_s, b, h, d, d2, stride_t, /*stride_b=*/0, stride_h, stride_d, max_s, b, h, d, d2, stride_t,
at::cuda::getCurrentCUDAStream()); /*stride_b=*/0, stride_h, stride_d, at::cuda::getCurrentCUDAStream());
return input_grads; return input_grads;
} }
......
...@@ -119,6 +119,7 @@ class FusedRoPEFunc(torch.autograd.Function): ...@@ -119,6 +119,7 @@ class FusedRoPEFunc(torch.autograd.Function):
ctx, ctx,
t: torch.Tensor, t: torch.Tensor,
freqs: torch.Tensor, freqs: torch.Tensor,
start_positions: Union[torch.Tensor, None] = None,
tensor_format: str = "sbhd", tensor_format: str = "sbhd",
interleaved: bool = False, interleaved: bool = False,
cu_seqlens: Union[torch.Tensor, None] = None, cu_seqlens: Union[torch.Tensor, None] = None,
...@@ -126,6 +127,7 @@ class FusedRoPEFunc(torch.autograd.Function): ...@@ -126,6 +127,7 @@ class FusedRoPEFunc(torch.autograd.Function):
cp_rank: int = 0, cp_rank: int = 0,
) -> torch.Tensor: ) -> torch.Tensor:
"""Fused RoPE forward.""" """Fused RoPE forward."""
if freqs.dtype != torch.float32: if freqs.dtype != torch.float32:
freqs = freqs.float() freqs = freqs.float()
assert tensor_format in ( assert tensor_format in (
...@@ -134,7 +136,14 @@ class FusedRoPEFunc(torch.autograd.Function): ...@@ -134,7 +136,14 @@ class FusedRoPEFunc(torch.autograd.Function):
"thd", "thd",
), f"Unsupported tensor_format: {tensor_format}." ), f"Unsupported tensor_format: {tensor_format}."
output = tex.fused_rope_forward( output = tex.fused_rope_forward(
t, freqs, QKVFormat[tensor_format], interleaved, cu_seqlens, cp_size, cp_rank t,
freqs,
start_positions,
QKVFormat[tensor_format],
interleaved,
cu_seqlens,
cp_size,
cp_rank,
) )
ctx.save_for_backward(freqs, cu_seqlens) ctx.save_for_backward(freqs, cu_seqlens)
ctx.tensor_format = tensor_format ctx.tensor_format = tensor_format
...@@ -158,7 +167,7 @@ class FusedRoPEFunc(torch.autograd.Function): ...@@ -158,7 +167,7 @@ class FusedRoPEFunc(torch.autograd.Function):
ctx.cp_rank, ctx.cp_rank,
) )
return grad_input, None, None, None, None, None, None return grad_input, None, None, None, None, None, None, None
def _rotate_half(x: torch.Tensor, interleaved: bool) -> torch.Tensor: def _rotate_half(x: torch.Tensor, interleaved: bool) -> torch.Tensor:
...@@ -185,6 +194,7 @@ def _rotate_half(x: torch.Tensor, interleaved: bool) -> torch.Tensor: ...@@ -185,6 +194,7 @@ def _rotate_half(x: torch.Tensor, interleaved: bool) -> torch.Tensor:
def _apply_rotary_pos_emb_base( def _apply_rotary_pos_emb_base(
t: torch.Tensor, t: torch.Tensor,
freqs: torch.Tensor, freqs: torch.Tensor,
start_positions: torch.Tensor = None,
tensor_format: str = "sbhd", tensor_format: str = "sbhd",
interleaved: bool = False, interleaved: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -199,6 +209,9 @@ def _apply_rotary_pos_emb_base( ...@@ -199,6 +209,9 @@ def _apply_rotary_pos_emb_base(
freqs: torch.Tensor freqs: torch.Tensor
Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float', Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
with `s2 >= s` and `d2 <= d`. with `s2 >= s` and `d2 <= d`.
start_positions: torch.Tensor, default = None.
Tokens in a sequence `i` should be applied with position encoding offset by
`start_positions[i]`. If `start_positions=None`, there's no offset.
tensor_format: {'sbhd', 'bshd'}, default = 'sbhd' tensor_format: {'sbhd', 'bshd'}, default = 'sbhd'
Should be `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is of shape Should be `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is of shape
`[seq, bs, ...]`. `[seq, bs, ...]`.
...@@ -208,14 +221,32 @@ def _apply_rotary_pos_emb_base( ...@@ -208,14 +221,32 @@ def _apply_rotary_pos_emb_base(
max_seq_len = freqs.shape[0] max_seq_len = freqs.shape[0]
cur_seq_len = t.shape[1] if tensor_format == "bshd" else t.shape[0] cur_seq_len = t.shape[1] if tensor_format == "bshd" else t.shape[0]
# In case `start_positions` are provided, create a staggered `freqs` tensor
# offset by the values in `start_positions`.
# `start_positions` is only supported for `cp_size=1` and inference.
if start_positions is not None:
max_offset = torch.max(start_positions)
assert (
max_offset + cur_seq_len <= max_seq_len
), f"Rotary Embeddings only suppported up to {max_seq_len} sequence length!"
# Stack staggered rope embeddings along the batch dimension
freqs = torch.concatenate([freqs[i : i + cur_seq_len] for i in start_positions], dim=1)
# Note that from this point, `freqs` has a shape `(s,b,1,d)`.
# Only apply the rotary embeddings up to the sequence length of the running # Only apply the rotary embeddings up to the sequence length of the running
# input. # input.
assert ( assert (
cur_seq_len <= max_seq_len cur_seq_len <= max_seq_len
), f"Rotary Embeddings only supported up to {max_seq_len} sequence length!" ), f"Rotary Embeddings only supported up to {max_seq_len} sequence length!"
freqs = freqs[:cur_seq_len] freqs = freqs[:cur_seq_len]
# [seq, 1, 1, dim] -> [1, seq, 1, dim] or
# [seq, b, 1, dim] -> [b, seq, 1, dim]
if tensor_format == "bshd": if tensor_format == "bshd":
freqs = freqs.transpose(0, 1) # [seq, 1, 1, dim] -> [1, seq, 1, dim] freqs = freqs.transpose(0, 1)
# cos/sin first then dtype conversion for better precision # cos/sin first then dtype conversion for better precision
cos_ = torch.cos(freqs).to(t.dtype) cos_ = torch.cos(freqs).to(t.dtype)
sin_ = torch.sin(freqs).to(t.dtype) sin_ = torch.sin(freqs).to(t.dtype)
...@@ -252,13 +283,14 @@ def _get_freqs_on_this_cp_rank( ...@@ -252,13 +283,14 @@ def _get_freqs_on_this_cp_rank(
) )
# cp_size == 1 # cp_size == 1
return freqs[:seqlen] return freqs
def apply_rotary_pos_emb( def apply_rotary_pos_emb(
t: torch.Tensor, t: torch.Tensor,
freqs: torch.Tensor, freqs: torch.Tensor,
tensor_format: str = "sbhd", tensor_format: str = "sbhd",
start_positions: Union[torch.Tensor, None] = None,
interleaved: bool = False, interleaved: bool = False,
fused: bool = False, fused: bool = False,
cu_seqlens: Union[torch.Tensor, None] = None, cu_seqlens: Union[torch.Tensor, None] = None,
...@@ -268,6 +300,19 @@ def apply_rotary_pos_emb( ...@@ -268,6 +300,19 @@ def apply_rotary_pos_emb(
""" """
Apply rotary positional embedding tensor to the input tensor. Apply rotary positional embedding tensor to the input tensor.
Support matrix:
Fused/Unfused:
Training:
qkv_formats: "thd", "bshd", "sbhd"
context parallel: yes
start_positions: no
interleaving: yes
Inference:
qkv_formats: "thd", "bshd", "sbhd"
context parallelism: no
start_positions: yes
interleaving: yes
Parameters Parameters
---------- ----------
t: torch.Tensor t: torch.Tensor
...@@ -276,6 +321,9 @@ def apply_rotary_pos_emb( ...@@ -276,6 +321,9 @@ def apply_rotary_pos_emb(
freqs: torch.Tensor freqs: torch.Tensor
Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float', Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
with `s2 >= s` and `d2 <= d`. with `s2 >= s` and `d2 <= d`.
start_positions: torch.Tensor, default = None.
Tokens in a sequence `i` should be applied with position encoding offset by
`start_positions[i]`. If `start_positions=None`, there's no offset.
tensor_format: {'sbhd', 'bshd', 'thd'}, default = 'sbhd' tensor_format: {'sbhd', 'bshd', 'thd'}, default = 'sbhd'
is `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is is `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is
of shape `[seq, bs, ...]`. 'thd' is only supported when `fused` is True. of shape `[seq, bs, ...]`. 'thd' is only supported when `fused` is True.
...@@ -292,27 +340,43 @@ def apply_rotary_pos_emb( ...@@ -292,27 +340,43 @@ def apply_rotary_pos_emb(
cp_rank: int, default = 0. cp_rank: int, default = 0.
Context parallel rank. Only valid when `tensor_format` is 'thd' and `fused` is True. Context parallel rank. Only valid when `tensor_format` is 'thd' and `fused` is True.
""" """
# `start_positions` is only supported for `cp_size=1` and inference.
assert not (
cp_size > 1 and start_positions is not None
), """start_positions != None with CP SIZE > 1 is not supported!"""
assert ( assert (
tensor_format != "thd" or cu_seqlens is not None tensor_format != "thd" or cu_seqlens is not None
), "cu_seqlens must not be None when tensor_format is 'thd'." ), "cu_seqlens must not be None when tensor_format is 'thd'."
if fused: if fused:
return FusedRoPEFunc.apply( return FusedRoPEFunc.apply(
t, freqs, tensor_format, interleaved, cu_seqlens, cp_size, cp_rank t, freqs, start_positions, tensor_format, interleaved, cu_seqlens, cp_size, cp_rank
) )
# Unfused THD format # Unfused THD format
if tensor_format == "thd": if tensor_format == "thd":
cu_seqlens = cu_seqlens // cp_size cu_seqlens = cu_seqlens // cp_size
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
# The following code essentially splits the `thd` tensor into corresponding
# `s1hd` tensors (for each sequence) and applies rotary embedding to
# those sequences individually.
# Note that if `start_positions` is not `None`, then for each sequence,
# it's corresponding rope offset is also supplied from `start_positions`
# individually.
return torch.cat( return torch.cat(
[ [
_apply_rotary_pos_emb_base( _apply_rotary_pos_emb_base(
x.unsqueeze(1), x.unsqueeze(1),
_get_freqs_on_this_cp_rank(freqs, x.size(0), cp_size, cp_rank), _get_freqs_on_this_cp_rank(freqs, x.size(0), cp_size, cp_rank),
start_positions=(
start_positions[idx : idx + 1] if start_positions is not None else None
),
interleaved=interleaved, interleaved=interleaved,
) )
for x in torch.split(t, seqlens) for idx, x in enumerate(torch.split(t, seqlens))
] ]
).squeeze(1) ).squeeze(1)
...@@ -326,6 +390,7 @@ def apply_rotary_pos_emb( ...@@ -326,6 +390,7 @@ def apply_rotary_pos_emb(
return _apply_rotary_pos_emb_base( return _apply_rotary_pos_emb_base(
t, t,
_get_freqs_on_this_cp_rank(freqs, seqlen, cp_size, cp_rank), _get_freqs_on_this_cp_rank(freqs, seqlen, cp_size, cp_rank),
start_positions,
tensor_format, tensor_format,
interleaved=interleaved, interleaved=interleaved,
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment