Unverified Commit 6c1a8bb5 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Common][PyTorch] Fused `apply_rotorary_pos_emb` (#517)



* fused apply rope
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* Apply suggestions from code review
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarXin Yao <yaox12@outlook.com>

* resolve comments
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* make rotary_percent optional
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* fix ci
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* fix test
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* add rope test to qa
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* fix linting
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* sync apex: add transpose_output_memory
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* small fix
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* sync apex: fuse sin/cos
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* sync apex: fused rope for thd format
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* fix lint
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* Fix license headers
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* add support for bshd format
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* support different seq length
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* update
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* update copyright
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* remove transpose_output_memory
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* Make outputs contiguous in SBHD case
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

---------
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Signed-off-by: default avatarXin Yao <yaox12@outlook.com>
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarPrzemyslaw Tredak <ptredak@nvidia.com>
parent b957aa47
......@@ -12,5 +12,6 @@ pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py
PYTORCH_JIT=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py
pytest -v -s $TE_PATH/tests/pytorch/test_jit.py
pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py
pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py
pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
import torch
from typing import Callable, Dict, Tuple, Union
from transformer_engine.pytorch.attention import (
RotaryPositionEmbedding,
apply_rotary_pos_emb,
)
def apply_rotary_pos_emb_thd(
t: torch.Tensor, cu_seqlens: torch.Tensor, freqs: torch.Tensor
) -> torch.Tensor:
"""A baseline implementation of applying RoPE for `thd` format.
Args:
t (Tensor): Input tensor T is of shape [t, h, d]
cu_seqlens(Tensor): Cumulative sum of sequence lengths in a batch for `t`,
with shape [b + 1] and dtype torch.int32.
freqs (Tensor): Rotary Positional embedding tensor freq is of shape [max_s, 1, 1, d]
Returns:
Tensor: Shape [t, h, d]. The input tensor after applying RoPE.
"""
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
return torch.cat(
[
apply_rotary_pos_emb(x.unsqueeze(1), freqs[: x.size(0)])
for x in torch.split(t, seqlens)
]
).squeeze(1)
def get_tol(dtype: torch.dtype) -> Dict:
if dtype == torch.bfloat16:
return dict(atol=1e-2, rtol=1e-2)
elif dtype == torch.float16:
return dict(atol=1e-3, rtol=1e-3)
return dict(atol=1e-5, rtol=1.3e-6)
# Gradient is a broadcasted scalar
def _overlapping_grad(output: torch.Tensor) -> torch.Tensor:
return output.sum() * 2
# Gradient is a full tensor
def _non_overlapping_grad(output: torch.Tensor) -> torch.Tensor:
t = torch.ones_like(output)
return torch.sum(output * t)
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
@pytest.mark.parametrize("seq_length", [2048, 4096])
@pytest.mark.parametrize("hidden_size", [128, 256])
@pytest.mark.parametrize("rotary_percent", [0.5, 1.0])
@pytest.mark.parametrize("margin", [0, 10])
@pytest.mark.parametrize("transpose", [None, (0, 1), (2, 3)])
@pytest.mark.parametrize("tensor_format", ["sbhd", "bshd"])
@pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad])
def test_fused_rope(
dtype: torch.dtype,
seq_length: int,
hidden_size: int,
rotary_percent: float,
margin: int,
transpose: Union[Tuple, None],
tensor_format: str,
loss_func: Callable,
) -> None:
device = torch.device("cuda:0")
batch_size, head_num = 2, 64
t = torch.rand(
(seq_length - margin, batch_size, head_num, hidden_size),
dtype=dtype,
device=device,
)
if tensor_format == "bshd":
t = t.transpose(0, 1).contiguous()
if transpose:
t = t.transpose(*transpose).contiguous().transpose(*transpose)
t.requires_grad = True
rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent)
emb = rotary_pos_emb(seq_length)
# unfused
output_unfused = apply_rotary_pos_emb(
t, emb, tensor_format=tensor_format, fused=False
)
loss_unfused = loss_func(output_unfused)
loss_unfused.backward()
grad_unfused = t.grad.detach().clone()
t.grad = None
# fused
output_fused = apply_rotary_pos_emb(
t,
emb,
tensor_format=tensor_format,
fused=True,
)
loss_fused = loss_func(output_fused)
loss_fused.backward()
grad_fused = t.grad.detach().clone()
t.grad = None
torch.testing.assert_close(output_fused, output_unfused, **get_tol(dtype))
torch.testing.assert_close(grad_fused, grad_unfused, **get_tol(dtype))
assert output_fused.is_contiguous()
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
@pytest.mark.parametrize("hidden_size", [128, 256])
@pytest.mark.parametrize("rotary_percent", [0.5, 1.0])
@pytest.mark.parametrize("transpose", [None, (1, 2)])
@pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad])
def test_fused_rope_thd(
dtype: torch.dtype,
hidden_size: int,
rotary_percent: float,
transpose: Union[Tuple, None],
loss_func: Callable,
) -> None:
device = torch.device("cuda:0")
batch_size, head_num = 2, 64
cu_seqlens = torch.tensor(
[0, 400, 542, 711, 727, 752, 1270, 1426, 1450, 1954, 2044, 2048],
dtype=torch.int32,
device=device,
)
t = torch.rand(
(cu_seqlens[-1], head_num, hidden_size),
dtype=dtype,
device=device,
)
if transpose:
t = t.transpose(*transpose).contiguous().transpose(*transpose)
t.requires_grad = True
rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent)
emb = rotary_pos_emb(cu_seqlens[-1])
# unfused
output_unfused = apply_rotary_pos_emb_thd(t, cu_seqlens, emb)
loss_unfused = loss_func(output_unfused)
loss_unfused.backward()
grad_unfused = t.grad.detach().clone()
t.grad = None
# fused
output_fused = apply_rotary_pos_emb(
t, emb, fused=True, tensor_format="thd", cu_seqlens=cu_seqlens
)
loss_fused = loss_func(output_fused)
loss_fused.backward()
grad_fused = t.grad.detach().clone()
t.grad = None
torch.testing.assert_close(output_fused, output_unfused, **get_tol(dtype))
torch.testing.assert_close(grad_fused, grad_unfused, **get_tol(dtype))
......@@ -34,7 +34,8 @@ list(APPEND transformer_engine_SOURCES
fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu)
fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_rope/fused_rope.cu)
add_library(transformer_engine SHARED ${transformer_engine_SOURCES})
target_include_directories(transformer_engine PUBLIC
"${CMAKE_CURRENT_SOURCE_DIR}/include")
......
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda_runtime.h>
#include <transformer_engine/fused_rope.h>
#include "../common.h"
#include "../util/logging.h"
#include "../utils.cuh"
namespace transformer_engine {
template <typename scalar_t>
__device__ void fused_rope_block_forward(
const scalar_t *src, const float *freqs, scalar_t *dst,
const int offset_block, const int offset_block_dst, const int h,
const int d, const int d2, const int stride_h, const int stride_d,
const int o_stride_h, const int o_stride_d) {
int s_id = blockIdx.x;
#pragma unroll
for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) {
float v_cos, v_sin;
sincosf(freqs[s_id * d2 + d_id], &v_sin, &v_cos);
#pragma unroll
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
int offset_src = offset_block + h_id * stride_h + d_id * stride_d;
int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d;
scalar_t v_src = src[offset_src];
scalar_t v_src_rotate = (d_id + d2 / 2 < d2)
? -src[offset_src + (d2 / 2) * stride_d]
: src[offset_src + (d2 / 2 - d2) * stride_d];
dst[offset_dst] =
v_src * (scalar_t)v_cos + v_src_rotate * (scalar_t)v_sin;
}
}
// copy the rest
if (d > d2) {
#pragma unroll
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
int offset_head = offset_block + h_id * stride_h;
int offset_head_dst = offset_block_dst + h_id * o_stride_h;
#pragma unroll
for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) {
dst[offset_head_dst + d_id * o_stride_d] =
src[offset_head + d_id * stride_d];
}
}
}
}
template <typename scalar_t>
__device__ void fused_rope_block_backward(
const scalar_t *src, const float *freqs, scalar_t *dst,
const int offset_block, const int offset_block_dst, const int h,
const int d, const int d2, const int stride_h, const int stride_d,
const int o_stride_h, const int o_stride_d) {
int s_id = blockIdx.x;
#pragma unroll
for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) {
scalar_t v_cos = cosf(freqs[s_id * d2 + d_id]);
scalar_t v_sin = (d_id + d2 / 2 < d2)
? sinf(freqs[s_id * d2 + d_id + d2 / 2])
: -sinf(freqs[s_id * d2 + d_id + d2 / 2 - d2]);
#pragma unroll
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
int offset_src = offset_block + h_id * stride_h + d_id * stride_d;
int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d;
scalar_t v_src = src[offset_src];
scalar_t v_src_rotate = (d_id + d2 / 2 < d2)
? src[offset_src + (d2 / 2) * stride_d]
: src[offset_src + (d2 / 2 - d2) * stride_d];
dst[offset_dst] = v_src * v_cos + v_src_rotate * v_sin;
}
}
// handle the tail
if (d > d2) {
#pragma unroll
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
int offset_head = offset_block + h_id * stride_h;
int offset_head_dst = offset_block_dst + h_id * o_stride_h;
#pragma unroll
for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) {
dst[offset_head_dst + d_id * o_stride_d] =
src[offset_head + d_id * stride_d];
}
}
}
}
template <typename scalar_t>
__global__ void fused_rope_forward_kernel(
const scalar_t *src, const float *freqs, scalar_t *dst, const int h,
const int d, const int d2, const int stride_s, const int stride_b,
const int stride_h, const int stride_d, const int o_stride_s,
const int o_stride_b, const int o_stride_h, const int o_stride_d) {
int s_id = blockIdx.x, b_id = blockIdx.y;
int offset_block = s_id * stride_s + b_id * stride_b;
int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b;
fused_rope_block_forward(src, freqs, dst, offset_block, offset_block_dst, h,
d, d2, stride_h, stride_d, o_stride_h, o_stride_d);
}
template <typename scalar_t>
__global__ void fused_rope_backward_kernel(
const scalar_t *src, const float *freqs, scalar_t *dst, const int h,
const int d, const int d2, const int stride_s, const int stride_b,
const int stride_h, const int stride_d, const int o_stride_s,
const int o_stride_b, const int o_stride_h, const int o_stride_d) {
int s_id = blockIdx.x, b_id = blockIdx.y;
int offset_block = s_id * stride_s + b_id * stride_b;
int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b;
fused_rope_block_backward(src, freqs, dst, offset_block, offset_block_dst, h,
d, d2, stride_h, stride_d, o_stride_h, o_stride_d);
}
template <typename scalar_t>
__global__ void fused_rope_thd_forward_kernel(
const scalar_t *src, const int *cu_seqlens, const float *freqs,
scalar_t *dst, const int h, const int d, const int d2, const int stride_t,
const int stride_h, const int stride_d, const int o_stride_t,
const int o_stride_h, const int o_stride_d) {
int s_id = blockIdx.x, b_id = blockIdx.y;
int t_id = s_id + cu_seqlens[b_id];
if (t_id >= cu_seqlens[b_id + 1]) return;
int offset_block = t_id * stride_t;
int offset_block_dst = t_id * o_stride_t;
fused_rope_block_forward(src, freqs, dst, offset_block, offset_block_dst, h,
d, d2, stride_h, stride_d, o_stride_h, o_stride_d);
}
template <typename scalar_t>
__global__ void fused_rope_thd_backward_kernel(
const scalar_t *src, const int *cu_seqlens, const float *freqs,
scalar_t *dst, const int h, const int d, const int d2, const int stride_t,
const int stride_h, const int stride_d, const int o_stride_t,
const int o_stride_h, const int o_stride_d) {
int s_id = blockIdx.x, b_id = blockIdx.y;
int t_id = s_id + cu_seqlens[b_id];
if (t_id >= cu_seqlens[b_id + 1]) return;
int offset_block = t_id * stride_t;
int offset_block_dst = t_id * o_stride_t;
fused_rope_block_backward(src, freqs, dst, offset_block, offset_block_dst, h,
d, d2, stride_h, stride_d, o_stride_h, o_stride_d);
}
template <typename scalar_t>
void fused_rope_forward_launcher(const scalar_t *input, const float *freqs,
scalar_t *output, const int s, const int b,
const int h, const int d, const int d2,
const int stride_s, const int stride_b,
const int stride_h, const int stride_d,
const int o_stride_s, const int o_stride_b,
const int o_stride_h, const int o_stride_d,
cudaStream_t stream) {
int warps_per_block = h < 16 ? 4 : 8;
dim3 blocks(s, b);
dim3 threads(THREADS_PER_WARP, warps_per_block);
fused_rope_forward_kernel<<<blocks, threads, 0, stream>>>(
input, freqs, output, h, d, d2, stride_s, stride_b, stride_h, stride_d,
o_stride_s, o_stride_b, o_stride_h, o_stride_d);
NVTE_CHECK_CUDA(cudaGetLastError());
}
template <typename scalar_t>
void fused_rope_backward_launcher(const scalar_t *output_grads,
const float *freqs, scalar_t *input_grads,
const int s, const int b, const int h,
const int d, const int d2, const int stride_s,
const int stride_b, const int stride_h,
const int stride_d, const int o_stride_s,
const int o_stride_b, const int o_stride_h,
const int o_stride_d, cudaStream_t stream) {
int warps_per_block = h < 16 ? 4 : 8;
dim3 blocks(s, b);
dim3 threads(THREADS_PER_WARP, warps_per_block);
fused_rope_backward_kernel<<<blocks, threads, 0, stream>>>(
output_grads, freqs, input_grads, h, d, d2, stride_s, stride_b, stride_h,
stride_d, o_stride_s, o_stride_b, o_stride_h, o_stride_d);
NVTE_CHECK_CUDA(cudaGetLastError());
}
template <typename scalar_t>
void fused_rope_thd_forward_launcher(
const scalar_t *input, const int *cu_seqlens, const float *freqs,
scalar_t *output, const int max_s, const int b, const int h, const int d,
const int d2, const int stride_t, const int stride_h, const int stride_d,
const int o_stride_t, const int o_stride_h, const int o_stride_d,
cudaStream_t stream) {
int warps_per_block = h < 16 ? 4 : 8;
dim3 blocks(max_s, b);
dim3 threads(THREADS_PER_WARP, warps_per_block);
fused_rope_thd_forward_kernel<<<blocks, threads, 0, stream>>>(
input, cu_seqlens, freqs, output, h, d, d2, stride_t, stride_h, stride_d,
o_stride_t, o_stride_h, o_stride_d);
NVTE_CHECK_CUDA(cudaGetLastError());
}
template <typename scalar_t>
void fused_rope_thd_backward_launcher(
const scalar_t *output_grads, const int *cu_seqlens, const float *freqs,
scalar_t *input_grads, const int max_s, const int b, const int h,
const int d, const int d2, const int stride_t, const int stride_h,
const int stride_d, const int o_stride_t, const int o_stride_h,
const int o_stride_d, cudaStream_t stream) {
int warps_per_block = h < 16 ? 4 : 8;
dim3 blocks(max_s, b);
dim3 threads(THREADS_PER_WARP, warps_per_block);
fused_rope_thd_backward_kernel<<<blocks, threads, 0, stream>>>(
output_grads, cu_seqlens, freqs, input_grads, h, d, d2, stride_t,
stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d);
NVTE_CHECK_CUDA(cudaGetLastError());
}
void fused_rope_forward(const Tensor &input, const Tensor &freqs,
Tensor *output, const int s, const int b, const int h,
const int d, const int d2, const int stride_s,
const int stride_b, const int stride_h,
const int stride_d, const int o_stride_s,
const int o_stride_b, const int o_stride_h,
const int o_stride_d, cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, scalar_t,
fused_rope_forward_launcher(
reinterpret_cast<const scalar_t *>(input.data.dptr),
reinterpret_cast<const float *>(freqs.data.dptr),
reinterpret_cast<scalar_t *>(output->data.dptr), s, b, h, d, d2,
stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b,
o_stride_h, o_stride_d, stream););
}
void fused_rope_backward(const Tensor &output_grads, const Tensor &freqs,
Tensor *input_grads, const int s, const int b,
const int h, const int d, const int d2,
const int stride_s, const int stride_b,
const int stride_h, const int stride_d,
const int o_stride_s, const int o_stride_b,
const int o_stride_h, const int o_stride_d,
cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
output_grads.data.dtype, scalar_t,
fused_rope_backward_launcher(
reinterpret_cast<const scalar_t *>(output_grads.data.dptr),
reinterpret_cast<const float *>(freqs.data.dptr),
reinterpret_cast<scalar_t *>(input_grads->data.dptr), s, b, h, d, d2,
stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b,
o_stride_h, o_stride_d, stream););
}
void fused_rope_thd_forward(const Tensor &input, const Tensor &cu_seqlens,
const Tensor &freqs, Tensor *output,
const int max_s, const int b, const int h,
const int d, const int d2, const int stride_t,
const int stride_h, const int stride_d,
const int o_stride_t, const int o_stride_h,
const int o_stride_d, cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, scalar_t,
fused_rope_thd_forward_launcher(
reinterpret_cast<const scalar_t *>(input.data.dptr),
reinterpret_cast<const int *>(cu_seqlens.data.dptr),
reinterpret_cast<const float *>(freqs.data.dptr),
reinterpret_cast<scalar_t *>(output->data.dptr), max_s, b, h, d, d2,
stride_t, stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d,
stream););
}
void fused_rope_thd_backward(const Tensor &output_grads,
const Tensor &cu_seqlens, const Tensor &freqs,
Tensor *input_grads, const int max_s, const int b,
const int h, const int d, const int d2,
const int stride_t, const int stride_h,
const int stride_d, const int o_stride_t,
const int o_stride_h, const int o_stride_d,
cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
output_grads.data.dtype, scalar_t,
fused_rope_thd_backward_launcher(
reinterpret_cast<const scalar_t *>(output_grads.data.dptr),
reinterpret_cast<const int *>(cu_seqlens.data.dptr),
reinterpret_cast<const float *>(freqs.data.dptr),
reinterpret_cast<scalar_t *>(input_grads->data.dptr), max_s, b, h, d,
d2, stride_t, stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d,
stream););
}
} // end namespace transformer_engine
void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor freqs,
NVTETensor output, const int s, const int b,
const int h, const int d, const int d2,
const int stride_s, const int stride_b,
const int stride_h, const int stride_d,
const int o_stride_s, const int o_stride_b,
const int o_stride_h, const int o_stride_d,
cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_rope_forward);
using namespace transformer_engine;
fused_rope_forward(*reinterpret_cast<const Tensor *>(input),
*reinterpret_cast<const Tensor *>(freqs),
reinterpret_cast<Tensor *>(output), s, b, h, d, d2,
stride_s, stride_b, stride_h, stride_d, o_stride_s,
o_stride_b, o_stride_h, o_stride_d, stream);
}
void nvte_fused_rope_backward(const NVTETensor output_grads,
const NVTETensor freqs, NVTETensor input_grads,
const int s, const int b, const int h,
const int d, const int d2, const int stride_s,
const int stride_b, const int stride_h,
const int stride_d, const int o_stride_s,
const int o_stride_b, const int o_stride_h,
const int o_stride_d, cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_rope_backward);
using namespace transformer_engine;
fused_rope_backward(*reinterpret_cast<const Tensor *>(output_grads),
*reinterpret_cast<const Tensor *>(freqs),
reinterpret_cast<Tensor *>(input_grads), s, b, h, d, d2,
stride_s, stride_b, stride_h, stride_d, o_stride_s,
o_stride_b, o_stride_h, o_stride_d, stream);
}
void nvte_fused_rope_thd_forward(const NVTETensor input,
const NVTETensor cu_seqlens,
const NVTETensor freqs, NVTETensor output,
const int max_s, const int b, const int h,
const int d, const int d2, const int stride_t,
const int stride_h, const int stride_d,
const int o_stride_t, const int o_stride_h,
const int o_stride_d, cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_rope_thd_forward);
using namespace transformer_engine;
fused_rope_thd_forward(*reinterpret_cast<const Tensor *>(input),
*reinterpret_cast<const Tensor *>(cu_seqlens),
*reinterpret_cast<const Tensor *>(freqs),
reinterpret_cast<Tensor *>(output), max_s, b, h, d, d2,
stride_t, stride_h, stride_d, o_stride_t, o_stride_h,
o_stride_d, stream);
}
void nvte_fused_rope_thd_backward(
const NVTETensor output_grads, const NVTETensor cu_seqlens,
const NVTETensor freqs, NVTETensor input_grads, const int max_s,
const int b, const int h, const int d, const int d2, const int stride_t,
const int stride_h, const int stride_d, const int o_stride_t,
const int o_stride_h, const int o_stride_d, cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_rope_thd_backward);
using namespace transformer_engine;
fused_rope_thd_backward(*reinterpret_cast<const Tensor *>(output_grads),
*reinterpret_cast<const Tensor *>(cu_seqlens),
*reinterpret_cast<const Tensor *>(freqs),
reinterpret_cast<Tensor *>(input_grads), max_s, b, h,
d, d2, stride_t, stride_h, stride_d, o_stride_t,
o_stride_h, o_stride_d, stream);
}
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_FUSED_ROPE_H_
#define TRANSFORMER_ENGINE_FUSED_ROPE_H_
#include "transformer_engine.h"
#ifdef __cplusplus
extern "C" {
#endif
/*! \brief Apply rotary positional embedding to the input tensor.
*
* \param[in] input Input tensor for fused rope.
* \param[in] freqs The freqs tensor.
* \param[out] output Output tensor.
* \param[in] s Length of the s dimension of input.
* \param[in] b Length of the b dimension of input.
* \param[in] h Length of the h dimension of input.
* \param[in] d Length of the d dimension of input.
* \param[in] d2 Length of the d dimension of freqs.
* \param[in] stride_s Stride of the s dimension of input.
* \param[in] stride_b Stride of the b dimension of input.
* \param[in] stride_h Stride of the h dimension of input.
* \param[in] stride_d Stride of the d dimension of input.
* \param[in] o_stride_s Stride of the s dimension of output.
* \param[in] o_stride_b Stride of the b dimension of output.
* \param[in] o_stride_h Stride of the h dimension of output.
* \param[in] o_stride_d Stride of the d dimension of output.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor freqs,
NVTETensor output, const int s, const int b,
const int h, const int d, const int d2,
const int stride_s, const int stride_b,
const int stride_h, const int stride_d,
const int o_stride_s, const int o_stride_b,
const int o_stride_h, const int o_stride_d,
cudaStream_t stream);
/*! \brief Compute the backward of the fused rope.
*
* \param[in] output_grads Incoming gradient tensor for backward.
* \param[in] freqs The freqs tensor.
* \param[out] input_grads Input gradient tensor to calculate.
* \param[in] s Length of the s dimension of output_grads.
* \param[in] b Length of the b dimension of output_grads.
* \param[in] h Length of the h dimension of output_grads.
* \param[in] d Length of the d dimension of output_grads.
* \param[in] d2 Length of the d dimension of freqs.
* \param[in] stride_s Stride of the s dimension of output_grads.
* \param[in] stride_b Stride of the b dimension of output_grads.
* \param[in] stride_h Stride of the h dimension of output_grads.
* \param[in] stride_d Stride of the d dimension of output_grads.
* \param[in] o_stride_s Stride of the s dimension of input_grads.
* \param[in] o_stride_b Stride of the b dimension of input_grads.
* \param[in] o_stride_h Stride of the h dimension of input_grads.
* \param[in] o_stride_d Stride of the d dimension of input_grads.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fused_rope_backward(const NVTETensor output_grads,
const NVTETensor freqs, NVTETensor input_grads,
const int s, const int b, const int h,
const int d, const int d2, const int stride_s,
const int stride_b, const int stride_h,
const int stride_d, const int o_stride_s,
const int o_stride_b, const int o_stride_h,
const int o_stride_d, cudaStream_t stream);
/*! \brief Apply rotary positional embedding to the input tensor in thd format.
*
* \param[in] input Input tensor for fused rope.
* \param[in] cu_seqlens The cumulative sum of sequence lengths tensor.
* \param[in] freqs The freqs tensor.
* \param[out] output Output tensor.
* \param[in] max_s Max sequence length.
* \param[in] b Batch size.
* \param[in] h Length of the h dimension of input.
* \param[in] d Length of the d dimension of input.
* \param[in] d2 Length of the d dimension of freqs.
* \param[in] stride_t Stride of the t dimension of input.
* \param[in] stride_h Stride of the h dimension of input.
* \param[in] stride_d Stride of the d dimension of input.
* \param[in] o_stride_t Stride of the t dimension of output.
* \param[in] o_stride_h Stride of the h dimension of output.
* \param[in] o_stride_d Stride of the d dimension of output.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fused_rope_thd_forward(const NVTETensor input,
const NVTETensor cu_seqlens,
const NVTETensor freqs, NVTETensor output,
const int max_s, const int b, const int h,
const int d, const int d2, const int stride_t,
const int stride_h, const int stride_d,
const int o_stride_t, const int o_stride_h,
const int o_stride_d, cudaStream_t stream);
/*! \brief Compute the backward of the fused rope in thd format.
*
* \param[in] output_grads Incoming gradient tensor for backward.
* \param[in] cu_seqlens The cumulative sum of sequence lengths tensor.
* \param[in] freqs The freqs tensor.
* \param[out] input_grads Input gradient to calculate.
* \param[in] max_s Max sequence length.
* \param[in] b Batch size.
* \param[in] h Length of the h dimension of output_grads.
* \param[in] d Length of the d dimension of output_grads.
* \param[in] d2 Length of the d dimension of freqs.
* \param[in] stride_t Stride of the t dimension of output_grads.
* \param[in] stride_h Stride of the h dimension of output_grads.
* \param[in] stride_d Stride of the d dimension of output_grads.
* \param[in] o_stride_t Stride of the t dimension of input_grads.
* \param[in] o_stride_h Stride of the h dimension of input_grads.
* \param[in] o_stride_d Stride of the d dimension of input_grads.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fused_rope_thd_backward(
const NVTETensor output_grads, const NVTETensor cu_seqlens,
const NVTETensor freqs, NVTETensor input_grads, const int max_s,
const int b, const int h, const int d, const int d2, const int stride_t,
const int stride_h, const int stride_d, const int o_stride_t,
const int o_stride_h, const int o_stride_d, cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
#endif // TRANSFORMER_ENGINE_FUSED_ROPE_H_
......@@ -974,6 +974,7 @@ class RotaryPositionEmbedding(torch.nn.Module):
def __init__(
self,
dim: int,
rotary_percent: float = 1.0,
seq_len_interpolation_factor: Optional[int] = None,
pretrained_max_position_embeddings: Optional[int] = None,
):
......@@ -982,6 +983,8 @@ class RotaryPositionEmbedding(torch.nn.Module):
----------
dim: int
rotary embedding dimension
rotary_percent: float
Percent of rotary dimension to use for rotary position embeddings.
seq_len_interpolation_factor: int
if not None, discrete positions will be interpolated by this factor via the trick in
https://arxiv.org/abs/2306.15595
......@@ -989,8 +992,16 @@ class RotaryPositionEmbedding(torch.nn.Module):
pre-trained max_position_embeddings before position interpolation
"""
super().__init__()
if rotary_percent < 1.0:
dim = int(dim * rotary_percent)
self.seq_len_interpolation_factor = seq_len_interpolation_factor
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
inv_freq = 1.0 / (
10000
** (
torch.arange(0, dim, 2, dtype=torch.float32, device=torch.cuda.current_device())
/ dim
)
)
self.register_buffer('inv_freq', inv_freq)
self.pretrained_max_position_embeddings = pretrained_max_position_embeddings
......@@ -1005,8 +1016,10 @@ class RotaryPositionEmbedding(torch.nn.Module):
offset: int, default = 0
fixed offset for freqencies
"""
seq = torch.arange(max_seq_len, device=self.inv_freq.device) + offset
seq = seq.type_as(self.inv_freq)
seq = (
torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
+ offset
)
if (self.pretrained_max_position_embeddings is not None
and self.seq_len_interpolation_factor is not None):
......@@ -1025,6 +1038,58 @@ class RotaryPositionEmbedding(torch.nn.Module):
# emb [seq_length, .., dim]
return emb.reshape(emb.size(0), 1, 1, emb.size(1))
class FusedRoPEFunc(torch.autograd.Function):
"""
Function for FusedRoPE
This implementation assumes the input tensor to be in `sbhd`, `bshd` or `thd` format and
the RoPE tensor to be of shape (s, 1, 1, d). It accepts arbitrary memory layouts to avoid
the expensive `.contiguous()` calls, thus it may not achieve the best memory access pattern.
"""
@staticmethod
def forward(
ctx,
t: torch.Tensor,
freqs: torch.Tensor,
tensor_format: str = "sbhd",
cu_seqlens: Union[torch.Tensor, None] = None,
) -> torch.Tensor:
if tensor_format == "sbhd":
output = tex.fused_rope_forward(t, freqs, False)
elif tensor_format == "bshd":
output = tex.fused_rope_forward(
t.transpose(0, 1), freqs, True
).transpose(0, 1)
elif tensor_format == "thd":
output = tex.fused_rope_thd_forward(t, cu_seqlens, freqs)
else:
raise ValueError(f"Unsupported tensor_format: {tensor_format}.")
ctx.save_for_backward(freqs, cu_seqlens)
ctx.tensor_format = tensor_format
return output
@staticmethod
def backward(
ctx, grad_output: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]:
freqs, cu_seqlens = ctx.saved_tensors
if ctx.tensor_format == "sbhd":
grad_input = tex.fused_rope_backward(grad_output, freqs, False)
elif ctx.tensor_format == "bshd":
grad_input = tex.fused_rope_backward(
grad_output.transpose(0, 1), freqs, True
).transpose(0, 1)
elif ctx.tensor_format == "thd":
grad_input = tex.fused_rope_thd_backward(grad_output, cu_seqlens, freqs)
else:
raise ValueError(f"Unsupported tensor_format: {ctx.tensor_format}.")
return grad_input, None, None, None, None
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
"""
change sign so the last dimension becomes [-odd, +even]
......@@ -1035,38 +1100,57 @@ def _rotate_half(x: torch.Tensor) -> torch.Tensor:
def apply_rotary_pos_emb(
t: torch.Tensor,
freqs: torch.Tensor,
tensor_format: str = "sbhd"
) -> torch.Tensor:
t: torch.Tensor,
freqs: torch.Tensor,
tensor_format: str = "sbhd",
fused: bool = False,
cu_seqlens: Union[torch.Tensor, None] = None,
) -> torch.Tensor:
"""
Parameters
----------
t: torch.Tensor
input tensor on which rotary positional embedding will be applied
freqs: torch.Tensor
rotary positional embeding tensor `freqs` is of shape
`[seq_length, ..., dim]`
tensor_format: {'sbhd', 'bshd'}, default = 'sbhd'
is `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is
of shape `[seq, bs, ...]`.
Apply rotary positional embedding tensor to the input tensor.
Parameters
----------
t: torch.Tensor
Input tensor of shape `[s, b, h, d]`, `[s, b, h, d]` or `[t, h, d]`, on which
rotary positional embedding will be applied.
freqs: torch.Tensor
Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
with `s2 >= s` and `d2 <= d`.
fused: bool, default = False
Whether to use a fused applying RoPE implementation.
tensor_format: {'sbhd', 'bshd', 'thd'}, default = 'sbhd'
is `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is
of shape `[seq, bs, ...]`. 'thd' is only supported when `fused` is True.
cu_seqlens: torch.Tensor, default = None.
Cumulative sum of sequence lengths in a batch for `t`, with shape [b + 1] and
dtype torch.int32. Only valid when `tensor_format` is 'thd'.
"""
assert tensor_format in ("sbhd", "bshd"),("Only formats `sbhd` or `bshd` "
"are supported for input tensor "
"`t`.")
if fused:
assert (
tensor_format != "thd" or cu_seqlens is not None
), "cu_seqlens must not be None when tensor_format is 'thd'."
return FusedRoPEFunc.apply(t, freqs, tensor_format, cu_seqlens)
assert tensor_format in ("sbhd", "bshd"), (
"Only formats `sbhd` or `bshd` are supported for input tensor `t` "
f"when fused is False, got {tensor_format}."
)
max_seq_len = freqs.shape[0]
cur_seq_len = t.shape[1] if tensor_format == "bshd" else t.shape[0]
# Only apply the rotary embeddings up to the sequence length of the running
# input.
if cur_seq_len > max_seq_len:
raise Exception(f"Rotary Embeddings only supported upto {max_seq_len} "
"sequence length!")
freqs = freqs[:cur_seq_len].to(t.dtype)
assert cur_seq_len <= max_seq_len, (
f"Rotary Embeddings only supported up to {max_seq_len} sequence length!"
)
freqs = freqs[:cur_seq_len]
if tensor_format == "bshd":
freqs = freqs.transpose(0,1) # [seq, 1, 1, dim] -> [1, seq, 1, dim]
freqs = freqs.transpose(0, 1) # [seq, 1, 1, dim] -> [1, seq, 1, dim]
# cos/sin first then dtype conversion for better precision
cos_ = torch.cos(freqs).to(t.dtype)
sin_ = torch.sin(freqs).to(t.dtype)
rot_dim = freqs.shape[-1]
# ideally t_pass is empty so rotary pos embedding is applied to all tensor t
......@@ -1074,7 +1158,7 @@ def apply_rotary_pos_emb(
# first part is cosine component
# second part is sine component, need to change signs with _rotate_half method
t = (t * freqs.cos()) + (_rotate_half(t) * freqs.sin())
t = (t * cos_) + (_rotate_half(t) * sin_)
return torch.cat((t, t_pass), dim=-1)
......
......@@ -35,6 +35,7 @@
#include <transformer_engine/activation.h>
#include <transformer_engine/cast.h>
#include <transformer_engine/fused_attn.h>
#include <transformer_engine/fused_rope.h>
#include <transformer_engine/gemm.h>
#include <transformer_engine/layer_norm.h>
#include <transformer_engine/rmsnorm.h>
......
......@@ -484,6 +484,9 @@ at::Tensor cast_from_fp8(const at::Tensor &input,
transformer_engine::DType otype
);
/***************************************************************************************************
* Softmax
**************************************************************************************************/
at::Tensor scaled_softmax_forward(at::Tensor input,
float scale_factor
......@@ -518,6 +521,34 @@ at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_,
float scale_factor
);
/***************************************************************************************************
* Rotary positional embedding
**************************************************************************************************/
at::Tensor fused_rope_forward(const at::Tensor &input,
const at::Tensor &freqs,
const bool transpose_output_memory
);
at::Tensor fused_rope_backward(const at::Tensor &output_grads,
const at::Tensor &freqs,
const bool transpose_output_memory
);
at::Tensor fused_rope_thd_forward(const at::Tensor &input,
const at::Tensor &cu_seqlens,
const at::Tensor &freqs
);
at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads,
const at::Tensor &cu_seqlens,
const at::Tensor &freqs
);
/***************************************************************************************************
* Misc
**************************************************************************************************/
size_t get_cublasLt_version();
size_t get_cudnn_version();
......
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs,
const bool transpose_output_memory) {
using namespace transformer_engine;
TORCH_CHECK(input.dim() == 4, "expected 4D tensor");
TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor");
TORCH_CHECK(input.size(0) <= freqs.size(0),
"expected freqs tensor has a longer sequence length than input");
TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1,
"expected the second and third dims of the freqs tensor equal 1");
TORCH_CHECK(input.size(3) >= freqs.size(3),
"expected the last dim of the input tensor equals or is "
"greater than the freqs tensor");
TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float,
"Dtype of the freqs tensor must be float");
// input sizes: (s, b, h, d)
// s: sequence length
// b: batch size
// h: head num
// d: dim of each head
const int s = input.size(0);
const int b = input.size(1);
const int h = input.size(2);
const int d = input.size(3);
// input strides
const int stride_s = input.stride(0);
const int stride_b = input.stride(1);
const int stride_h = input.stride(2);
const int stride_d = input.stride(3);
// freqs' shape is always (s, 1, 1, d2), so the strides are same under
// different memory formats
const int d2 = freqs.size(3);
// output
auto act_options = input.options().requires_grad(false);
at::Tensor output;
if (transpose_output_memory) {
output = torch::empty({b, s, h, d}, act_options).transpose(0, 1);
} else {
output = torch::empty({s, b, h, d}, act_options);
}
// output strides
const int o_stride_s = output.stride(0);
const int o_stride_b = output.stride(1);
const int o_stride_h = output.stride(2);
const int o_stride_d = output.stride(3);
auto input_cu = makeTransformerEngineTensor(input);
auto freqs_cu = makeTransformerEngineTensor(freqs);
auto output_cu = makeTransformerEngineTensor(output);
nvte_fused_rope_forward(input_cu.data(), freqs_cu.data(), output_cu.data(), s,
b, h, d, d2, stride_s, stride_b, stride_h, stride_d,
o_stride_s, o_stride_b, o_stride_h, o_stride_d,
at::cuda::getCurrentCUDAStream());
return output;
}
at::Tensor fused_rope_backward(const at::Tensor &output_grads,
const at::Tensor &freqs,
const bool transpose_output_memory) {
using namespace transformer_engine;
TORCH_CHECK(output_grads.dim() == 4, "expected 4D tensor");
TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor");
TORCH_CHECK(
output_grads.size(0) <= freqs.size(0),
"expected freqs tensor has a longer sequence length than output_grads");
TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1,
"expected the second and third dims of the freqs tensor equal 1");
TORCH_CHECK(output_grads.size(3) >= freqs.size(3),
"expected the last dim of the output_grads tensor equals or is "
"greater than the freqs tensor");
TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float,
"Dtype of the freqs tensor must be float");
// output_grads sizes: (s, b, h, d)
// s: sequence length
// b: batch size
// h: head num
// d: dim of each head
const int s = output_grads.size(0);
const int b = output_grads.size(1);
const int h = output_grads.size(2);
const int d = output_grads.size(3);
// output_grads strides
const int stride_s = output_grads.stride(0);
const int stride_b = output_grads.stride(1);
const int stride_h = output_grads.stride(2);
const int stride_d = output_grads.stride(3);
// freqs' shape is always (s, 1, 1, d2), so the strides are same under
// different memory formats
const int d2 = freqs.size(3);
auto act_options = output_grads.options().requires_grad(false);
at::Tensor input_grads;
if (transpose_output_memory) {
input_grads = torch::empty({b, s, h, d}, act_options).transpose(0, 1);
} else {
input_grads = torch::empty({s, b, h, d}, act_options);
}
const int o_stride_s = input_grads.stride(0);
const int o_stride_b = input_grads.stride(1);
const int o_stride_h = input_grads.stride(2);
const int o_stride_d = input_grads.stride(3);
auto output_grads_cu = makeTransformerEngineTensor(output_grads);
auto freqs_cu = makeTransformerEngineTensor(freqs);
auto input_grads_cu = makeTransformerEngineTensor(input_grads);
nvte_fused_rope_backward(
output_grads_cu.data(), freqs_cu.data(), input_grads_cu.data(), s, b, h,
d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b,
o_stride_h, o_stride_d, at::cuda::getCurrentCUDAStream());
return input_grads;
}
at::Tensor fused_rope_thd_forward(const at::Tensor &input,
const at::Tensor &cu_seqlens,
const at::Tensor &freqs) {
using namespace transformer_engine;
TORCH_CHECK(input.dim() == 3, "expected 3D tensor");
TORCH_CHECK(cu_seqlens.dim() == 1, "expected 1D tensor");
TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor");
TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1,
"expected the second and third dims of the freqs tensor equal 1");
TORCH_CHECK(input.size(2) >= freqs.size(3),
"expected the last dim of the input tensor equals or is "
"greater than the freqs tensor");
TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float,
"Dtype of the freqs tensor must be float");
// input sizes: (t, h, d)
// t: cumulative sum of sequence lengths
// h: head num
// d: dim of each head
const int t = input.size(0);
const int h = input.size(1);
const int d = input.size(2);
// input strides
const int stride_t = input.stride(0);
const int stride_h = input.stride(1);
const int stride_d = input.stride(2);
// batch size
const int b = cu_seqlens.size(0) - 1;
// freqs' shape is (max_s, 1, 1, d2)
const int max_s = freqs.size(0);
const int d2 = freqs.size(3);
// output
auto act_options = input.options().requires_grad(false);
auto output = torch::empty({t, h, d}, act_options);
// output strides
const int o_stride_t = output.stride(0);
const int o_stride_h = output.stride(1);
const int o_stride_d = output.stride(2);
auto input_cu = makeTransformerEngineTensor(input);
auto cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens);
auto freqs_cu = makeTransformerEngineTensor(freqs);
auto output_cu = makeTransformerEngineTensor(output);
nvte_fused_rope_thd_forward(
input_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(), output_cu.data(),
max_s, b, h, d, d2, stride_t, stride_h, stride_d, o_stride_t, o_stride_h,
o_stride_d, at::cuda::getCurrentCUDAStream());
return output;
}
at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads,
const at::Tensor &cu_seqlens,
const at::Tensor &freqs) {
using namespace transformer_engine;
TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor");
TORCH_CHECK(cu_seqlens.dim() == 1, "expected 1D tensor");
TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor");
TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1,
"expected the second and third dims of the freqs tensor equal 1");
TORCH_CHECK(output_grads.size(2) >= freqs.size(3),
"expected the last dim of the output_grads tensor equals or is "
"greater than the freqs tensor");
TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float,
"Dtype of the freqs tensor must be float");
// output_grads sizes: (t, h, d)
// t: cumulative sum of sequence lengths
// h: head num
// d: dim of each head
const int t = output_grads.size(0);
const int h = output_grads.size(1);
const int d = output_grads.size(2);
// output_grads strides
const int stride_t = output_grads.stride(0);
const int stride_h = output_grads.stride(1);
const int stride_d = output_grads.stride(2);
// batch size
const int b = cu_seqlens.size(0) - 1;
// freqs' shape is (max_s, 1, 1, d2)
const int max_s = freqs.size(0);
const int d2 = freqs.size(3);
auto act_options = output_grads.options().requires_grad(false);
auto input_grads = torch::empty({t, h, d}, act_options);
const int o_stride_t = input_grads.stride(0);
const int o_stride_h = input_grads.stride(1);
const int o_stride_d = input_grads.stride(2);
auto output_grads_cu = makeTransformerEngineTensor(output_grads);
auto cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens);
auto freqs_cu = makeTransformerEngineTensor(freqs);
auto input_grads_cu = makeTransformerEngineTensor(input_grads);
nvte_fused_rope_thd_backward(
output_grads_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(),
input_grads_cu.data(), max_s, b, h, d, d2, stride_t, stride_h, stride_d,
o_stride_t, o_stride_h, o_stride_d, at::cuda::getCurrentCUDAStream());
return input_grads;
}
......@@ -75,6 +75,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fa_prepare_bwd", &fa_prepare_bwd, "Backward of QKV preparation for Flash Attention");
m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend");
// fused apply rope
m.def("fused_rope_forward", &fused_rope_forward, "Fused Apply RoPE FWD");
m.def("fused_rope_backward", &fused_rope_backward, "Fused Apply RoPE BWD");
m.def("fused_rope_thd_forward", &fused_rope_thd_forward, "Fused Apply RoPE FWD for thd format");
m.def("fused_rope_thd_backward", &fused_rope_thd_backward, "Fused Apply RoPE BWD for thd format");
// Misc
m.def("get_cublasLt_version", &get_cublasLt_version, "Get cublasLt version");
m.def("get_cudnn_version", &get_cudnn_version, "Get cuDNN version");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment