Unverified Commit 6c1a8bb5 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Common][PyTorch] Fused `apply_rotorary_pos_emb` (#517)



* fused apply rope
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* Apply suggestions from code review
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarXin Yao <yaox12@outlook.com>

* resolve comments
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* make rotary_percent optional
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* fix ci
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* fix test
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* add rope test to qa
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* fix linting
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* sync apex: add transpose_output_memory
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* small fix
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* sync apex: fuse sin/cos
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* sync apex: fused rope for thd format
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* fix lint
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* Fix license headers
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* add support for bshd format
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* support different seq length
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* update
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* update copyright
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* remove transpose_output_memory
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* Make outputs contiguous in SBHD case
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

---------
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Signed-off-by: default avatarXin Yao <yaox12@outlook.com>
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarPrzemyslaw Tredak <ptredak@nvidia.com>
parent b957aa47
...@@ -12,5 +12,6 @@ pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py ...@@ -12,5 +12,6 @@ pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py
PYTORCH_JIT=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py PYTORCH_JIT=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py
pytest -v -s $TE_PATH/tests/pytorch/test_jit.py pytest -v -s $TE_PATH/tests/pytorch/test_jit.py
pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py
pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py
pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
import torch
from typing import Callable, Dict, Tuple, Union
from transformer_engine.pytorch.attention import (
RotaryPositionEmbedding,
apply_rotary_pos_emb,
)
def apply_rotary_pos_emb_thd(
t: torch.Tensor, cu_seqlens: torch.Tensor, freqs: torch.Tensor
) -> torch.Tensor:
"""A baseline implementation of applying RoPE for `thd` format.
Args:
t (Tensor): Input tensor T is of shape [t, h, d]
cu_seqlens(Tensor): Cumulative sum of sequence lengths in a batch for `t`,
with shape [b + 1] and dtype torch.int32.
freqs (Tensor): Rotary Positional embedding tensor freq is of shape [max_s, 1, 1, d]
Returns:
Tensor: Shape [t, h, d]. The input tensor after applying RoPE.
"""
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
return torch.cat(
[
apply_rotary_pos_emb(x.unsqueeze(1), freqs[: x.size(0)])
for x in torch.split(t, seqlens)
]
).squeeze(1)
def get_tol(dtype: torch.dtype) -> Dict:
if dtype == torch.bfloat16:
return dict(atol=1e-2, rtol=1e-2)
elif dtype == torch.float16:
return dict(atol=1e-3, rtol=1e-3)
return dict(atol=1e-5, rtol=1.3e-6)
# Gradient is a broadcasted scalar
def _overlapping_grad(output: torch.Tensor) -> torch.Tensor:
return output.sum() * 2
# Gradient is a full tensor
def _non_overlapping_grad(output: torch.Tensor) -> torch.Tensor:
t = torch.ones_like(output)
return torch.sum(output * t)
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
@pytest.mark.parametrize("seq_length", [2048, 4096])
@pytest.mark.parametrize("hidden_size", [128, 256])
@pytest.mark.parametrize("rotary_percent", [0.5, 1.0])
@pytest.mark.parametrize("margin", [0, 10])
@pytest.mark.parametrize("transpose", [None, (0, 1), (2, 3)])
@pytest.mark.parametrize("tensor_format", ["sbhd", "bshd"])
@pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad])
def test_fused_rope(
dtype: torch.dtype,
seq_length: int,
hidden_size: int,
rotary_percent: float,
margin: int,
transpose: Union[Tuple, None],
tensor_format: str,
loss_func: Callable,
) -> None:
device = torch.device("cuda:0")
batch_size, head_num = 2, 64
t = torch.rand(
(seq_length - margin, batch_size, head_num, hidden_size),
dtype=dtype,
device=device,
)
if tensor_format == "bshd":
t = t.transpose(0, 1).contiguous()
if transpose:
t = t.transpose(*transpose).contiguous().transpose(*transpose)
t.requires_grad = True
rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent)
emb = rotary_pos_emb(seq_length)
# unfused
output_unfused = apply_rotary_pos_emb(
t, emb, tensor_format=tensor_format, fused=False
)
loss_unfused = loss_func(output_unfused)
loss_unfused.backward()
grad_unfused = t.grad.detach().clone()
t.grad = None
# fused
output_fused = apply_rotary_pos_emb(
t,
emb,
tensor_format=tensor_format,
fused=True,
)
loss_fused = loss_func(output_fused)
loss_fused.backward()
grad_fused = t.grad.detach().clone()
t.grad = None
torch.testing.assert_close(output_fused, output_unfused, **get_tol(dtype))
torch.testing.assert_close(grad_fused, grad_unfused, **get_tol(dtype))
assert output_fused.is_contiguous()
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
@pytest.mark.parametrize("hidden_size", [128, 256])
@pytest.mark.parametrize("rotary_percent", [0.5, 1.0])
@pytest.mark.parametrize("transpose", [None, (1, 2)])
@pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad])
def test_fused_rope_thd(
dtype: torch.dtype,
hidden_size: int,
rotary_percent: float,
transpose: Union[Tuple, None],
loss_func: Callable,
) -> None:
device = torch.device("cuda:0")
batch_size, head_num = 2, 64
cu_seqlens = torch.tensor(
[0, 400, 542, 711, 727, 752, 1270, 1426, 1450, 1954, 2044, 2048],
dtype=torch.int32,
device=device,
)
t = torch.rand(
(cu_seqlens[-1], head_num, hidden_size),
dtype=dtype,
device=device,
)
if transpose:
t = t.transpose(*transpose).contiguous().transpose(*transpose)
t.requires_grad = True
rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent)
emb = rotary_pos_emb(cu_seqlens[-1])
# unfused
output_unfused = apply_rotary_pos_emb_thd(t, cu_seqlens, emb)
loss_unfused = loss_func(output_unfused)
loss_unfused.backward()
grad_unfused = t.grad.detach().clone()
t.grad = None
# fused
output_fused = apply_rotary_pos_emb(
t, emb, fused=True, tensor_format="thd", cu_seqlens=cu_seqlens
)
loss_fused = loss_func(output_fused)
loss_fused.backward()
grad_fused = t.grad.detach().clone()
t.grad = None
torch.testing.assert_close(output_fused, output_unfused, **get_tol(dtype))
torch.testing.assert_close(grad_fused, grad_unfused, **get_tol(dtype))
...@@ -34,7 +34,8 @@ list(APPEND transformer_engine_SOURCES ...@@ -34,7 +34,8 @@ list(APPEND transformer_engine_SOURCES
fused_softmax/scaled_masked_softmax.cu fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_softmax/scaled_masked_softmax.cu fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu) fused_softmax/scaled_upper_triang_masked_softmax.cu
fused_rope/fused_rope.cu)
add_library(transformer_engine SHARED ${transformer_engine_SOURCES}) add_library(transformer_engine SHARED ${transformer_engine_SOURCES})
target_include_directories(transformer_engine PUBLIC target_include_directories(transformer_engine PUBLIC
"${CMAKE_CURRENT_SOURCE_DIR}/include") "${CMAKE_CURRENT_SOURCE_DIR}/include")
......
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda_runtime.h>
#include <transformer_engine/fused_rope.h>
#include "../common.h"
#include "../util/logging.h"
#include "../utils.cuh"
namespace transformer_engine {
template <typename scalar_t>
__device__ void fused_rope_block_forward(
const scalar_t *src, const float *freqs, scalar_t *dst,
const int offset_block, const int offset_block_dst, const int h,
const int d, const int d2, const int stride_h, const int stride_d,
const int o_stride_h, const int o_stride_d) {
int s_id = blockIdx.x;
#pragma unroll
for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) {
float v_cos, v_sin;
sincosf(freqs[s_id * d2 + d_id], &v_sin, &v_cos);
#pragma unroll
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
int offset_src = offset_block + h_id * stride_h + d_id * stride_d;
int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d;
scalar_t v_src = src[offset_src];
scalar_t v_src_rotate = (d_id + d2 / 2 < d2)
? -src[offset_src + (d2 / 2) * stride_d]
: src[offset_src + (d2 / 2 - d2) * stride_d];
dst[offset_dst] =
v_src * (scalar_t)v_cos + v_src_rotate * (scalar_t)v_sin;
}
}
// copy the rest
if (d > d2) {
#pragma unroll
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
int offset_head = offset_block + h_id * stride_h;
int offset_head_dst = offset_block_dst + h_id * o_stride_h;
#pragma unroll
for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) {
dst[offset_head_dst + d_id * o_stride_d] =
src[offset_head + d_id * stride_d];
}
}
}
}
template <typename scalar_t>
__device__ void fused_rope_block_backward(
const scalar_t *src, const float *freqs, scalar_t *dst,
const int offset_block, const int offset_block_dst, const int h,
const int d, const int d2, const int stride_h, const int stride_d,
const int o_stride_h, const int o_stride_d) {
int s_id = blockIdx.x;
#pragma unroll
for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) {
scalar_t v_cos = cosf(freqs[s_id * d2 + d_id]);
scalar_t v_sin = (d_id + d2 / 2 < d2)
? sinf(freqs[s_id * d2 + d_id + d2 / 2])
: -sinf(freqs[s_id * d2 + d_id + d2 / 2 - d2]);
#pragma unroll
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
int offset_src = offset_block + h_id * stride_h + d_id * stride_d;
int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d;
scalar_t v_src = src[offset_src];
scalar_t v_src_rotate = (d_id + d2 / 2 < d2)
? src[offset_src + (d2 / 2) * stride_d]
: src[offset_src + (d2 / 2 - d2) * stride_d];
dst[offset_dst] = v_src * v_cos + v_src_rotate * v_sin;
}
}
// handle the tail
if (d > d2) {
#pragma unroll
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
int offset_head = offset_block + h_id * stride_h;
int offset_head_dst = offset_block_dst + h_id * o_stride_h;
#pragma unroll
for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) {
dst[offset_head_dst + d_id * o_stride_d] =
src[offset_head + d_id * stride_d];
}
}
}
}
template <typename scalar_t>
__global__ void fused_rope_forward_kernel(
const scalar_t *src, const float *freqs, scalar_t *dst, const int h,
const int d, const int d2, const int stride_s, const int stride_b,
const int stride_h, const int stride_d, const int o_stride_s,
const int o_stride_b, const int o_stride_h, const int o_stride_d) {
int s_id = blockIdx.x, b_id = blockIdx.y;
int offset_block = s_id * stride_s + b_id * stride_b;
int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b;
fused_rope_block_forward(src, freqs, dst, offset_block, offset_block_dst, h,
d, d2, stride_h, stride_d, o_stride_h, o_stride_d);
}
template <typename scalar_t>
__global__ void fused_rope_backward_kernel(
const scalar_t *src, const float *freqs, scalar_t *dst, const int h,
const int d, const int d2, const int stride_s, const int stride_b,
const int stride_h, const int stride_d, const int o_stride_s,
const int o_stride_b, const int o_stride_h, const int o_stride_d) {
int s_id = blockIdx.x, b_id = blockIdx.y;
int offset_block = s_id * stride_s + b_id * stride_b;
int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b;
fused_rope_block_backward(src, freqs, dst, offset_block, offset_block_dst, h,
d, d2, stride_h, stride_d, o_stride_h, o_stride_d);
}
template <typename scalar_t>
__global__ void fused_rope_thd_forward_kernel(
const scalar_t *src, const int *cu_seqlens, const float *freqs,
scalar_t *dst, const int h, const int d, const int d2, const int stride_t,
const int stride_h, const int stride_d, const int o_stride_t,
const int o_stride_h, const int o_stride_d) {
int s_id = blockIdx.x, b_id = blockIdx.y;
int t_id = s_id + cu_seqlens[b_id];
if (t_id >= cu_seqlens[b_id + 1]) return;
int offset_block = t_id * stride_t;
int offset_block_dst = t_id * o_stride_t;
fused_rope_block_forward(src, freqs, dst, offset_block, offset_block_dst, h,
d, d2, stride_h, stride_d, o_stride_h, o_stride_d);
}
template <typename scalar_t>
__global__ void fused_rope_thd_backward_kernel(
const scalar_t *src, const int *cu_seqlens, const float *freqs,
scalar_t *dst, const int h, const int d, const int d2, const int stride_t,
const int stride_h, const int stride_d, const int o_stride_t,
const int o_stride_h, const int o_stride_d) {
int s_id = blockIdx.x, b_id = blockIdx.y;
int t_id = s_id + cu_seqlens[b_id];
if (t_id >= cu_seqlens[b_id + 1]) return;
int offset_block = t_id * stride_t;
int offset_block_dst = t_id * o_stride_t;
fused_rope_block_backward(src, freqs, dst, offset_block, offset_block_dst, h,
d, d2, stride_h, stride_d, o_stride_h, o_stride_d);
}
template <typename scalar_t>
void fused_rope_forward_launcher(const scalar_t *input, const float *freqs,
scalar_t *output, const int s, const int b,
const int h, const int d, const int d2,
const int stride_s, const int stride_b,
const int stride_h, const int stride_d,
const int o_stride_s, const int o_stride_b,
const int o_stride_h, const int o_stride_d,
cudaStream_t stream) {
int warps_per_block = h < 16 ? 4 : 8;
dim3 blocks(s, b);
dim3 threads(THREADS_PER_WARP, warps_per_block);
fused_rope_forward_kernel<<<blocks, threads, 0, stream>>>(
input, freqs, output, h, d, d2, stride_s, stride_b, stride_h, stride_d,
o_stride_s, o_stride_b, o_stride_h, o_stride_d);
NVTE_CHECK_CUDA(cudaGetLastError());
}
template <typename scalar_t>
void fused_rope_backward_launcher(const scalar_t *output_grads,
const float *freqs, scalar_t *input_grads,
const int s, const int b, const int h,
const int d, const int d2, const int stride_s,
const int stride_b, const int stride_h,
const int stride_d, const int o_stride_s,
const int o_stride_b, const int o_stride_h,
const int o_stride_d, cudaStream_t stream) {
int warps_per_block = h < 16 ? 4 : 8;
dim3 blocks(s, b);
dim3 threads(THREADS_PER_WARP, warps_per_block);
fused_rope_backward_kernel<<<blocks, threads, 0, stream>>>(
output_grads, freqs, input_grads, h, d, d2, stride_s, stride_b, stride_h,
stride_d, o_stride_s, o_stride_b, o_stride_h, o_stride_d);
NVTE_CHECK_CUDA(cudaGetLastError());
}
template <typename scalar_t>
void fused_rope_thd_forward_launcher(
const scalar_t *input, const int *cu_seqlens, const float *freqs,
scalar_t *output, const int max_s, const int b, const int h, const int d,
const int d2, const int stride_t, const int stride_h, const int stride_d,
const int o_stride_t, const int o_stride_h, const int o_stride_d,
cudaStream_t stream) {
int warps_per_block = h < 16 ? 4 : 8;
dim3 blocks(max_s, b);
dim3 threads(THREADS_PER_WARP, warps_per_block);
fused_rope_thd_forward_kernel<<<blocks, threads, 0, stream>>>(
input, cu_seqlens, freqs, output, h, d, d2, stride_t, stride_h, stride_d,
o_stride_t, o_stride_h, o_stride_d);
NVTE_CHECK_CUDA(cudaGetLastError());
}
template <typename scalar_t>
void fused_rope_thd_backward_launcher(
const scalar_t *output_grads, const int *cu_seqlens, const float *freqs,
scalar_t *input_grads, const int max_s, const int b, const int h,
const int d, const int d2, const int stride_t, const int stride_h,
const int stride_d, const int o_stride_t, const int o_stride_h,
const int o_stride_d, cudaStream_t stream) {
int warps_per_block = h < 16 ? 4 : 8;
dim3 blocks(max_s, b);
dim3 threads(THREADS_PER_WARP, warps_per_block);
fused_rope_thd_backward_kernel<<<blocks, threads, 0, stream>>>(
output_grads, cu_seqlens, freqs, input_grads, h, d, d2, stride_t,
stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d);
NVTE_CHECK_CUDA(cudaGetLastError());
}
void fused_rope_forward(const Tensor &input, const Tensor &freqs,
Tensor *output, const int s, const int b, const int h,
const int d, const int d2, const int stride_s,
const int stride_b, const int stride_h,
const int stride_d, const int o_stride_s,
const int o_stride_b, const int o_stride_h,
const int o_stride_d, cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, scalar_t,
fused_rope_forward_launcher(
reinterpret_cast<const scalar_t *>(input.data.dptr),
reinterpret_cast<const float *>(freqs.data.dptr),
reinterpret_cast<scalar_t *>(output->data.dptr), s, b, h, d, d2,
stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b,
o_stride_h, o_stride_d, stream););
}
void fused_rope_backward(const Tensor &output_grads, const Tensor &freqs,
Tensor *input_grads, const int s, const int b,
const int h, const int d, const int d2,
const int stride_s, const int stride_b,
const int stride_h, const int stride_d,
const int o_stride_s, const int o_stride_b,
const int o_stride_h, const int o_stride_d,
cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
output_grads.data.dtype, scalar_t,
fused_rope_backward_launcher(
reinterpret_cast<const scalar_t *>(output_grads.data.dptr),
reinterpret_cast<const float *>(freqs.data.dptr),
reinterpret_cast<scalar_t *>(input_grads->data.dptr), s, b, h, d, d2,
stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b,
o_stride_h, o_stride_d, stream););
}
void fused_rope_thd_forward(const Tensor &input, const Tensor &cu_seqlens,
const Tensor &freqs, Tensor *output,
const int max_s, const int b, const int h,
const int d, const int d2, const int stride_t,
const int stride_h, const int stride_d,
const int o_stride_t, const int o_stride_h,
const int o_stride_d, cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, scalar_t,
fused_rope_thd_forward_launcher(
reinterpret_cast<const scalar_t *>(input.data.dptr),
reinterpret_cast<const int *>(cu_seqlens.data.dptr),
reinterpret_cast<const float *>(freqs.data.dptr),
reinterpret_cast<scalar_t *>(output->data.dptr), max_s, b, h, d, d2,
stride_t, stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d,
stream););
}
void fused_rope_thd_backward(const Tensor &output_grads,
const Tensor &cu_seqlens, const Tensor &freqs,
Tensor *input_grads, const int max_s, const int b,
const int h, const int d, const int d2,
const int stride_t, const int stride_h,
const int stride_d, const int o_stride_t,
const int o_stride_h, const int o_stride_d,
cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
output_grads.data.dtype, scalar_t,
fused_rope_thd_backward_launcher(
reinterpret_cast<const scalar_t *>(output_grads.data.dptr),
reinterpret_cast<const int *>(cu_seqlens.data.dptr),
reinterpret_cast<const float *>(freqs.data.dptr),
reinterpret_cast<scalar_t *>(input_grads->data.dptr), max_s, b, h, d,
d2, stride_t, stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d,
stream););
}
} // end namespace transformer_engine
void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor freqs,
NVTETensor output, const int s, const int b,
const int h, const int d, const int d2,
const int stride_s, const int stride_b,
const int stride_h, const int stride_d,
const int o_stride_s, const int o_stride_b,
const int o_stride_h, const int o_stride_d,
cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_rope_forward);
using namespace transformer_engine;
fused_rope_forward(*reinterpret_cast<const Tensor *>(input),
*reinterpret_cast<const Tensor *>(freqs),
reinterpret_cast<Tensor *>(output), s, b, h, d, d2,
stride_s, stride_b, stride_h, stride_d, o_stride_s,
o_stride_b, o_stride_h, o_stride_d, stream);
}
void nvte_fused_rope_backward(const NVTETensor output_grads,
const NVTETensor freqs, NVTETensor input_grads,
const int s, const int b, const int h,
const int d, const int d2, const int stride_s,
const int stride_b, const int stride_h,
const int stride_d, const int o_stride_s,
const int o_stride_b, const int o_stride_h,
const int o_stride_d, cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_rope_backward);
using namespace transformer_engine;
fused_rope_backward(*reinterpret_cast<const Tensor *>(output_grads),
*reinterpret_cast<const Tensor *>(freqs),
reinterpret_cast<Tensor *>(input_grads), s, b, h, d, d2,
stride_s, stride_b, stride_h, stride_d, o_stride_s,
o_stride_b, o_stride_h, o_stride_d, stream);
}
void nvte_fused_rope_thd_forward(const NVTETensor input,
const NVTETensor cu_seqlens,
const NVTETensor freqs, NVTETensor output,
const int max_s, const int b, const int h,
const int d, const int d2, const int stride_t,
const int stride_h, const int stride_d,
const int o_stride_t, const int o_stride_h,
const int o_stride_d, cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_rope_thd_forward);
using namespace transformer_engine;
fused_rope_thd_forward(*reinterpret_cast<const Tensor *>(input),
*reinterpret_cast<const Tensor *>(cu_seqlens),
*reinterpret_cast<const Tensor *>(freqs),
reinterpret_cast<Tensor *>(output), max_s, b, h, d, d2,
stride_t, stride_h, stride_d, o_stride_t, o_stride_h,
o_stride_d, stream);
}
void nvte_fused_rope_thd_backward(
const NVTETensor output_grads, const NVTETensor cu_seqlens,
const NVTETensor freqs, NVTETensor input_grads, const int max_s,
const int b, const int h, const int d, const int d2, const int stride_t,
const int stride_h, const int stride_d, const int o_stride_t,
const int o_stride_h, const int o_stride_d, cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_rope_thd_backward);
using namespace transformer_engine;
fused_rope_thd_backward(*reinterpret_cast<const Tensor *>(output_grads),
*reinterpret_cast<const Tensor *>(cu_seqlens),
*reinterpret_cast<const Tensor *>(freqs),
reinterpret_cast<Tensor *>(input_grads), max_s, b, h,
d, d2, stride_t, stride_h, stride_d, o_stride_t,
o_stride_h, o_stride_d, stream);
}
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_FUSED_ROPE_H_
#define TRANSFORMER_ENGINE_FUSED_ROPE_H_
#include "transformer_engine.h"
#ifdef __cplusplus
extern "C" {
#endif
/*! \brief Apply rotary positional embedding to the input tensor.
*
* \param[in] input Input tensor for fused rope.
* \param[in] freqs The freqs tensor.
* \param[out] output Output tensor.
* \param[in] s Length of the s dimension of input.
* \param[in] b Length of the b dimension of input.
* \param[in] h Length of the h dimension of input.
* \param[in] d Length of the d dimension of input.
* \param[in] d2 Length of the d dimension of freqs.
* \param[in] stride_s Stride of the s dimension of input.
* \param[in] stride_b Stride of the b dimension of input.
* \param[in] stride_h Stride of the h dimension of input.
* \param[in] stride_d Stride of the d dimension of input.
* \param[in] o_stride_s Stride of the s dimension of output.
* \param[in] o_stride_b Stride of the b dimension of output.
* \param[in] o_stride_h Stride of the h dimension of output.
* \param[in] o_stride_d Stride of the d dimension of output.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor freqs,
NVTETensor output, const int s, const int b,
const int h, const int d, const int d2,
const int stride_s, const int stride_b,
const int stride_h, const int stride_d,
const int o_stride_s, const int o_stride_b,
const int o_stride_h, const int o_stride_d,
cudaStream_t stream);
/*! \brief Compute the backward of the fused rope.
*
* \param[in] output_grads Incoming gradient tensor for backward.
* \param[in] freqs The freqs tensor.
* \param[out] input_grads Input gradient tensor to calculate.
* \param[in] s Length of the s dimension of output_grads.
* \param[in] b Length of the b dimension of output_grads.
* \param[in] h Length of the h dimension of output_grads.
* \param[in] d Length of the d dimension of output_grads.
* \param[in] d2 Length of the d dimension of freqs.
* \param[in] stride_s Stride of the s dimension of output_grads.
* \param[in] stride_b Stride of the b dimension of output_grads.
* \param[in] stride_h Stride of the h dimension of output_grads.
* \param[in] stride_d Stride of the d dimension of output_grads.
* \param[in] o_stride_s Stride of the s dimension of input_grads.
* \param[in] o_stride_b Stride of the b dimension of input_grads.
* \param[in] o_stride_h Stride of the h dimension of input_grads.
* \param[in] o_stride_d Stride of the d dimension of input_grads.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fused_rope_backward(const NVTETensor output_grads,
const NVTETensor freqs, NVTETensor input_grads,
const int s, const int b, const int h,
const int d, const int d2, const int stride_s,
const int stride_b, const int stride_h,
const int stride_d, const int o_stride_s,
const int o_stride_b, const int o_stride_h,
const int o_stride_d, cudaStream_t stream);
/*! \brief Apply rotary positional embedding to the input tensor in thd format.
*
* \param[in] input Input tensor for fused rope.
* \param[in] cu_seqlens The cumulative sum of sequence lengths tensor.
* \param[in] freqs The freqs tensor.
* \param[out] output Output tensor.
* \param[in] max_s Max sequence length.
* \param[in] b Batch size.
* \param[in] h Length of the h dimension of input.
* \param[in] d Length of the d dimension of input.
* \param[in] d2 Length of the d dimension of freqs.
* \param[in] stride_t Stride of the t dimension of input.
* \param[in] stride_h Stride of the h dimension of input.
* \param[in] stride_d Stride of the d dimension of input.
* \param[in] o_stride_t Stride of the t dimension of output.
* \param[in] o_stride_h Stride of the h dimension of output.
* \param[in] o_stride_d Stride of the d dimension of output.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fused_rope_thd_forward(const NVTETensor input,
const NVTETensor cu_seqlens,
const NVTETensor freqs, NVTETensor output,
const int max_s, const int b, const int h,
const int d, const int d2, const int stride_t,
const int stride_h, const int stride_d,
const int o_stride_t, const int o_stride_h,
const int o_stride_d, cudaStream_t stream);
/*! \brief Compute the backward of the fused rope in thd format.
*
* \param[in] output_grads Incoming gradient tensor for backward.
* \param[in] cu_seqlens The cumulative sum of sequence lengths tensor.
* \param[in] freqs The freqs tensor.
* \param[out] input_grads Input gradient to calculate.
* \param[in] max_s Max sequence length.
* \param[in] b Batch size.
* \param[in] h Length of the h dimension of output_grads.
* \param[in] d Length of the d dimension of output_grads.
* \param[in] d2 Length of the d dimension of freqs.
* \param[in] stride_t Stride of the t dimension of output_grads.
* \param[in] stride_h Stride of the h dimension of output_grads.
* \param[in] stride_d Stride of the d dimension of output_grads.
* \param[in] o_stride_t Stride of the t dimension of input_grads.
* \param[in] o_stride_h Stride of the h dimension of input_grads.
* \param[in] o_stride_d Stride of the d dimension of input_grads.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fused_rope_thd_backward(
const NVTETensor output_grads, const NVTETensor cu_seqlens,
const NVTETensor freqs, NVTETensor input_grads, const int max_s,
const int b, const int h, const int d, const int d2, const int stride_t,
const int stride_h, const int stride_d, const int o_stride_t,
const int o_stride_h, const int o_stride_d, cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
#endif // TRANSFORMER_ENGINE_FUSED_ROPE_H_
...@@ -974,6 +974,7 @@ class RotaryPositionEmbedding(torch.nn.Module): ...@@ -974,6 +974,7 @@ class RotaryPositionEmbedding(torch.nn.Module):
def __init__( def __init__(
self, self,
dim: int, dim: int,
rotary_percent: float = 1.0,
seq_len_interpolation_factor: Optional[int] = None, seq_len_interpolation_factor: Optional[int] = None,
pretrained_max_position_embeddings: Optional[int] = None, pretrained_max_position_embeddings: Optional[int] = None,
): ):
...@@ -982,6 +983,8 @@ class RotaryPositionEmbedding(torch.nn.Module): ...@@ -982,6 +983,8 @@ class RotaryPositionEmbedding(torch.nn.Module):
---------- ----------
dim: int dim: int
rotary embedding dimension rotary embedding dimension
rotary_percent: float
Percent of rotary dimension to use for rotary position embeddings.
seq_len_interpolation_factor: int seq_len_interpolation_factor: int
if not None, discrete positions will be interpolated by this factor via the trick in if not None, discrete positions will be interpolated by this factor via the trick in
https://arxiv.org/abs/2306.15595 https://arxiv.org/abs/2306.15595
...@@ -989,8 +992,16 @@ class RotaryPositionEmbedding(torch.nn.Module): ...@@ -989,8 +992,16 @@ class RotaryPositionEmbedding(torch.nn.Module):
pre-trained max_position_embeddings before position interpolation pre-trained max_position_embeddings before position interpolation
""" """
super().__init__() super().__init__()
if rotary_percent < 1.0:
dim = int(dim * rotary_percent)
self.seq_len_interpolation_factor = seq_len_interpolation_factor self.seq_len_interpolation_factor = seq_len_interpolation_factor
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) inv_freq = 1.0 / (
10000
** (
torch.arange(0, dim, 2, dtype=torch.float32, device=torch.cuda.current_device())
/ dim
)
)
self.register_buffer('inv_freq', inv_freq) self.register_buffer('inv_freq', inv_freq)
self.pretrained_max_position_embeddings = pretrained_max_position_embeddings self.pretrained_max_position_embeddings = pretrained_max_position_embeddings
...@@ -1005,8 +1016,10 @@ class RotaryPositionEmbedding(torch.nn.Module): ...@@ -1005,8 +1016,10 @@ class RotaryPositionEmbedding(torch.nn.Module):
offset: int, default = 0 offset: int, default = 0
fixed offset for freqencies fixed offset for freqencies
""" """
seq = torch.arange(max_seq_len, device=self.inv_freq.device) + offset seq = (
seq = seq.type_as(self.inv_freq) torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
+ offset
)
if (self.pretrained_max_position_embeddings is not None if (self.pretrained_max_position_embeddings is not None
and self.seq_len_interpolation_factor is not None): and self.seq_len_interpolation_factor is not None):
...@@ -1025,6 +1038,58 @@ class RotaryPositionEmbedding(torch.nn.Module): ...@@ -1025,6 +1038,58 @@ class RotaryPositionEmbedding(torch.nn.Module):
# emb [seq_length, .., dim] # emb [seq_length, .., dim]
return emb.reshape(emb.size(0), 1, 1, emb.size(1)) return emb.reshape(emb.size(0), 1, 1, emb.size(1))
class FusedRoPEFunc(torch.autograd.Function):
"""
Function for FusedRoPE
This implementation assumes the input tensor to be in `sbhd`, `bshd` or `thd` format and
the RoPE tensor to be of shape (s, 1, 1, d). It accepts arbitrary memory layouts to avoid
the expensive `.contiguous()` calls, thus it may not achieve the best memory access pattern.
"""
@staticmethod
def forward(
ctx,
t: torch.Tensor,
freqs: torch.Tensor,
tensor_format: str = "sbhd",
cu_seqlens: Union[torch.Tensor, None] = None,
) -> torch.Tensor:
if tensor_format == "sbhd":
output = tex.fused_rope_forward(t, freqs, False)
elif tensor_format == "bshd":
output = tex.fused_rope_forward(
t.transpose(0, 1), freqs, True
).transpose(0, 1)
elif tensor_format == "thd":
output = tex.fused_rope_thd_forward(t, cu_seqlens, freqs)
else:
raise ValueError(f"Unsupported tensor_format: {tensor_format}.")
ctx.save_for_backward(freqs, cu_seqlens)
ctx.tensor_format = tensor_format
return output
@staticmethod
def backward(
ctx, grad_output: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]:
freqs, cu_seqlens = ctx.saved_tensors
if ctx.tensor_format == "sbhd":
grad_input = tex.fused_rope_backward(grad_output, freqs, False)
elif ctx.tensor_format == "bshd":
grad_input = tex.fused_rope_backward(
grad_output.transpose(0, 1), freqs, True
).transpose(0, 1)
elif ctx.tensor_format == "thd":
grad_input = tex.fused_rope_thd_backward(grad_output, cu_seqlens, freqs)
else:
raise ValueError(f"Unsupported tensor_format: {ctx.tensor_format}.")
return grad_input, None, None, None, None
def _rotate_half(x: torch.Tensor) -> torch.Tensor: def _rotate_half(x: torch.Tensor) -> torch.Tensor:
""" """
change sign so the last dimension becomes [-odd, +even] change sign so the last dimension becomes [-odd, +even]
...@@ -1037,36 +1102,55 @@ def _rotate_half(x: torch.Tensor) -> torch.Tensor: ...@@ -1037,36 +1102,55 @@ def _rotate_half(x: torch.Tensor) -> torch.Tensor:
def apply_rotary_pos_emb( def apply_rotary_pos_emb(
t: torch.Tensor, t: torch.Tensor,
freqs: torch.Tensor, freqs: torch.Tensor,
tensor_format: str = "sbhd" tensor_format: str = "sbhd",
) -> torch.Tensor: fused: bool = False,
cu_seqlens: Union[torch.Tensor, None] = None,
) -> torch.Tensor:
""" """
Apply rotary positional embedding tensor to the input tensor.
Parameters Parameters
---------- ----------
t: torch.Tensor t: torch.Tensor
input tensor on which rotary positional embedding will be applied Input tensor of shape `[s, b, h, d]`, `[s, b, h, d]` or `[t, h, d]`, on which
rotary positional embedding will be applied.
freqs: torch.Tensor freqs: torch.Tensor
rotary positional embeding tensor `freqs` is of shape Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
`[seq_length, ..., dim]` with `s2 >= s` and `d2 <= d`.
tensor_format: {'sbhd', 'bshd'}, default = 'sbhd' fused: bool, default = False
Whether to use a fused applying RoPE implementation.
tensor_format: {'sbhd', 'bshd', 'thd'}, default = 'sbhd'
is `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is is `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is
of shape `[seq, bs, ...]`. of shape `[seq, bs, ...]`. 'thd' is only supported when `fused` is True.
cu_seqlens: torch.Tensor, default = None.
Cumulative sum of sequence lengths in a batch for `t`, with shape [b + 1] and
dtype torch.int32. Only valid when `tensor_format` is 'thd'.
""" """
assert tensor_format in ("sbhd", "bshd"),("Only formats `sbhd` or `bshd` " if fused:
"are supported for input tensor " assert (
"`t`.") tensor_format != "thd" or cu_seqlens is not None
), "cu_seqlens must not be None when tensor_format is 'thd'."
return FusedRoPEFunc.apply(t, freqs, tensor_format, cu_seqlens)
assert tensor_format in ("sbhd", "bshd"), (
"Only formats `sbhd` or `bshd` are supported for input tensor `t` "
f"when fused is False, got {tensor_format}."
)
max_seq_len = freqs.shape[0] max_seq_len = freqs.shape[0]
cur_seq_len = t.shape[1] if tensor_format == "bshd" else t.shape[0] cur_seq_len = t.shape[1] if tensor_format == "bshd" else t.shape[0]
# Only apply the rotary embeddings up to the sequence length of the running # Only apply the rotary embeddings up to the sequence length of the running
# input. # input.
if cur_seq_len > max_seq_len: assert cur_seq_len <= max_seq_len, (
raise Exception(f"Rotary Embeddings only supported upto {max_seq_len} " f"Rotary Embeddings only supported up to {max_seq_len} sequence length!"
"sequence length!") )
freqs = freqs[:cur_seq_len]
freqs = freqs[:cur_seq_len].to(t.dtype)
if tensor_format == "bshd": if tensor_format == "bshd":
freqs = freqs.transpose(0,1) # [seq, 1, 1, dim] -> [1, seq, 1, dim] freqs = freqs.transpose(0, 1) # [seq, 1, 1, dim] -> [1, seq, 1, dim]
# cos/sin first then dtype conversion for better precision
cos_ = torch.cos(freqs).to(t.dtype)
sin_ = torch.sin(freqs).to(t.dtype)
rot_dim = freqs.shape[-1] rot_dim = freqs.shape[-1]
# ideally t_pass is empty so rotary pos embedding is applied to all tensor t # ideally t_pass is empty so rotary pos embedding is applied to all tensor t
...@@ -1074,7 +1158,7 @@ def apply_rotary_pos_emb( ...@@ -1074,7 +1158,7 @@ def apply_rotary_pos_emb(
# first part is cosine component # first part is cosine component
# second part is sine component, need to change signs with _rotate_half method # second part is sine component, need to change signs with _rotate_half method
t = (t * freqs.cos()) + (_rotate_half(t) * freqs.sin()) t = (t * cos_) + (_rotate_half(t) * sin_)
return torch.cat((t, t_pass), dim=-1) return torch.cat((t, t_pass), dim=-1)
......
...@@ -35,6 +35,7 @@ ...@@ -35,6 +35,7 @@
#include <transformer_engine/activation.h> #include <transformer_engine/activation.h>
#include <transformer_engine/cast.h> #include <transformer_engine/cast.h>
#include <transformer_engine/fused_attn.h> #include <transformer_engine/fused_attn.h>
#include <transformer_engine/fused_rope.h>
#include <transformer_engine/gemm.h> #include <transformer_engine/gemm.h>
#include <transformer_engine/layer_norm.h> #include <transformer_engine/layer_norm.h>
#include <transformer_engine/rmsnorm.h> #include <transformer_engine/rmsnorm.h>
......
...@@ -484,6 +484,9 @@ at::Tensor cast_from_fp8(const at::Tensor &input, ...@@ -484,6 +484,9 @@ at::Tensor cast_from_fp8(const at::Tensor &input,
transformer_engine::DType otype transformer_engine::DType otype
); );
/***************************************************************************************************
* Softmax
**************************************************************************************************/
at::Tensor scaled_softmax_forward(at::Tensor input, at::Tensor scaled_softmax_forward(at::Tensor input,
float scale_factor float scale_factor
...@@ -518,6 +521,34 @@ at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_, ...@@ -518,6 +521,34 @@ at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_,
float scale_factor float scale_factor
); );
/***************************************************************************************************
* Rotary positional embedding
**************************************************************************************************/
at::Tensor fused_rope_forward(const at::Tensor &input,
const at::Tensor &freqs,
const bool transpose_output_memory
);
at::Tensor fused_rope_backward(const at::Tensor &output_grads,
const at::Tensor &freqs,
const bool transpose_output_memory
);
at::Tensor fused_rope_thd_forward(const at::Tensor &input,
const at::Tensor &cu_seqlens,
const at::Tensor &freqs
);
at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads,
const at::Tensor &cu_seqlens,
const at::Tensor &freqs
);
/***************************************************************************************************
* Misc
**************************************************************************************************/
size_t get_cublasLt_version(); size_t get_cublasLt_version();
size_t get_cudnn_version(); size_t get_cudnn_version();
......
/*************************************************************************
* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs,
const bool transpose_output_memory) {
using namespace transformer_engine;
TORCH_CHECK(input.dim() == 4, "expected 4D tensor");
TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor");
TORCH_CHECK(input.size(0) <= freqs.size(0),
"expected freqs tensor has a longer sequence length than input");
TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1,
"expected the second and third dims of the freqs tensor equal 1");
TORCH_CHECK(input.size(3) >= freqs.size(3),
"expected the last dim of the input tensor equals or is "
"greater than the freqs tensor");
TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float,
"Dtype of the freqs tensor must be float");
// input sizes: (s, b, h, d)
// s: sequence length
// b: batch size
// h: head num
// d: dim of each head
const int s = input.size(0);
const int b = input.size(1);
const int h = input.size(2);
const int d = input.size(3);
// input strides
const int stride_s = input.stride(0);
const int stride_b = input.stride(1);
const int stride_h = input.stride(2);
const int stride_d = input.stride(3);
// freqs' shape is always (s, 1, 1, d2), so the strides are same under
// different memory formats
const int d2 = freqs.size(3);
// output
auto act_options = input.options().requires_grad(false);
at::Tensor output;
if (transpose_output_memory) {
output = torch::empty({b, s, h, d}, act_options).transpose(0, 1);
} else {
output = torch::empty({s, b, h, d}, act_options);
}
// output strides
const int o_stride_s = output.stride(0);
const int o_stride_b = output.stride(1);
const int o_stride_h = output.stride(2);
const int o_stride_d = output.stride(3);
auto input_cu = makeTransformerEngineTensor(input);
auto freqs_cu = makeTransformerEngineTensor(freqs);
auto output_cu = makeTransformerEngineTensor(output);
nvte_fused_rope_forward(input_cu.data(), freqs_cu.data(), output_cu.data(), s,
b, h, d, d2, stride_s, stride_b, stride_h, stride_d,
o_stride_s, o_stride_b, o_stride_h, o_stride_d,
at::cuda::getCurrentCUDAStream());
return output;
}
at::Tensor fused_rope_backward(const at::Tensor &output_grads,
const at::Tensor &freqs,
const bool transpose_output_memory) {
using namespace transformer_engine;
TORCH_CHECK(output_grads.dim() == 4, "expected 4D tensor");
TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor");
TORCH_CHECK(
output_grads.size(0) <= freqs.size(0),
"expected freqs tensor has a longer sequence length than output_grads");
TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1,
"expected the second and third dims of the freqs tensor equal 1");
TORCH_CHECK(output_grads.size(3) >= freqs.size(3),
"expected the last dim of the output_grads tensor equals or is "
"greater than the freqs tensor");
TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float,
"Dtype of the freqs tensor must be float");
// output_grads sizes: (s, b, h, d)
// s: sequence length
// b: batch size
// h: head num
// d: dim of each head
const int s = output_grads.size(0);
const int b = output_grads.size(1);
const int h = output_grads.size(2);
const int d = output_grads.size(3);
// output_grads strides
const int stride_s = output_grads.stride(0);
const int stride_b = output_grads.stride(1);
const int stride_h = output_grads.stride(2);
const int stride_d = output_grads.stride(3);
// freqs' shape is always (s, 1, 1, d2), so the strides are same under
// different memory formats
const int d2 = freqs.size(3);
auto act_options = output_grads.options().requires_grad(false);
at::Tensor input_grads;
if (transpose_output_memory) {
input_grads = torch::empty({b, s, h, d}, act_options).transpose(0, 1);
} else {
input_grads = torch::empty({s, b, h, d}, act_options);
}
const int o_stride_s = input_grads.stride(0);
const int o_stride_b = input_grads.stride(1);
const int o_stride_h = input_grads.stride(2);
const int o_stride_d = input_grads.stride(3);
auto output_grads_cu = makeTransformerEngineTensor(output_grads);
auto freqs_cu = makeTransformerEngineTensor(freqs);
auto input_grads_cu = makeTransformerEngineTensor(input_grads);
nvte_fused_rope_backward(
output_grads_cu.data(), freqs_cu.data(), input_grads_cu.data(), s, b, h,
d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b,
o_stride_h, o_stride_d, at::cuda::getCurrentCUDAStream());
return input_grads;
}
at::Tensor fused_rope_thd_forward(const at::Tensor &input,
const at::Tensor &cu_seqlens,
const at::Tensor &freqs) {
using namespace transformer_engine;
TORCH_CHECK(input.dim() == 3, "expected 3D tensor");
TORCH_CHECK(cu_seqlens.dim() == 1, "expected 1D tensor");
TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor");
TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1,
"expected the second and third dims of the freqs tensor equal 1");
TORCH_CHECK(input.size(2) >= freqs.size(3),
"expected the last dim of the input tensor equals or is "
"greater than the freqs tensor");
TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float,
"Dtype of the freqs tensor must be float");
// input sizes: (t, h, d)
// t: cumulative sum of sequence lengths
// h: head num
// d: dim of each head
const int t = input.size(0);
const int h = input.size(1);
const int d = input.size(2);
// input strides
const int stride_t = input.stride(0);
const int stride_h = input.stride(1);
const int stride_d = input.stride(2);
// batch size
const int b = cu_seqlens.size(0) - 1;
// freqs' shape is (max_s, 1, 1, d2)
const int max_s = freqs.size(0);
const int d2 = freqs.size(3);
// output
auto act_options = input.options().requires_grad(false);
auto output = torch::empty({t, h, d}, act_options);
// output strides
const int o_stride_t = output.stride(0);
const int o_stride_h = output.stride(1);
const int o_stride_d = output.stride(2);
auto input_cu = makeTransformerEngineTensor(input);
auto cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens);
auto freqs_cu = makeTransformerEngineTensor(freqs);
auto output_cu = makeTransformerEngineTensor(output);
nvte_fused_rope_thd_forward(
input_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(), output_cu.data(),
max_s, b, h, d, d2, stride_t, stride_h, stride_d, o_stride_t, o_stride_h,
o_stride_d, at::cuda::getCurrentCUDAStream());
return output;
}
at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads,
const at::Tensor &cu_seqlens,
const at::Tensor &freqs) {
using namespace transformer_engine;
TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor");
TORCH_CHECK(cu_seqlens.dim() == 1, "expected 1D tensor");
TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor");
TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1,
"expected the second and third dims of the freqs tensor equal 1");
TORCH_CHECK(output_grads.size(2) >= freqs.size(3),
"expected the last dim of the output_grads tensor equals or is "
"greater than the freqs tensor");
TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float,
"Dtype of the freqs tensor must be float");
// output_grads sizes: (t, h, d)
// t: cumulative sum of sequence lengths
// h: head num
// d: dim of each head
const int t = output_grads.size(0);
const int h = output_grads.size(1);
const int d = output_grads.size(2);
// output_grads strides
const int stride_t = output_grads.stride(0);
const int stride_h = output_grads.stride(1);
const int stride_d = output_grads.stride(2);
// batch size
const int b = cu_seqlens.size(0) - 1;
// freqs' shape is (max_s, 1, 1, d2)
const int max_s = freqs.size(0);
const int d2 = freqs.size(3);
auto act_options = output_grads.options().requires_grad(false);
auto input_grads = torch::empty({t, h, d}, act_options);
const int o_stride_t = input_grads.stride(0);
const int o_stride_h = input_grads.stride(1);
const int o_stride_d = input_grads.stride(2);
auto output_grads_cu = makeTransformerEngineTensor(output_grads);
auto cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens);
auto freqs_cu = makeTransformerEngineTensor(freqs);
auto input_grads_cu = makeTransformerEngineTensor(input_grads);
nvte_fused_rope_thd_backward(
output_grads_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(),
input_grads_cu.data(), max_s, b, h, d, d2, stride_t, stride_h, stride_d,
o_stride_t, o_stride_h, o_stride_d, at::cuda::getCurrentCUDAStream());
return input_grads;
}
...@@ -75,6 +75,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -75,6 +75,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fa_prepare_bwd", &fa_prepare_bwd, "Backward of QKV preparation for Flash Attention"); m.def("fa_prepare_bwd", &fa_prepare_bwd, "Backward of QKV preparation for Flash Attention");
m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend"); m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend");
// fused apply rope
m.def("fused_rope_forward", &fused_rope_forward, "Fused Apply RoPE FWD");
m.def("fused_rope_backward", &fused_rope_backward, "Fused Apply RoPE BWD");
m.def("fused_rope_thd_forward", &fused_rope_thd_forward, "Fused Apply RoPE FWD for thd format");
m.def("fused_rope_thd_backward", &fused_rope_thd_backward, "Fused Apply RoPE BWD for thd format");
// Misc // Misc
m.def("get_cublasLt_version", &get_cublasLt_version, "Get cublasLt version"); m.def("get_cublasLt_version", &get_cublasLt_version, "Get cublasLt version");
m.def("get_cudnn_version", &get_cudnn_version, "Get cuDNN version"); m.def("get_cudnn_version", &get_cudnn_version, "Get cuDNN version");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment