Unverified Commit ba605f18 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[PyTorch][Common] Refactor RoPE (#1626)



* refactor to add cp support for sbhd/bshd
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* support interleaved
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* format
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* add interleaved to RotaryPositionEmbedding in test
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* update
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* merge sbhd/bshd and thd functions
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

---------
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
parent ff884e20
......@@ -11,52 +11,6 @@ from transformer_engine.pytorch.dot_product_attention.rope import (
)
def _get_thd_freqs_on_this_cp_rank(
cp_rank: int, cp_size: int, x: torch.Tensor, freqs: torch.Tensor
) -> torch.Tensor:
if cp_size > 1:
cp_seg = x.size(0) // 2
full_seqlen = cp_size * x.size(0)
return torch.cat(
[
freqs[cp_rank * cp_seg : (cp_rank + 1) * cp_seg],
freqs[full_seqlen - (cp_rank + 1) * cp_seg : full_seqlen - cp_rank * cp_seg],
]
)
else:
return freqs[: x.size(0)]
def apply_rotary_pos_emb_thd(
t: torch.Tensor,
cu_seqlens: torch.Tensor,
freqs: torch.Tensor,
cp_size: int = 1,
cp_rank: int = 0,
) -> torch.Tensor:
"""A baseline implementation of applying RoPE for `thd` format.
Args:
t (Tensor): Input tensor T is of shape [t, h, d]
cu_seqlens(Tensor): Cumulative sum of sequence lengths in a batch for `t`,
with shape [b + 1] and dtype torch.int32.
freqs (Tensor): Rotary Positional embedding tensor freq is of shape [max_s, 1, 1, d]
Returns:
Tensor: Shape [t, h, d]. The input tensor after applying RoPE.
"""
cu_seqlens = cu_seqlens // cp_size
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
return torch.cat(
[
apply_rotary_pos_emb(
x.unsqueeze(1), _get_thd_freqs_on_this_cp_rank(cp_rank, cp_size, x, freqs)
)
for x in torch.split(t, seqlens)
]
).squeeze(1)
# Gradient is a broadcasted scalar
def _overlapping_grad(output: torch.Tensor) -> torch.Tensor:
return output.sum() * 2
......@@ -76,6 +30,8 @@ def _non_overlapping_grad(output: torch.Tensor) -> torch.Tensor:
@pytest.mark.parametrize("transpose", [None, (0, 1), (2, 3)])
@pytest.mark.parametrize("tensor_format", ["sbhd", "bshd"])
@pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad])
@pytest.mark.parametrize("cp_size", [1, 2])
@pytest.mark.parametrize("interleaved", [True, False])
def test_fused_rope(
dtype: torch.dtype,
seq_length: int,
......@@ -85,6 +41,8 @@ def test_fused_rope(
transpose: Union[Tuple, None],
tensor_format: str,
loss_func: Callable,
cp_size: int,
interleaved: bool,
) -> None:
device = torch.device("cuda:0")
batch_size, head_num = 2, 64
......@@ -99,14 +57,22 @@ def test_fused_rope(
t = t.transpose(*transpose).contiguous().transpose(*transpose)
t.requires_grad = True
rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent)
emb = rotary_pos_emb(seq_length)
rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved)
emb = rotary_pos_emb(seq_length * cp_size)
assert emb.is_contiguous()
for cp_rank in range(cp_size):
# unfused
# The fused kernel computes in float32 internally, so we force the unfused func to use float32
# for more accurate comparison
output_unfused = apply_rotary_pos_emb(
t.float(), emb, tensor_format=tensor_format, fused=False
t.float(),
emb,
tensor_format=tensor_format,
interleaved=interleaved,
fused=False,
cp_size=cp_size,
cp_rank=cp_rank,
).to(dtype)
loss_unfused = loss_func(output_unfused)
loss_unfused.backward()
......@@ -118,7 +84,10 @@ def test_fused_rope(
t,
emb,
tensor_format=tensor_format,
interleaved=interleaved,
fused=True,
cp_size=cp_size,
cp_rank=cp_rank,
)
loss_fused = loss_func(output_fused)
loss_fused.backward()
......@@ -135,7 +104,8 @@ def test_fused_rope(
@pytest.mark.parametrize("rotary_percent", [0.5, 1.0])
@pytest.mark.parametrize("transpose", [None, (1, 2)])
@pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad])
@pytest.mark.parametrize("cp_size", [1, 2, 3])
@pytest.mark.parametrize("cp_size", [1, 2])
@pytest.mark.parametrize("interleaved", [True, False])
def test_fused_rope_thd(
dtype: torch.dtype,
hidden_size: int,
......@@ -143,6 +113,7 @@ def test_fused_rope_thd(
transpose: Union[Tuple, None],
loss_func: Callable,
cp_size: int,
interleaved: bool,
) -> None:
device = torch.device("cuda:0")
batch_size, head_num = 2, 64
......@@ -170,15 +141,23 @@ def test_fused_rope_thd(
t = t.transpose(*transpose).contiguous().transpose(*transpose)
t.requires_grad = True
rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent)
rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved)
emb = rotary_pos_emb(cu_seqlens_padded[-1])
assert emb.is_contiguous()
for cp_rank in range(cp_size):
# unfused
# The fused kernel computes in float32 internally, so we force the unfused func to use float32
# for more accurate comparison
output_unfused = apply_rotary_pos_emb_thd(
t.float(), cu_seqlens_padded, emb, cp_size, cp_rank
output_unfused = apply_rotary_pos_emb(
t.float(),
emb,
tensor_format="thd",
interleaved=interleaved,
fused=False,
cu_seqlens=cu_seqlens_padded,
cp_size=cp_size,
cp_rank=cp_rank,
).to(dtype)
loss_unfused = loss_func(output_unfused)
loss_unfused.backward()
......@@ -189,6 +168,7 @@ def test_fused_rope_thd(
output_fused = apply_rotary_pos_emb(
t,
emb,
interleaved=interleaved,
fused=True,
tensor_format="thd",
cu_seqlens=cu_seqlens_padded,
......
......@@ -16,10 +16,11 @@ namespace transformer_engine {
template <typename scalar_t>
__device__ void fused_rope_block_forward(const scalar_t *src, const float *freqs, scalar_t *dst,
const int s_id, const int offset_block,
const int offset_block_dst, const int h, const int d,
const int d2, const int stride_h, const int stride_d,
const int o_stride_h, const int o_stride_d) {
const bool interleaved, const int s_id,
const int offset_block, const int offset_block_dst,
const int h, const int d, const int d2, const int stride_h,
const int stride_d, const int o_stride_h,
const int o_stride_d) {
#pragma unroll
for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) {
float v_cos, v_sin;
......@@ -29,9 +30,18 @@ __device__ void fused_rope_block_forward(const scalar_t *src, const float *freqs
int offset_src = offset_block + h_id * stride_h + d_id * stride_d;
int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d;
float v_src = src[offset_src];
float v_src_rotate = (d_id + d2 / 2 < d2)
float v_src_rotate;
if (!interleaved) {
v_src_rotate = (d_id + d2 / 2 < d2)
? -static_cast<float>(src[offset_src + (d2 / 2) * stride_d])
: static_cast<float>(src[offset_src + (d2 / 2 - d2) * stride_d]);
} else {
v_src_rotate = (d_id % 2 == 0)
// d_id + 1
? -static_cast<float>(src[offset_src + stride_d])
// d_id - 1
: static_cast<float>(src[offset_src - stride_d]);
}
dst[offset_dst] = v_src * v_cos + v_src_rotate * v_sin;
}
}
......@@ -52,22 +62,39 @@ __device__ void fused_rope_block_forward(const scalar_t *src, const float *freqs
template <typename scalar_t>
__device__ void fused_rope_block_backward(const scalar_t *src, const float *freqs, scalar_t *dst,
const int s_id, const int offset_block,
const int offset_block_dst, const int h, const int d,
const int d2, const int stride_h, const int stride_d,
const bool interleaved, const int s_id,
const int offset_block, const int offset_block_dst,
const int h, const int d, const int d2,
const int stride_h, const int stride_d,
const int o_stride_h, const int o_stride_d) {
#pragma unroll
for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) {
float v_cos = cosf(freqs[s_id * d2 + d_id]);
float v_sin = (d_id + d2 / 2 < d2) ? sinf(freqs[s_id * d2 + d_id + d2 / 2])
float v_sin;
if (!interleaved) {
v_sin = (d_id + d2 / 2 < d2) ? sinf(freqs[s_id * d2 + d_id + d2 / 2])
: -sinf(freqs[s_id * d2 + d_id + d2 / 2 - d2]);
} else {
v_sin =
(d_id % 2 == 0) ? sinf(freqs[s_id * d2 + d_id + 1]) : -sinf(freqs[s_id * d2 + d_id - 1]);
}
#pragma unroll
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
int offset_src = offset_block + h_id * stride_h + d_id * stride_d;
int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d;
float v_src = src[offset_src];
float v_src_rotate = (d_id + d2 / 2 < d2) ? src[offset_src + (d2 / 2) * stride_d]
: src[offset_src + (d2 / 2 - d2) * stride_d];
float v_src_rotate;
if (!interleaved) {
v_src_rotate = (d_id + d2 / 2 < d2)
? static_cast<float>(src[offset_src + (d2 / 2) * stride_d])
: static_cast<float>(src[offset_src + (d2 / 2 - d2) * stride_d]);
} else {
v_src_rotate = (d_id % 2 == 0)
// d_id + 1
? static_cast<float>(src[offset_src + stride_d])
// d_id - 1
: static_cast<float>(src[offset_src - stride_d]);
}
dst[offset_dst] = v_src * v_cos + v_src_rotate * v_sin;
}
}
......@@ -87,51 +114,33 @@ __device__ void fused_rope_block_backward(const scalar_t *src, const float *freq
}
template <typename scalar_t>
__global__ void fused_rope_forward_kernel(const scalar_t *src, const float *freqs, scalar_t *dst,
__global__ void fused_rope_forward_kernel(const scalar_t *src, const int *cu_seqlens,
const float *freqs, scalar_t *dst, const bool interleaved,
const int cp_size, const int cp_rank, const int s,
const int h, const int d, const int d2,
const int stride_s, const int stride_b,
const int stride_s_or_t, const int stride_b,
const int stride_h, const int stride_d,
const int o_stride_s, const int o_stride_b,
const int o_stride_h, const int o_stride_d) {
int s_id = blockIdx.x, b_id = blockIdx.y;
int offset_block = s_id * stride_s + b_id * stride_b;
int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b;
fused_rope_block_forward(src, freqs, dst, s_id, offset_block, offset_block_dst, h, d, d2,
stride_h, stride_d, o_stride_h, o_stride_d);
}
template <typename scalar_t>
__global__ void fused_rope_backward_kernel(const scalar_t *src, const float *freqs, scalar_t *dst,
const int h, const int d, const int d2,
const int stride_s, const int stride_b,
const int stride_h, const int stride_d,
const int o_stride_s, const int o_stride_b,
const int o_stride_h, const int o_stride_d) {
int s_id = blockIdx.x, b_id = blockIdx.y;
int offset_block = s_id * stride_s + b_id * stride_b;
int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b;
fused_rope_block_backward(src, freqs, dst, s_id, offset_block, offset_block_dst, h, d, d2,
stride_h, stride_d, o_stride_h, o_stride_d);
}
template <typename scalar_t>
__global__ void fused_rope_thd_forward_kernel(const scalar_t *src, const int *cu_seqlens,
const float *freqs, scalar_t *dst, const int cp_size,
const int cp_rank, const int h, const int d,
const int d2, const int stride_t, const int stride_h,
const int stride_d, const int o_stride_t,
const int o_stride_s_or_t, const int o_stride_b,
const int o_stride_h, const int o_stride_d) {
int s_id = blockIdx.x, b_id = blockIdx.y;
int offset_block, offset_block_dst;
int cur_seqlens;
if (cu_seqlens != nullptr) { // THD
int start = cu_seqlens[b_id] / cp_size;
int end = cu_seqlens[b_id + 1] / cp_size;
int t_id = s_id + start;
if (t_id >= end) return;
int offset_block = t_id * stride_t;
int offset_block_dst = t_id * o_stride_t;
offset_block = t_id * stride_s_or_t;
offset_block_dst = t_id * o_stride_s_or_t;
cur_seqlens = end - start;
} else { // SBHD/BSHD
offset_block = s_id * stride_s_or_t + b_id * stride_b;
offset_block_dst = s_id * o_stride_s_or_t + b_id * o_stride_b;
cur_seqlens = s;
}
int s_id_for_freqs;
if (cp_size > 1) {
int cur_seqlens = end - start;
assert(cur_seqlens % 2 == 0);
if (s_id < cur_seqlens / 2) {
s_id_for_freqs = s_id + cp_rank * cur_seqlens / 2;
......@@ -142,28 +151,37 @@ __global__ void fused_rope_thd_forward_kernel(const scalar_t *src, const int *cu
} else {
s_id_for_freqs = s_id;
}
fused_rope_block_forward(src, freqs, dst, s_id_for_freqs, offset_block, offset_block_dst, h, d,
d2, stride_h, stride_d, o_stride_h, o_stride_d);
fused_rope_block_forward(src, freqs, dst, interleaved, s_id_for_freqs, offset_block,
offset_block_dst, h, d, d2, stride_h, stride_d, o_stride_h, o_stride_d);
}
template <typename scalar_t>
__global__ void fused_rope_thd_backward_kernel(const scalar_t *src, const int *cu_seqlens,
const float *freqs, scalar_t *dst, const int cp_size,
const int cp_rank, const int h, const int d,
const int d2, const int stride_t, const int stride_h,
const int stride_d, const int o_stride_t,
const int o_stride_h, const int o_stride_d) {
__global__ void fused_rope_backward_kernel(
const scalar_t *src, const int *cu_seqlens, const float *freqs, scalar_t *dst,
const bool interleaved, const int cp_size, const int cp_rank, const int s, const int h,
const int d, const int d2, const int stride_s_or_t, const int stride_b, const int stride_h,
const int stride_d, const int o_stride_s_or_t, const int o_stride_b, const int o_stride_h,
const int o_stride_d) {
int s_id = blockIdx.x, b_id = blockIdx.y;
int offset_block, offset_block_dst;
int cur_seqlens;
if (cu_seqlens != nullptr) { // THD
int start = cu_seqlens[b_id] / cp_size;
int end = cu_seqlens[b_id + 1] / cp_size;
int t_id = s_id + start;
if (t_id >= end) return;
int offset_block = t_id * stride_t;
int offset_block_dst = t_id * o_stride_t;
offset_block = t_id * stride_s_or_t;
offset_block_dst = t_id * o_stride_s_or_t;
cur_seqlens = end - start;
} else { // SBHD/BSHD
offset_block = s_id * stride_s_or_t + b_id * stride_b;
offset_block_dst = s_id * o_stride_s_or_t + b_id * o_stride_b;
cur_seqlens = s;
}
int s_id_for_freqs;
if (cp_size > 1) {
int cur_seqlens = end - start;
assert(cur_seqlens % 2 == 0);
if (s_id < cur_seqlens / 2) {
s_id_for_freqs = s_id + cp_rank * cur_seqlens / 2;
......@@ -174,193 +192,136 @@ __global__ void fused_rope_thd_backward_kernel(const scalar_t *src, const int *c
} else {
s_id_for_freqs = s_id;
}
fused_rope_block_backward(src, freqs, dst, s_id_for_freqs, offset_block, offset_block_dst, h, d,
d2, stride_h, stride_d, o_stride_h, o_stride_d);
fused_rope_block_backward(src, freqs, dst, interleaved, s_id_for_freqs, offset_block,
offset_block_dst, h, d, d2, stride_h, stride_d, o_stride_h, o_stride_d);
}
template <typename scalar_t>
void fused_rope_forward_launcher(const scalar_t *input, const float *freqs, scalar_t *output,
void fused_rope_forward_launcher(const scalar_t *input, const int *cu_seqlens, const float *freqs,
scalar_t *output, const NVTE_QKV_Format qkv_format,
const bool interleaved, const int cp_size, const int cp_rank,
const int s, const int b, const int h, const int d, const int d2,
const int stride_s, const int stride_b, const int stride_h,
const int stride_d, const int o_stride_s, const int o_stride_b,
const int o_stride_h, const int o_stride_d, cudaStream_t stream) {
const int stride_s_or_t, const int stride_b, const int stride_h,
const int stride_d, cudaStream_t stream) {
int warps_per_block = h < 16 ? 4 : 8;
dim3 blocks(s, b);
dim3 threads(THREADS_PER_WARP, warps_per_block);
int o_stride_s_or_t, o_stride_b;
if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
NVTE_CHECK(cu_seqlens != nullptr, "cu_seqlens is required for THD format");
o_stride_s_or_t = h * d;
o_stride_b = 0;
} else if (qkv_format == NVTE_QKV_Format::NVTE_SBHD) {
o_stride_s_or_t = b * h * d;
o_stride_b = h * d;
} else {
o_stride_s_or_t = h * d;
o_stride_b = s * h * d;
}
const int o_stride_h = d;
const int o_stride_d = 1;
fused_rope_forward_kernel<<<blocks, threads, 0, stream>>>(
input, freqs, output, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s,
o_stride_b, o_stride_h, o_stride_d);
input, cu_seqlens, freqs, output, interleaved, cp_size, cp_rank, s, h, d, d2, stride_s_or_t,
stride_b, stride_h, stride_d, o_stride_s_or_t, o_stride_b, o_stride_h, o_stride_d);
NVTE_CHECK_CUDA(cudaGetLastError());
}
template <typename scalar_t>
void fused_rope_backward_launcher(const scalar_t *output_grads, const float *freqs,
scalar_t *input_grads, const int s, const int b, const int h,
const int d, const int d2, const int stride_s, const int stride_b,
const int stride_h, const int stride_d, const int o_stride_s,
const int o_stride_b, const int o_stride_h, const int o_stride_d,
void fused_rope_backward_launcher(const scalar_t *output_grads, const int *cu_seqlens,
const float *freqs, scalar_t *input_grads,
const NVTE_QKV_Format qkv_format, const bool interleaved,
const int cp_size, const int cp_rank, const int s, const int b,
const int h, const int d, const int d2, const int stride_s_or_t,
const int stride_b, const int stride_h, const int stride_d,
cudaStream_t stream) {
int warps_per_block = h < 16 ? 4 : 8;
dim3 blocks(s, b);
dim3 threads(THREADS_PER_WARP, warps_per_block);
int o_stride_s_or_t, o_stride_b;
if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
NVTE_CHECK(cu_seqlens != nullptr, "cu_seqlens is required for THD format");
o_stride_s_or_t = h * d;
o_stride_b = 0;
} else if (qkv_format == NVTE_QKV_Format::NVTE_SBHD) {
o_stride_s_or_t = b * h * d;
o_stride_b = h * d;
} else {
o_stride_s_or_t = h * d;
o_stride_b = s * h * d;
}
const int o_stride_h = d;
const int o_stride_d = 1;
fused_rope_backward_kernel<<<blocks, threads, 0, stream>>>(
output_grads, freqs, input_grads, h, d, d2, stride_s, stride_b, stride_h, stride_d,
o_stride_s, o_stride_b, o_stride_h, o_stride_d);
NVTE_CHECK_CUDA(cudaGetLastError());
}
template <typename scalar_t>
void fused_rope_thd_forward_launcher(const scalar_t *input, const int *cu_seqlens,
const float *freqs, scalar_t *output, const int cp_size,
const int cp_rank, const int max_s, const int b, const int h,
const int d, const int d2, const int stride_t,
const int stride_h, const int stride_d, const int o_stride_t,
const int o_stride_h, const int o_stride_d,
cudaStream_t stream) {
int warps_per_block = h < 16 ? 4 : 8;
dim3 blocks(max_s, b);
dim3 threads(THREADS_PER_WARP, warps_per_block);
fused_rope_thd_forward_kernel<<<blocks, threads, 0, stream>>>(
input, cu_seqlens, freqs, output, cp_size, cp_rank, h, d, d2, stride_t, stride_h, stride_d,
o_stride_t, o_stride_h, o_stride_d);
output_grads, cu_seqlens, freqs, input_grads, interleaved, cp_size, cp_rank, s, h, d, d2,
stride_s_or_t, stride_b, stride_h, stride_d, o_stride_s_or_t, o_stride_b, o_stride_h,
o_stride_d);
NVTE_CHECK_CUDA(cudaGetLastError());
}
template <typename scalar_t>
void fused_rope_thd_backward_launcher(const scalar_t *output_grads, const int *cu_seqlens,
const float *freqs, scalar_t *input_grads, const int cp_size,
const int cp_rank, const int max_s, const int b, const int h,
const int d, const int d2, const int stride_t,
const int stride_h, const int stride_d, const int o_stride_t,
const int o_stride_h, const int o_stride_d,
cudaStream_t stream) {
int warps_per_block = h < 16 ? 4 : 8;
dim3 blocks(max_s, b);
dim3 threads(THREADS_PER_WARP, warps_per_block);
fused_rope_thd_backward_kernel<<<blocks, threads, 0, stream>>>(
output_grads, cu_seqlens, freqs, input_grads, cp_size, cp_rank, h, d, d2, stride_t, stride_h,
stride_d, o_stride_t, o_stride_h, o_stride_d);
NVTE_CHECK_CUDA(cudaGetLastError());
}
void fused_rope_forward(const Tensor &input, const Tensor &freqs, Tensor *output, const int s,
const int b, const int h, const int d, const int d2, const int stride_s,
const int stride_b, const int stride_h, const int stride_d,
const int o_stride_s, const int o_stride_b, const int o_stride_h,
const int o_stride_d, cudaStream_t stream) {
void fused_rope_forward(const Tensor &input, const Tensor &cu_seqlens, const Tensor &freqs,
Tensor *output, const NVTE_QKV_Format qkv_format, const bool interleaved,
const int cp_size, const int cp_rank, const int s, const int b, const int h,
const int d, const int d2, const int stride_s_or_t, const int stride_b,
const int stride_h, const int stride_d, cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, scalar_t,
fused_rope_forward_launcher(reinterpret_cast<const scalar_t *>(input.data.dptr),
reinterpret_cast<const float *>(freqs.data.dptr),
reinterpret_cast<scalar_t *>(output->data.dptr), s, b, h, d, d2,
stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b,
o_stride_h, o_stride_d, stream););
}
void fused_rope_backward(const Tensor &output_grads, const Tensor &freqs, Tensor *input_grads,
const int s, const int b, const int h, const int d, const int d2,
const int stride_s, const int stride_b, const int stride_h,
const int stride_d, const int o_stride_s, const int o_stride_b,
const int o_stride_h, const int o_stride_d, cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
output_grads.data.dtype, scalar_t,
fused_rope_backward_launcher(reinterpret_cast<const scalar_t *>(output_grads.data.dptr),
reinterpret_cast<const float *>(freqs.data.dptr),
reinterpret_cast<scalar_t *>(input_grads->data.dptr), s, b, h, d,
d2, stride_s, stride_b, stride_h, stride_d, o_stride_s,
o_stride_b, o_stride_h, o_stride_d, stream););
}
void fused_rope_thd_forward(const Tensor &input, const Tensor &cu_seqlens, const Tensor &freqs,
Tensor *output, const int cp_size, const int cp_rank, const int max_s,
const int b, const int h, const int d, const int d2, const int stride_t,
const int stride_h, const int stride_d, const int o_stride_t,
const int o_stride_h, const int o_stride_d, cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, scalar_t,
fused_rope_thd_forward_launcher(reinterpret_cast<const scalar_t *>(input.data.dptr),
reinterpret_cast<const int *>(cu_seqlens.data.dptr),
reinterpret_cast<const float *>(freqs.data.dptr),
reinterpret_cast<scalar_t *>(output->data.dptr), cp_size,
cp_rank, max_s, b, h, d, d2, stride_t, stride_h, stride_d,
o_stride_t, o_stride_h, o_stride_d, stream););
reinterpret_cast<scalar_t *>(output->data.dptr), qkv_format,
interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s_or_t,
stride_b, stride_h, stride_d, stream););
}
void fused_rope_thd_backward(const Tensor &output_grads, const Tensor &cu_seqlens,
const Tensor &freqs, Tensor *input_grads, const int cp_size,
const int cp_rank, const int max_s, const int b, const int h,
const int d, const int d2, const int stride_t, const int stride_h,
const int stride_d, const int o_stride_t, const int o_stride_h,
const int o_stride_d, cudaStream_t stream) {
void fused_rope_backward(const Tensor &output_grads, const Tensor &cu_seqlens, const Tensor &freqs,
Tensor *input_grads, const NVTE_QKV_Format qkv_format,
const bool interleaved, const int cp_size, const int cp_rank, const int s,
const int b, const int h, const int d, const int d2,
const int stride_s_or_t, const int stride_b, const int stride_h,
const int stride_d, cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
output_grads.data.dtype, scalar_t,
fused_rope_thd_backward_launcher(reinterpret_cast<const scalar_t *>(output_grads.data.dptr),
fused_rope_backward_launcher(reinterpret_cast<const scalar_t *>(output_grads.data.dptr),
reinterpret_cast<const int *>(cu_seqlens.data.dptr),
reinterpret_cast<const float *>(freqs.data.dptr),
reinterpret_cast<scalar_t *>(input_grads->data.dptr),
cp_size, cp_rank, max_s, b, h, d, d2, stride_t, stride_h,
stride_d, o_stride_t, o_stride_h, o_stride_d, stream););
reinterpret_cast<scalar_t *>(input_grads->data.dptr), qkv_format,
interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s_or_t,
stride_b, stride_h, stride_d, stream););
}
} // end namespace transformer_engine
void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor freqs, NVTETensor output,
const int s, const int b, const int h, const int d, const int d2,
const int stride_s, const int stride_b, const int stride_h,
const int stride_d, const int o_stride_s, const int o_stride_b,
const int o_stride_h, const int o_stride_d, cudaStream_t stream) {
void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens,
const NVTETensor freqs, NVTETensor output,
const NVTE_QKV_Format qkv_format, const bool interleaved,
const int cp_size, const int cp_rank, const int s, const int b,
const int h, const int d, const int d2, const int stride_s_or_t,
const int stride_b, const int stride_h, const int stride_d,
cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_rope_forward);
using namespace transformer_engine;
fused_rope_forward(*reinterpret_cast<const Tensor *>(input),
*reinterpret_cast<const Tensor *>(cu_seqlens),
*reinterpret_cast<const Tensor *>(freqs), reinterpret_cast<Tensor *>(output),
s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b,
o_stride_h, o_stride_d, stream);
qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s_or_t,
stride_b, stride_h, stride_d, stream);
}
void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor freqs,
NVTETensor input_grads, const int s, const int b, const int h,
const int d, const int d2, const int stride_s, const int stride_b,
const int stride_h, const int stride_d, const int o_stride_s,
const int o_stride_b, const int o_stride_h, const int o_stride_d,
void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens,
const NVTETensor freqs, NVTETensor input_grads,
const NVTE_QKV_Format qkv_format, const bool interleaved,
const int cp_size, const int cp_rank, const int s, const int b,
const int h, const int d, const int d2, const int stride_s_or_t,
const int stride_b, const int stride_h, const int stride_d,
cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_rope_backward);
using namespace transformer_engine;
fused_rope_backward(*reinterpret_cast<const Tensor *>(output_grads),
*reinterpret_cast<const Tensor *>(freqs),
reinterpret_cast<Tensor *>(input_grads), s, b, h, d, d2, stride_s, stride_b,
stride_h, stride_d, o_stride_s, o_stride_b, o_stride_h, o_stride_d, stream);
}
void nvte_fused_rope_thd_forward(const NVTETensor input, const NVTETensor cu_seqlens,
const NVTETensor freqs, NVTETensor output, const int cp_size,
const int cp_rank, const int max_s, const int b, const int h,
const int d, const int d2, const int stride_t, const int stride_h,
const int stride_d, const int o_stride_t, const int o_stride_h,
const int o_stride_d, cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_rope_thd_forward);
using namespace transformer_engine;
fused_rope_thd_forward(*reinterpret_cast<const Tensor *>(input),
*reinterpret_cast<const Tensor *>(cu_seqlens),
*reinterpret_cast<const Tensor *>(freqs),
reinterpret_cast<Tensor *>(output), cp_size, cp_rank, max_s, b, h, d, d2,
stride_t, stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, stream);
}
void nvte_fused_rope_thd_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens,
const NVTETensor freqs, NVTETensor input_grads, const int cp_size,
const int cp_rank, const int max_s, const int b, const int h,
const int d, const int d2, const int stride_t, const int stride_h,
const int stride_d, const int o_stride_t, const int o_stride_h,
const int o_stride_d, cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_rope_thd_backward);
using namespace transformer_engine;
fused_rope_thd_backward(
*reinterpret_cast<const Tensor *>(output_grads),
*reinterpret_cast<const Tensor *>(cu_seqlens), *reinterpret_cast<const Tensor *>(freqs),
reinterpret_cast<Tensor *>(input_grads), cp_size, cp_rank, max_s, b, h, d, d2, stride_t,
stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, stream);
reinterpret_cast<Tensor *>(input_grads), qkv_format, interleaved, cp_size,
cp_rank, s, b, h, d, d2, stride_s_or_t, stride_b, stride_h, stride_d, stream);
}
......@@ -7,6 +7,7 @@
#ifndef TRANSFORMER_ENGINE_FUSED_ROPE_H_
#define TRANSFORMER_ENGINE_FUSED_ROPE_H_
#include "fused_attn.h"
#include "transformer_engine.h"
#ifdef __cplusplus
......@@ -16,112 +17,63 @@ extern "C" {
/*! \brief Apply rotary positional embedding to the input tensor.
*
* \param[in] input Input tensor for fused rope.
* \param[in] cu_seqlens The cumulative sum of sequence lengths tensor.
* (Required for the thd format, empty tensor for other formats)
* \param[in] freqs The freqs tensor.
* \param[out] output Output tensor.
* \param[in] qkv_format QKV format.
* \param[in] interleaved Whether to use interleaved rotary position embedding.
* \param[in] cp_size Context parallel world size.
* \param[in] cp_rank Context parallel rank.
* \param[in] s Length of the s dimension of input.
* \param[in] b Length of the b dimension of input.
* \param[in] h Length of the h dimension of input.
* \param[in] d Length of the d dimension of input.
* \param[in] d2 Length of the d dimension of freqs.
* \param[in] stride_s Stride of the s dimension of input.
* \param[in] stride_b Stride of the b dimension of input.
* \param[in] stride_s_or_t Stride of the s (sbhd/bshd)/t (thd) dimension of input.
* \param[in] stride_b Stride of the b dimension of input. (0 for thd).
* \param[in] stride_h Stride of the h dimension of input.
* \param[in] stride_d Stride of the d dimension of input.
* \param[in] o_stride_s Stride of the s dimension of output.
* \param[in] o_stride_b Stride of the b dimension of output.
* \param[in] o_stride_h Stride of the h dimension of output.
* \param[in] o_stride_d Stride of the d dimension of output.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor freqs, NVTETensor output,
const int s, const int b, const int h, const int d, const int d2,
const int stride_s, const int stride_b, const int stride_h,
const int stride_d, const int o_stride_s, const int o_stride_b,
const int o_stride_h, const int o_stride_d, cudaStream_t stream);
void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens,
const NVTETensor freqs, NVTETensor output,
const NVTE_QKV_Format qkv_format, const bool interleaved,
const int cp_size, const int cp_rank, const int s, const int b,
const int h, const int d, const int d2, const int stride_s_or_t,
const int stride_b, const int stride_h, const int stride_d,
cudaStream_t stream);
/*! \brief Compute the backward of the fused rope.
*
* \param[in] output_grads Incoming gradient tensor for backward.
* \param[in] cu_seqlens The cumulative sum of sequence lengths tensor.
* (Required for the thd format, empty tensor for other formats)
* \param[in] freqs The freqs tensor.
* \param[out] input_grads Input gradient tensor to calculate.
* \param[in] qkv_format QKV format.
* \param[in] interleaved Whether to use interleaved rotary position embedding.
* \param[in] cp_size Context parallel world size.
* \param[in] cp_rank Context parallel rank.
* \param[in] s Length of the s dimension of output_grads.
* \param[in] b Length of the b dimension of output_grads.
* \param[in] h Length of the h dimension of output_grads.
* \param[in] d Length of the d dimension of output_grads.
* \param[in] d2 Length of the d dimension of freqs.
* \param[in] stride_s Stride of the s dimension of output_grads.
* \param[in] stride_b Stride of the b dimension of output_grads.
* \param[in] stride_s_or_t Stride of the s (sbhd/bshd)/t (thd) dimension of output_grads.
* \param[in] stride_b Stride of the b dimension of output_grads. (0 for thd).
* \param[in] stride_h Stride of the h dimension of output_grads.
* \param[in] stride_d Stride of the d dimension of output_grads.
* \param[in] o_stride_s Stride of the s dimension of input_grads.
* \param[in] o_stride_b Stride of the b dimension of input_grads.
* \param[in] o_stride_h Stride of the h dimension of input_grads.
* \param[in] o_stride_d Stride of the d dimension of input_grads.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor freqs,
NVTETensor input_grads, const int s, const int b, const int h,
const int d, const int d2, const int stride_s, const int stride_b,
const int stride_h, const int stride_d, const int o_stride_s,
const int o_stride_b, const int o_stride_h, const int o_stride_d,
void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens,
const NVTETensor freqs, NVTETensor input_grads,
const NVTE_QKV_Format qkv_format, const bool interleaved,
const int cp_size, const int cp_rank, const int s, const int b,
const int h, const int d, const int d2, const int stride_s_or_t,
const int stride_b, const int stride_h, const int stride_d,
cudaStream_t stream);
/*! \brief Apply rotary positional embedding to the input tensor in thd format.
*
* \param[in] input Input tensor for fused rope.
* \param[in] cu_seqlens The cumulative sum of sequence lengths tensor.
* \param[in] freqs The freqs tensor.
* \param[out] output Output tensor.
* \param[in] cp_size Context parallel world size.
* \param[in] cp_rank Context parallel rank.
* \param[in] max_s Max sequence length.
* \param[in] b Batch size.
* \param[in] h Length of the h dimension of input.
* \param[in] d Length of the d dimension of input.
* \param[in] d2 Length of the d dimension of freqs.
* \param[in] stride_t Stride of the t dimension of input.
* \param[in] stride_h Stride of the h dimension of input.
* \param[in] stride_d Stride of the d dimension of input.
* \param[in] o_stride_t Stride of the t dimension of output.
* \param[in] o_stride_h Stride of the h dimension of output.
* \param[in] o_stride_d Stride of the d dimension of output.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fused_rope_thd_forward(const NVTETensor input, const NVTETensor cu_seqlens,
const NVTETensor freqs, NVTETensor output, const int cp_size,
const int cp_rank, const int max_s, const int b, const int h,
const int d, const int d2, const int stride_t, const int stride_h,
const int stride_d, const int o_stride_t, const int o_stride_h,
const int o_stride_d, cudaStream_t stream);
/*! \brief Compute the backward of the fused rope in thd format.
*
* \param[in] output_grads Incoming gradient tensor for backward.
* \param[in] cu_seqlens The cumulative sum of sequence lengths tensor.
* \param[in] freqs The freqs tensor.
* \param[out] input_grads Input gradient to calculate.
* \param[in] cp_size Context parallel world size.
* \param[in] cp_rank Context parallel rank.
* \param[in] max_s Max sequence length.
* \param[in] b Batch size.
* \param[in] h Length of the h dimension of output_grads.
* \param[in] d Length of the d dimension of output_grads.
* \param[in] d2 Length of the d dimension of freqs.
* \param[in] stride_t Stride of the t dimension of output_grads.
* \param[in] stride_h Stride of the h dimension of output_grads.
* \param[in] stride_d Stride of the d dimension of output_grads.
* \param[in] o_stride_t Stride of the t dimension of input_grads.
* \param[in] o_stride_h Stride of the h dimension of input_grads.
* \param[in] o_stride_d Stride of the d dimension of input_grads.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fused_rope_thd_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens,
const NVTETensor freqs, NVTETensor input_grads, const int cp_size,
const int cp_rank, const int max_s, const int b, const int h,
const int d, const int d2, const int stride_t, const int stride_h,
const int stride_d, const int o_stride_t, const int o_stride_h,
const int o_stride_d, cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
......
......@@ -265,16 +265,14 @@ void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reductio
**************************************************************************************************/
at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs,
const bool transpose_output_memory);
const NVTE_QKV_Format qkv_format, const bool interleaved,
const c10::optional<at::Tensor> cu_seqlens, const int cp_size,
const int cp_rank);
at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs,
const bool transpose_output_memory);
at::Tensor fused_rope_thd_forward(const at::Tensor &input, const at::Tensor &cu_seqlens,
const at::Tensor &freqs, const int cp_size, const int cp_rank);
at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads, const at::Tensor &cu_seqlens,
const at::Tensor &freqs, const int cp_size, const int cp_rank);
const NVTE_QKV_Format qkv_format, const bool interleaved,
const c10::optional<at::Tensor> cu_seqlens, const int cp_size,
const int cp_rank);
/***************************************************************************************************
* Miscellaneous
......
......@@ -7,138 +7,38 @@
#include "extensions.h"
at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs,
const bool transpose_output_memory) {
const NVTE_QKV_Format qkv_format, const bool interleaved,
const c10::optional<at::Tensor> cu_seqlens, const int cp_size,
const int cp_rank) {
using namespace transformer_engine::pytorch;
TORCH_CHECK(input.dim() == 4, "expected 4D tensor");
TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor");
TORCH_CHECK(input.size(0) <= freqs.size(0),
"expected freqs tensor has a longer sequence length than input");
TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1,
"expected the second and third dims of the freqs tensor equal 1");
TORCH_CHECK(input.size(3) >= freqs.size(3),
"expected the last dim of the input tensor equals or is "
"greater than the freqs tensor");
TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float,
"Dtype of the freqs tensor must be float");
// input sizes: (s, b, h, d)
// s: sequence length
// b: batch size
// h: head num
// d: dim of each head
const int s = input.size(0);
const int b = input.size(1);
const int h = input.size(2);
const int d = input.size(3);
// input strides
const int stride_s = input.stride(0);
const int stride_b = input.stride(1);
const int stride_h = input.stride(2);
const int stride_d = input.stride(3);
// freqs' shape is always (s, 1, 1, d2), so the strides are same under
// different memory formats
const int d2 = freqs.size(3);
// output
auto act_options = input.options().requires_grad(false);
at::Tensor output;
if (transpose_output_memory) {
output = torch::empty({b, s, h, d}, act_options).transpose(0, 1);
} else {
output = torch::empty({s, b, h, d}, act_options);
}
// output strides
const int o_stride_s = output.stride(0);
const int o_stride_b = output.stride(1);
const int o_stride_h = output.stride(2);
const int o_stride_d = output.stride(3);
auto act_options = at::TensorOptions().dtype(input.scalar_type()).device(input.device());
auto output = at::empty(input.sizes(), act_options);
auto input_cu = makeTransformerEngineTensor(input);
auto freqs_cu = makeTransformerEngineTensor(freqs);
auto output_cu = makeTransformerEngineTensor(output);
nvte_fused_rope_forward(input_cu.data(), freqs_cu.data(), output_cu.data(), s, b, h, d, d2,
stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b,
o_stride_h, o_stride_d, at::cuda::getCurrentCUDAStream());
return output;
}
at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs,
const bool transpose_output_memory) {
using namespace transformer_engine::pytorch;
TORCH_CHECK(output_grads.dim() == 4, "expected 4D tensor");
TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor");
TORCH_CHECK(output_grads.size(0) <= freqs.size(0),
"expected freqs tensor has a longer sequence length than output_grads");
TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1,
"expected the second and third dims of the freqs tensor equal 1");
TORCH_CHECK(output_grads.size(3) >= freqs.size(3),
"expected the last dim of the output_grads tensor equals or is "
"greater than the freqs tensor");
TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float,
"Dtype of the freqs tensor must be float");
// output_grads sizes: (s, b, h, d)
// s: sequence length
// b: batch size
// h: head num
// d: dim of each head
const int s = output_grads.size(0);
const int b = output_grads.size(1);
const int h = output_grads.size(2);
const int d = output_grads.size(3);
// output_grads strides
const int stride_s = output_grads.stride(0);
const int stride_b = output_grads.stride(1);
const int stride_h = output_grads.stride(2);
const int stride_d = output_grads.stride(3);
// freqs' shape is always (s, 1, 1, d2), so the strides are same under
// different memory formats
const int d2 = freqs.size(3);
auto act_options = output_grads.options().requires_grad(false);
at::Tensor input_grads;
if (transpose_output_memory) {
input_grads = torch::empty({b, s, h, d}, act_options).transpose(0, 1);
} else {
input_grads = torch::empty({s, b, h, d}, act_options);
}
const int o_stride_s = input_grads.stride(0);
const int o_stride_b = input_grads.stride(1);
const int o_stride_h = input_grads.stride(2);
const int o_stride_d = input_grads.stride(3);
auto output_grads_cu = makeTransformerEngineTensor(output_grads);
auto freqs_cu = makeTransformerEngineTensor(freqs);
auto input_grads_cu = makeTransformerEngineTensor(input_grads);
nvte_fused_rope_backward(output_grads_cu.data(), freqs_cu.data(), input_grads_cu.data(), s, b, h,
d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b,
o_stride_h, o_stride_d, at::cuda::getCurrentCUDAStream());
return input_grads;
}
at::Tensor fused_rope_thd_forward(const at::Tensor &input, const at::Tensor &cu_seqlens,
const at::Tensor &freqs, const int cp_size, const int cp_rank) {
using namespace transformer_engine::pytorch;
if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
TORCH_CHECK(input.dim() == 3, "expected 3D tensor");
TORCH_CHECK(cu_seqlens.dim() == 1, "expected 1D tensor");
TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor");
TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1,
"expected the second and third dims of the freqs tensor equal 1");
TORCH_CHECK(cu_seqlens.has_value(), "expected cu_seqlens tensor");
TORCH_CHECK(cu_seqlens.value().dim() == 1, "expected 1D tensor");
TORCH_CHECK(input.size(2) >= freqs.size(3),
"expected the last dim of the input tensor equals or is "
"greater than the freqs tensor");
TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float,
"Dtype of the freqs tensor must be float");
// input sizes: (t, h, d)
// t: cumulative sum of sequence lengths
// h: head num
// d: dim of each head
const int t = input.size(0);
// const int t = input.size(0);
const int h = input.size(1);
const int d = input.size(2);
// input strides
......@@ -146,51 +46,86 @@ at::Tensor fused_rope_thd_forward(const at::Tensor &input, const at::Tensor &cu_
const int stride_h = input.stride(1);
const int stride_d = input.stride(2);
// batch size
const int b = cu_seqlens.size(0) - 1;
const int b = cu_seqlens.value().size(0) - 1;
// freqs' shape is (max_s, 1, 1, d2)
const int max_s = freqs.size(0);
const int d2 = freqs.size(3);
// output
auto act_options = input.options().requires_grad(false);
auto output = torch::empty({t, h, d}, act_options);
// output strides
const int o_stride_t = output.stride(0);
const int o_stride_h = output.stride(1);
const int o_stride_d = output.stride(2);
auto cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens.value());
auto input_cu = makeTransformerEngineTensor(input);
auto cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens);
auto freqs_cu = makeTransformerEngineTensor(freqs);
auto output_cu = makeTransformerEngineTensor(output);
nvte_fused_rope_thd_forward(input_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(),
output_cu.data(), cp_size, cp_rank, max_s, b, h, d, d2, stride_t,
stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d,
nvte_fused_rope_forward(input_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(),
output_cu.data(), qkv_format, interleaved, cp_size, cp_rank, max_s, b,
h, d, d2, stride_t, /*stride_b=*/0, stride_h, stride_d,
at::cuda::getCurrentCUDAStream());
return output;
}
TORCH_CHECK(input.dim() == 4, "expected 4D tensor");
// input sizes: (s, b, h, d) or (b, s, h, d)
// s: sequence length
// b: batch size
// h: head num
// d: dim of each head
const int s = qkv_format == NVTE_QKV_Format::NVTE_SBHD ? input.size(0) : input.size(1);
const int b = qkv_format == NVTE_QKV_Format::NVTE_SBHD ? input.size(1) : input.size(0);
const int h = input.size(2);
const int d = input.size(3);
// input strides
const int stride_s = qkv_format == NVTE_QKV_Format::NVTE_SBHD ? input.stride(0) : input.stride(1);
const int stride_b = qkv_format == NVTE_QKV_Format::NVTE_SBHD ? input.stride(1) : input.stride(0);
const int stride_h = input.stride(2);
const int stride_d = input.stride(3);
// freqs' shape is always (s, 1, 1, d2), so the strides are same under
// different memory formats
const int d2 = freqs.size(3);
TORCH_CHECK(s * cp_size <= freqs.size(0),
"expected freqs tensor has a longer sequence length than input");
TORCH_CHECK(d >= d2,
"expected the last dim of the input tensor equals or is "
"greater than the freqs tensor");
auto cu_seqlens_cu = transformer_engine::TensorWrapper(); // empty cu_seqlens tensor
nvte_fused_rope_forward(input_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(), output_cu.data(),
qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s,
stride_b, stride_h, stride_d, at::cuda::getCurrentCUDAStream());
return output;
}
at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads, const at::Tensor &cu_seqlens,
const at::Tensor &freqs, const int cp_size, const int cp_rank) {
at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs,
const NVTE_QKV_Format qkv_format, const bool interleaved,
const c10::optional<at::Tensor> cu_seqlens, const int cp_size,
const int cp_rank) {
using namespace transformer_engine::pytorch;
TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor");
TORCH_CHECK(cu_seqlens.dim() == 1, "expected 1D tensor");
TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor");
TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1,
"expected the second and third dims of the freqs tensor equal 1");
TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float,
"Dtype of the freqs tensor must be float");
auto act_options =
at::TensorOptions().dtype(output_grads.scalar_type()).device(output_grads.device());
auto input_grads = at::empty(output_grads.sizes(), act_options);
auto output_grads_cu = makeTransformerEngineTensor(output_grads);
auto freqs_cu = makeTransformerEngineTensor(freqs);
auto input_grads_cu = makeTransformerEngineTensor(input_grads);
if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor");
TORCH_CHECK(cu_seqlens.has_value(), "expected cu_seqlens tensor");
TORCH_CHECK(cu_seqlens.value().dim() == 1, "expected 1D tensor");
TORCH_CHECK(output_grads.size(2) >= freqs.size(3),
"expected the last dim of the output_grads tensor equals or is "
"greater than the freqs tensor");
TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float,
"Dtype of the freqs tensor must be float");
// output_grads sizes: (t, h, d)
// t: cumulative sum of sequence lengths
// h: head num
// d: dim of each head
const int t = output_grads.size(0);
// const int t = output_grads.size(0);
const int h = output_grads.size(1);
const int d = output_grads.size(2);
// output_grads strides
......@@ -198,25 +133,54 @@ at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads, const at::Ten
const int stride_h = output_grads.stride(1);
const int stride_d = output_grads.stride(2);
// batch size
const int b = cu_seqlens.size(0) - 1;
const int b = cu_seqlens.value().size(0) - 1;
// freqs' shape is (max_s, 1, 1, d2)
const int max_s = freqs.size(0);
const int d2 = freqs.size(3);
auto act_options = output_grads.options().requires_grad(false);
auto input_grads = torch::empty({t, h, d}, act_options);
const int o_stride_t = input_grads.stride(0);
const int o_stride_h = input_grads.stride(1);
const int o_stride_d = input_grads.stride(2);
auto cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens.value());
auto output_grads_cu = makeTransformerEngineTensor(output_grads);
auto cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens);
auto freqs_cu = makeTransformerEngineTensor(freqs);
auto input_grads_cu = makeTransformerEngineTensor(input_grads);
nvte_fused_rope_backward(output_grads_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(),
input_grads_cu.data(), qkv_format, interleaved, cp_size, cp_rank,
max_s, b, h, d, d2, stride_t, /*stride_b=*/0, stride_h, stride_d,
at::cuda::getCurrentCUDAStream());
return input_grads;
}
TORCH_CHECK(output_grads.dim() == 4, "expected 4D tensor");
// output_grads sizes: (s, b, h, d)
// s: sequence length
// b: batch size
// h: head num
// d: dim of each head
const int s =
qkv_format == NVTE_QKV_Format::NVTE_SBHD ? output_grads.size(0) : output_grads.size(1);
const int b =
qkv_format == NVTE_QKV_Format::NVTE_SBHD ? output_grads.size(1) : output_grads.size(0);
const int h = output_grads.size(2);
const int d = output_grads.size(3);
// output_grads strides
const int stride_s =
qkv_format == NVTE_QKV_Format::NVTE_SBHD ? output_grads.stride(0) : output_grads.stride(1);
const int stride_b =
qkv_format == NVTE_QKV_Format::NVTE_SBHD ? output_grads.stride(1) : output_grads.stride(0);
const int stride_h = output_grads.stride(2);
const int stride_d = output_grads.stride(3);
// freqs' shape is always (s, 1, 1, d2), so the strides are same under
// different memory formats
const int d2 = freqs.size(3);
TORCH_CHECK(s * cp_size <= freqs.size(0),
"expected freqs tensor has a longer sequence length than output_grads");
TORCH_CHECK(d >= d2,
"expected the last dim of the output_grads tensor equals or is "
"greater than the freqs tensor");
nvte_fused_rope_thd_backward(output_grads_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(),
input_grads_cu.data(), cp_size, cp_rank, max_s, b, h, d, d2,
stride_t, stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d,
auto cu_seqlens_cu = transformer_engine::TensorWrapper(); // empty cu_seqlens tensor
nvte_fused_rope_backward(output_grads_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(),
input_grads_cu.data(), qkv_format, interleaved, cp_size, cp_rank, s, b,
h, d, d2, stride_s, stride_b, stride_h, stride_d,
at::cuda::getCurrentCUDAStream());
return input_grads;
......
......@@ -229,10 +229,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::call_guard<py::gil_scoped_release>());
m.def("fused_rope_backward", &fused_rope_backward, "Fused Apply RoPE BWD",
py::call_guard<py::gil_scoped_release>());
m.def("fused_rope_thd_forward", &fused_rope_thd_forward, "Fused Apply RoPE FWD for thd format",
py::call_guard<py::gil_scoped_release>());
m.def("fused_rope_thd_backward", &fused_rope_thd_backward, "Fused Apply RoPE BWD for thd format",
py::call_guard<py::gil_scoped_release>());
// Misc
m.def("get_cublasLt_version", &get_cublasLt_version, "Get cublasLt version",
......
......@@ -7,7 +7,12 @@ Rotary Position Embedding implementation of different types along with helper fu
"""
from typing import Optional, Tuple, Union
import torch
import transformer_engine_torch as tex
from transformer_engine.pytorch.cpp_extensions.fused_attn import QKVFormat
__all__ = ["RotaryPositionEmbedding", "apply_rotary_pos_emb"]
class RotaryPositionEmbedding(torch.nn.Module):
......@@ -22,19 +27,24 @@ class RotaryPositionEmbedding(torch.nn.Module):
seq_len_interpolation_factor: Optional[int] = None,
pretrained_max_position_embeddings: Optional[int] = None,
rotary_base: float = 10000.0,
interleaved: bool = False,
):
"""
Parameters
----------
dim: int
rotary embedding dimension
rotary_percent: float
Rotary embedding dimension.
rotary_percent: float, default = 1.0
Percent of rotary dimension to use for rotary position embeddings.
seq_len_interpolation_factor: int
if not None, discrete positions will be interpolated by this factor via the trick in
seq_len_interpolation_factor: int, default = None
If not None, discrete positions will be interpolated by this factor via the trick in
https://arxiv.org/abs/2306.15595
pretrained_max_position_embeddings: int
pre-trained max_position_embeddings before position interpolation
pretrained_max_position_embeddings: int, default = None
Pre-trained max_position_embeddings before position interpolation.
rotary_base: float, default = 10000.0
Base of the rotary position embedding.
interleaved: bool, default = False
Whether to use interleaved rotary position embedding.
"""
super().__init__()
if rotary_percent < 1.0:
......@@ -50,17 +60,18 @@ class RotaryPositionEmbedding(torch.nn.Module):
)
self.register_buffer("inv_freq", inv_freq)
self.pretrained_max_position_embeddings = pretrained_max_position_embeddings
self.interleaved = interleaved
def forward(self, max_seq_len: int, offset: int = 0):
"""
Create rotary position embedding frequencies
Create rotary position embedding frequencies.
Parameters
----------
max_seq_len: int
sequence length of a sample
Sequence length of a sample.
offset: int, default = 0
fixed offset for freqencies
Fixed offset for frequencies.
"""
seq = (
torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
......@@ -84,7 +95,12 @@ class RotaryPositionEmbedding(torch.nn.Module):
freqs = torch.einsum("i , j -> i j", seq, self.inv_freq)
# first part even vector components, second part odd vector components,
# 2 * dim in dimension size
if not self.interleaved:
emb = torch.cat((freqs, freqs), dim=-1)
else:
emb = torch.stack((freqs.view(-1, 1), freqs.view(-1, 1)), dim=-1).view(
freqs.shape[0], -1
)
# emb [seq_length, .., dim]
return emb.reshape(emb.size(0), 1, 1, emb.size(1))
......@@ -104,61 +120,146 @@ class FusedRoPEFunc(torch.autograd.Function):
t: torch.Tensor,
freqs: torch.Tensor,
tensor_format: str = "sbhd",
interleaved: bool = False,
cu_seqlens: Union[torch.Tensor, None] = None,
cp_size: int = 1,
cp_rank: int = 0,
) -> torch.Tensor:
# pylint: disable=missing-function-docstring
"""Fused RoPE forward."""
if freqs.dtype != torch.float32:
freqs = freqs.float()
if tensor_format == "sbhd":
output = tex.fused_rope_forward(t, freqs, False)
elif tensor_format == "bshd":
output = tex.fused_rope_forward(t.transpose(0, 1), freqs, True).transpose(0, 1)
elif tensor_format == "thd":
output = tex.fused_rope_thd_forward(t, cu_seqlens, freqs, cp_size, cp_rank)
else:
raise ValueError(f"Unsupported tensor_format: {tensor_format}.")
assert tensor_format in (
"sbhd",
"bshd",
"thd",
), f"Unsupported tensor_format: {tensor_format}."
output = tex.fused_rope_forward(
t, freqs, QKVFormat[tensor_format], interleaved, cu_seqlens, cp_size, cp_rank
)
ctx.save_for_backward(freqs, cu_seqlens)
ctx.tensor_format = tensor_format
ctx.cp_size = cp_size
ctx.cp_rank = cp_rank
ctx.interleaved = interleaved
return output
@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
# pylint: disable=missing-function-docstring
"""Fused RoPE backward."""
freqs, cu_seqlens = ctx.saved_tensors
if ctx.tensor_format == "sbhd":
grad_input = tex.fused_rope_backward(grad_output, freqs, False)
elif ctx.tensor_format == "bshd":
grad_input = tex.fused_rope_backward(
grad_output.transpose(0, 1), freqs, True
).transpose(0, 1)
elif ctx.tensor_format == "thd":
grad_input = tex.fused_rope_thd_backward(
grad_output, cu_seqlens, freqs, ctx.cp_size, ctx.cp_rank
grad_output,
freqs,
QKVFormat[ctx.tensor_format],
ctx.interleaved,
cu_seqlens,
ctx.cp_size,
ctx.cp_rank,
)
else:
raise ValueError(f"Unsupported tensor_format: {ctx.tensor_format}.")
return grad_input, None, None, None, None, None
return grad_input, None, None, None, None, None, None
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
"""
change sign so the last dimension becomes [-odd, +even]
def _rotate_half(x: torch.Tensor, interleaved: bool) -> torch.Tensor:
"""Change sign so the last dimension becomes [-odd, +even]
Args:
x: torch.Tensor. Input tensor.
interleaved: bool. Whether to use interleaved rotary position embedding.
Returns:
Tensor: Tensor rotated half.
"""
x = x.view(x.shape[:-1] + torch.Size((2, x.shape[-1] // 2)))
x1, x2 = x.unbind(dim=-2)
if not interleaved:
x1, x2 = torch.chunk(x, 2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
# interleaved
x1 = x[:, :, :, ::2]
x2 = x[:, :, :, 1::2]
x_new = torch.stack((-x2, x1), dim=-1)
return x_new.view(x_new.shape[0], x_new.shape[1], x_new.shape[2], -1)
def _apply_rotary_pos_emb_base(
t: torch.Tensor,
freqs: torch.Tensor,
tensor_format: str = "sbhd",
interleaved: bool = False,
) -> torch.Tensor:
"""
Base implementation of applying rotary positional embedding tensor to the input tensor.
Parameters
----------
t: torch.Tensor
Input tensor of shape `[s, b, h, d]` or `[b, s, h, d]`, on which rotary positional
embedding will be applied.
freqs: torch.Tensor
Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
with `s2 >= s` and `d2 <= d`.
tensor_format: {'sbhd', 'bshd'}, default = 'sbhd'
Should be `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is of shape
`[seq, bs, ...]`.
interleaved: bool, default = False
Whether to use interleaved rotary position embedding.
"""
max_seq_len = freqs.shape[0]
cur_seq_len = t.shape[1] if tensor_format == "bshd" else t.shape[0]
# Only apply the rotary embeddings up to the sequence length of the running
# input.
assert (
cur_seq_len <= max_seq_len
), f"Rotary Embeddings only supported up to {max_seq_len} sequence length!"
freqs = freqs[:cur_seq_len]
if tensor_format == "bshd":
freqs = freqs.transpose(0, 1) # [seq, 1, 1, dim] -> [1, seq, 1, dim]
# cos/sin first then dtype conversion for better precision
cos_ = torch.cos(freqs).to(t.dtype)
sin_ = torch.sin(freqs).to(t.dtype)
rot_dim = freqs.shape[-1]
# ideally t_pass is empty so rotary pos embedding is applied to all tensor t
t, t_pass = t[..., :rot_dim], t[..., rot_dim:]
# first part is cosine component
# second part is sine component, need to change signs with _rotate_half method
t = (t * cos_) + (_rotate_half(t, interleaved) * sin_)
return torch.cat((t, t_pass), dim=-1)
def _get_freqs_on_this_cp_rank(
freqs: torch.Tensor, seqlen: int, cp_size: int, cp_rank: int
) -> torch.Tensor:
"""Get the position embedding on the current context parallel rank.
Args:
freqs: torch.Tensor. Positional embedding tensor in shape `[s2, 1, 1, d2]`.
seqlen: int. Length of the current sequence.
cp_size: int. Context parallel world size.
cp_rank: int. Context parallel rank.
"""
if cp_size > 1:
cp_seg = seqlen // 2
full_seqlen = cp_size * seqlen
return torch.cat(
[
freqs[cp_rank * cp_seg : (cp_rank + 1) * cp_seg],
freqs[full_seqlen - (cp_rank + 1) * cp_seg : full_seqlen - cp_rank * cp_seg],
]
)
# cp_size == 1
return freqs[:seqlen]
def apply_rotary_pos_emb(
t: torch.Tensor,
freqs: torch.Tensor,
tensor_format: str = "sbhd",
interleaved: bool = False,
fused: bool = False,
cu_seqlens: Union[torch.Tensor, None] = None,
cp_size: int = 1,
......@@ -175,11 +276,13 @@ def apply_rotary_pos_emb(
freqs: torch.Tensor
Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
with `s2 >= s` and `d2 <= d`.
fused: bool, default = False
Whether to use a fused applying RoPE implementation.
tensor_format: {'sbhd', 'bshd', 'thd'}, default = 'sbhd'
is `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is
of shape `[seq, bs, ...]`. 'thd' is only supported when `fused` is True.
interleaved: bool, default = False
Whether to use interleaved rotary position embedding.
fused: bool, default = False
Whether to use a fused applying RoPE implementation.
cu_seqlens: torch.Tensor, default = None.
Cumulative sum of sequence lengths in a batch for `t`, with shape [b + 1] and
dtype torch.int32. Only valid when `tensor_format` is 'thd'.
......@@ -189,37 +292,40 @@ def apply_rotary_pos_emb(
cp_rank: int, default = 0.
Context parallel rank. Only valid when `tensor_format` is 'thd' and `fused` is True.
"""
if fused:
assert (
tensor_format != "thd" or cu_seqlens is not None
), "cu_seqlens must not be None when tensor_format is 'thd'."
return FusedRoPEFunc.apply(t, freqs, tensor_format, cu_seqlens, cp_size, cp_rank)
assert tensor_format in ("sbhd", "bshd"), (
"Only formats `sbhd` or `bshd` are supported for input tensor `t` "
f"when fused is False, got {tensor_format}."
if fused:
return FusedRoPEFunc.apply(
t, freqs, tensor_format, interleaved, cu_seqlens, cp_size, cp_rank
)
max_seq_len = freqs.shape[0]
cur_seq_len = t.shape[1] if tensor_format == "bshd" else t.shape[0]
# Only apply the rotary embeddings up to the sequence length of the running
# input.
assert (
cur_seq_len <= max_seq_len
), f"Rotary Embeddings only supported up to {max_seq_len} sequence length!"
freqs = freqs[:cur_seq_len]
if tensor_format == "bshd":
freqs = freqs.transpose(0, 1) # [seq, 1, 1, dim] -> [1, seq, 1, dim]
# cos/sin first then dtype conversion for better precision
cos_ = torch.cos(freqs).to(t.dtype)
sin_ = torch.sin(freqs).to(t.dtype)
rot_dim = freqs.shape[-1]
# ideally t_pass is empty so rotary pos embedding is applied to all tensor t
t, t_pass = t[..., :rot_dim], t[..., rot_dim:]
# Unfused THD format
if tensor_format == "thd":
cu_seqlens = cu_seqlens // cp_size
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
return torch.cat(
[
_apply_rotary_pos_emb_base(
x.unsqueeze(1),
_get_freqs_on_this_cp_rank(freqs, x.size(0), cp_size, cp_rank),
interleaved=interleaved,
)
for x in torch.split(t, seqlens)
]
).squeeze(1)
# first part is cosine component
# second part is sine component, need to change signs with _rotate_half method
t = (t * cos_) + (_rotate_half(t) * sin_)
return torch.cat((t, t_pass), dim=-1)
# Unfused SBHD/BSHD format
if tensor_format == "sbhd":
seqlen = t.size(0)
elif tensor_format == "bshd":
seqlen = t.size(1)
else:
raise ValueError(f"Unsupported tensor_format: {tensor_format}.")
return _apply_rotary_pos_emb_base(
t,
_get_freqs_on_this_cp_rank(freqs, seqlen, cp_size, cp_rank),
tensor_format,
interleaved=interleaved,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment