"...git@developer.sourcefind.cn:kecinstone/2024-pra-vllm.git" did not exist on "e38074b1e6ad0975acbfa15d858c4bd7cd005e99"
Unverified Commit ba605f18 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[PyTorch][Common] Refactor RoPE (#1626)



* refactor to add cp support for sbhd/bshd
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* support interleaved
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* format
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* add interleaved to RotaryPositionEmbedding in test
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* update
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* merge sbhd/bshd and thd functions
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

---------
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
parent ff884e20
...@@ -11,52 +11,6 @@ from transformer_engine.pytorch.dot_product_attention.rope import ( ...@@ -11,52 +11,6 @@ from transformer_engine.pytorch.dot_product_attention.rope import (
) )
def _get_thd_freqs_on_this_cp_rank(
cp_rank: int, cp_size: int, x: torch.Tensor, freqs: torch.Tensor
) -> torch.Tensor:
if cp_size > 1:
cp_seg = x.size(0) // 2
full_seqlen = cp_size * x.size(0)
return torch.cat(
[
freqs[cp_rank * cp_seg : (cp_rank + 1) * cp_seg],
freqs[full_seqlen - (cp_rank + 1) * cp_seg : full_seqlen - cp_rank * cp_seg],
]
)
else:
return freqs[: x.size(0)]
def apply_rotary_pos_emb_thd(
t: torch.Tensor,
cu_seqlens: torch.Tensor,
freqs: torch.Tensor,
cp_size: int = 1,
cp_rank: int = 0,
) -> torch.Tensor:
"""A baseline implementation of applying RoPE for `thd` format.
Args:
t (Tensor): Input tensor T is of shape [t, h, d]
cu_seqlens(Tensor): Cumulative sum of sequence lengths in a batch for `t`,
with shape [b + 1] and dtype torch.int32.
freqs (Tensor): Rotary Positional embedding tensor freq is of shape [max_s, 1, 1, d]
Returns:
Tensor: Shape [t, h, d]. The input tensor after applying RoPE.
"""
cu_seqlens = cu_seqlens // cp_size
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
return torch.cat(
[
apply_rotary_pos_emb(
x.unsqueeze(1), _get_thd_freqs_on_this_cp_rank(cp_rank, cp_size, x, freqs)
)
for x in torch.split(t, seqlens)
]
).squeeze(1)
# Gradient is a broadcasted scalar # Gradient is a broadcasted scalar
def _overlapping_grad(output: torch.Tensor) -> torch.Tensor: def _overlapping_grad(output: torch.Tensor) -> torch.Tensor:
return output.sum() * 2 return output.sum() * 2
...@@ -76,6 +30,8 @@ def _non_overlapping_grad(output: torch.Tensor) -> torch.Tensor: ...@@ -76,6 +30,8 @@ def _non_overlapping_grad(output: torch.Tensor) -> torch.Tensor:
@pytest.mark.parametrize("transpose", [None, (0, 1), (2, 3)]) @pytest.mark.parametrize("transpose", [None, (0, 1), (2, 3)])
@pytest.mark.parametrize("tensor_format", ["sbhd", "bshd"]) @pytest.mark.parametrize("tensor_format", ["sbhd", "bshd"])
@pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad]) @pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad])
@pytest.mark.parametrize("cp_size", [1, 2])
@pytest.mark.parametrize("interleaved", [True, False])
def test_fused_rope( def test_fused_rope(
dtype: torch.dtype, dtype: torch.dtype,
seq_length: int, seq_length: int,
...@@ -85,6 +41,8 @@ def test_fused_rope( ...@@ -85,6 +41,8 @@ def test_fused_rope(
transpose: Union[Tuple, None], transpose: Union[Tuple, None],
tensor_format: str, tensor_format: str,
loss_func: Callable, loss_func: Callable,
cp_size: int,
interleaved: bool,
) -> None: ) -> None:
device = torch.device("cuda:0") device = torch.device("cuda:0")
batch_size, head_num = 2, 64 batch_size, head_num = 2, 64
...@@ -99,14 +57,22 @@ def test_fused_rope( ...@@ -99,14 +57,22 @@ def test_fused_rope(
t = t.transpose(*transpose).contiguous().transpose(*transpose) t = t.transpose(*transpose).contiguous().transpose(*transpose)
t.requires_grad = True t.requires_grad = True
rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent) rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved)
emb = rotary_pos_emb(seq_length) emb = rotary_pos_emb(seq_length * cp_size)
assert emb.is_contiguous()
for cp_rank in range(cp_size):
# unfused # unfused
# The fused kernel computes in float32 internally, so we force the unfused func to use float32 # The fused kernel computes in float32 internally, so we force the unfused func to use float32
# for more accurate comparison # for more accurate comparison
output_unfused = apply_rotary_pos_emb( output_unfused = apply_rotary_pos_emb(
t.float(), emb, tensor_format=tensor_format, fused=False t.float(),
emb,
tensor_format=tensor_format,
interleaved=interleaved,
fused=False,
cp_size=cp_size,
cp_rank=cp_rank,
).to(dtype) ).to(dtype)
loss_unfused = loss_func(output_unfused) loss_unfused = loss_func(output_unfused)
loss_unfused.backward() loss_unfused.backward()
...@@ -118,7 +84,10 @@ def test_fused_rope( ...@@ -118,7 +84,10 @@ def test_fused_rope(
t, t,
emb, emb,
tensor_format=tensor_format, tensor_format=tensor_format,
interleaved=interleaved,
fused=True, fused=True,
cp_size=cp_size,
cp_rank=cp_rank,
) )
loss_fused = loss_func(output_fused) loss_fused = loss_func(output_fused)
loss_fused.backward() loss_fused.backward()
...@@ -135,7 +104,8 @@ def test_fused_rope( ...@@ -135,7 +104,8 @@ def test_fused_rope(
@pytest.mark.parametrize("rotary_percent", [0.5, 1.0]) @pytest.mark.parametrize("rotary_percent", [0.5, 1.0])
@pytest.mark.parametrize("transpose", [None, (1, 2)]) @pytest.mark.parametrize("transpose", [None, (1, 2)])
@pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad]) @pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad])
@pytest.mark.parametrize("cp_size", [1, 2, 3]) @pytest.mark.parametrize("cp_size", [1, 2])
@pytest.mark.parametrize("interleaved", [True, False])
def test_fused_rope_thd( def test_fused_rope_thd(
dtype: torch.dtype, dtype: torch.dtype,
hidden_size: int, hidden_size: int,
...@@ -143,6 +113,7 @@ def test_fused_rope_thd( ...@@ -143,6 +113,7 @@ def test_fused_rope_thd(
transpose: Union[Tuple, None], transpose: Union[Tuple, None],
loss_func: Callable, loss_func: Callable,
cp_size: int, cp_size: int,
interleaved: bool,
) -> None: ) -> None:
device = torch.device("cuda:0") device = torch.device("cuda:0")
batch_size, head_num = 2, 64 batch_size, head_num = 2, 64
...@@ -170,15 +141,23 @@ def test_fused_rope_thd( ...@@ -170,15 +141,23 @@ def test_fused_rope_thd(
t = t.transpose(*transpose).contiguous().transpose(*transpose) t = t.transpose(*transpose).contiguous().transpose(*transpose)
t.requires_grad = True t.requires_grad = True
rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent) rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved)
emb = rotary_pos_emb(cu_seqlens_padded[-1]) emb = rotary_pos_emb(cu_seqlens_padded[-1])
assert emb.is_contiguous()
for cp_rank in range(cp_size): for cp_rank in range(cp_size):
# unfused # unfused
# The fused kernel computes in float32 internally, so we force the unfused func to use float32 # The fused kernel computes in float32 internally, so we force the unfused func to use float32
# for more accurate comparison # for more accurate comparison
output_unfused = apply_rotary_pos_emb_thd( output_unfused = apply_rotary_pos_emb(
t.float(), cu_seqlens_padded, emb, cp_size, cp_rank t.float(),
emb,
tensor_format="thd",
interleaved=interleaved,
fused=False,
cu_seqlens=cu_seqlens_padded,
cp_size=cp_size,
cp_rank=cp_rank,
).to(dtype) ).to(dtype)
loss_unfused = loss_func(output_unfused) loss_unfused = loss_func(output_unfused)
loss_unfused.backward() loss_unfused.backward()
...@@ -189,6 +168,7 @@ def test_fused_rope_thd( ...@@ -189,6 +168,7 @@ def test_fused_rope_thd(
output_fused = apply_rotary_pos_emb( output_fused = apply_rotary_pos_emb(
t, t,
emb, emb,
interleaved=interleaved,
fused=True, fused=True,
tensor_format="thd", tensor_format="thd",
cu_seqlens=cu_seqlens_padded, cu_seqlens=cu_seqlens_padded,
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#ifndef TRANSFORMER_ENGINE_FUSED_ROPE_H_ #ifndef TRANSFORMER_ENGINE_FUSED_ROPE_H_
#define TRANSFORMER_ENGINE_FUSED_ROPE_H_ #define TRANSFORMER_ENGINE_FUSED_ROPE_H_
#include "fused_attn.h"
#include "transformer_engine.h" #include "transformer_engine.h"
#ifdef __cplusplus #ifdef __cplusplus
...@@ -16,112 +17,63 @@ extern "C" { ...@@ -16,112 +17,63 @@ extern "C" {
/*! \brief Apply rotary positional embedding to the input tensor. /*! \brief Apply rotary positional embedding to the input tensor.
* *
* \param[in] input Input tensor for fused rope. * \param[in] input Input tensor for fused rope.
* \param[in] cu_seqlens The cumulative sum of sequence lengths tensor.
* (Required for the thd format, empty tensor for other formats)
* \param[in] freqs The freqs tensor. * \param[in] freqs The freqs tensor.
* \param[out] output Output tensor. * \param[out] output Output tensor.
* \param[in] qkv_format QKV format.
* \param[in] interleaved Whether to use interleaved rotary position embedding.
* \param[in] cp_size Context parallel world size.
* \param[in] cp_rank Context parallel rank.
* \param[in] s Length of the s dimension of input. * \param[in] s Length of the s dimension of input.
* \param[in] b Length of the b dimension of input. * \param[in] b Length of the b dimension of input.
* \param[in] h Length of the h dimension of input. * \param[in] h Length of the h dimension of input.
* \param[in] d Length of the d dimension of input. * \param[in] d Length of the d dimension of input.
* \param[in] d2 Length of the d dimension of freqs. * \param[in] d2 Length of the d dimension of freqs.
* \param[in] stride_s Stride of the s dimension of input. * \param[in] stride_s_or_t Stride of the s (sbhd/bshd)/t (thd) dimension of input.
* \param[in] stride_b Stride of the b dimension of input. * \param[in] stride_b Stride of the b dimension of input. (0 for thd).
* \param[in] stride_h Stride of the h dimension of input. * \param[in] stride_h Stride of the h dimension of input.
* \param[in] stride_d Stride of the d dimension of input. * \param[in] stride_d Stride of the d dimension of input.
* \param[in] o_stride_s Stride of the s dimension of output.
* \param[in] o_stride_b Stride of the b dimension of output.
* \param[in] o_stride_h Stride of the h dimension of output.
* \param[in] o_stride_d Stride of the d dimension of output.
* \param[in] stream CUDA stream used for the operation. * \param[in] stream CUDA stream used for the operation.
*/ */
void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor freqs, NVTETensor output, void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens,
const int s, const int b, const int h, const int d, const int d2, const NVTETensor freqs, NVTETensor output,
const int stride_s, const int stride_b, const int stride_h, const NVTE_QKV_Format qkv_format, const bool interleaved,
const int stride_d, const int o_stride_s, const int o_stride_b, const int cp_size, const int cp_rank, const int s, const int b,
const int o_stride_h, const int o_stride_d, cudaStream_t stream); const int h, const int d, const int d2, const int stride_s_or_t,
const int stride_b, const int stride_h, const int stride_d,
cudaStream_t stream);
/*! \brief Compute the backward of the fused rope. /*! \brief Compute the backward of the fused rope.
* *
* \param[in] output_grads Incoming gradient tensor for backward. * \param[in] output_grads Incoming gradient tensor for backward.
* \param[in] cu_seqlens The cumulative sum of sequence lengths tensor.
* (Required for the thd format, empty tensor for other formats)
* \param[in] freqs The freqs tensor. * \param[in] freqs The freqs tensor.
* \param[out] input_grads Input gradient tensor to calculate. * \param[out] input_grads Input gradient tensor to calculate.
* \param[in] qkv_format QKV format.
* \param[in] interleaved Whether to use interleaved rotary position embedding.
* \param[in] cp_size Context parallel world size.
* \param[in] cp_rank Context parallel rank.
* \param[in] s Length of the s dimension of output_grads. * \param[in] s Length of the s dimension of output_grads.
* \param[in] b Length of the b dimension of output_grads. * \param[in] b Length of the b dimension of output_grads.
* \param[in] h Length of the h dimension of output_grads. * \param[in] h Length of the h dimension of output_grads.
* \param[in] d Length of the d dimension of output_grads. * \param[in] d Length of the d dimension of output_grads.
* \param[in] d2 Length of the d dimension of freqs. * \param[in] d2 Length of the d dimension of freqs.
* \param[in] stride_s Stride of the s dimension of output_grads. * \param[in] stride_s_or_t Stride of the s (sbhd/bshd)/t (thd) dimension of output_grads.
* \param[in] stride_b Stride of the b dimension of output_grads. * \param[in] stride_b Stride of the b dimension of output_grads. (0 for thd).
* \param[in] stride_h Stride of the h dimension of output_grads. * \param[in] stride_h Stride of the h dimension of output_grads.
* \param[in] stride_d Stride of the d dimension of output_grads. * \param[in] stride_d Stride of the d dimension of output_grads.
* \param[in] o_stride_s Stride of the s dimension of input_grads.
* \param[in] o_stride_b Stride of the b dimension of input_grads.
* \param[in] o_stride_h Stride of the h dimension of input_grads.
* \param[in] o_stride_d Stride of the d dimension of input_grads.
* \param[in] stream CUDA stream used for the operation. * \param[in] stream CUDA stream used for the operation.
*/ */
void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor freqs, void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens,
NVTETensor input_grads, const int s, const int b, const int h, const NVTETensor freqs, NVTETensor input_grads,
const int d, const int d2, const int stride_s, const int stride_b, const NVTE_QKV_Format qkv_format, const bool interleaved,
const int stride_h, const int stride_d, const int o_stride_s, const int cp_size, const int cp_rank, const int s, const int b,
const int o_stride_b, const int o_stride_h, const int o_stride_d, const int h, const int d, const int d2, const int stride_s_or_t,
const int stride_b, const int stride_h, const int stride_d,
cudaStream_t stream); cudaStream_t stream);
/*! \brief Apply rotary positional embedding to the input tensor in thd format.
*
* \param[in] input Input tensor for fused rope.
* \param[in] cu_seqlens The cumulative sum of sequence lengths tensor.
* \param[in] freqs The freqs tensor.
* \param[out] output Output tensor.
* \param[in] cp_size Context parallel world size.
* \param[in] cp_rank Context parallel rank.
* \param[in] max_s Max sequence length.
* \param[in] b Batch size.
* \param[in] h Length of the h dimension of input.
* \param[in] d Length of the d dimension of input.
* \param[in] d2 Length of the d dimension of freqs.
* \param[in] stride_t Stride of the t dimension of input.
* \param[in] stride_h Stride of the h dimension of input.
* \param[in] stride_d Stride of the d dimension of input.
* \param[in] o_stride_t Stride of the t dimension of output.
* \param[in] o_stride_h Stride of the h dimension of output.
* \param[in] o_stride_d Stride of the d dimension of output.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fused_rope_thd_forward(const NVTETensor input, const NVTETensor cu_seqlens,
const NVTETensor freqs, NVTETensor output, const int cp_size,
const int cp_rank, const int max_s, const int b, const int h,
const int d, const int d2, const int stride_t, const int stride_h,
const int stride_d, const int o_stride_t, const int o_stride_h,
const int o_stride_d, cudaStream_t stream);
/*! \brief Compute the backward of the fused rope in thd format.
*
* \param[in] output_grads Incoming gradient tensor for backward.
* \param[in] cu_seqlens The cumulative sum of sequence lengths tensor.
* \param[in] freqs The freqs tensor.
* \param[out] input_grads Input gradient to calculate.
* \param[in] cp_size Context parallel world size.
* \param[in] cp_rank Context parallel rank.
* \param[in] max_s Max sequence length.
* \param[in] b Batch size.
* \param[in] h Length of the h dimension of output_grads.
* \param[in] d Length of the d dimension of output_grads.
* \param[in] d2 Length of the d dimension of freqs.
* \param[in] stride_t Stride of the t dimension of output_grads.
* \param[in] stride_h Stride of the h dimension of output_grads.
* \param[in] stride_d Stride of the d dimension of output_grads.
* \param[in] o_stride_t Stride of the t dimension of input_grads.
* \param[in] o_stride_h Stride of the h dimension of input_grads.
* \param[in] o_stride_d Stride of the d dimension of input_grads.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fused_rope_thd_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens,
const NVTETensor freqs, NVTETensor input_grads, const int cp_size,
const int cp_rank, const int max_s, const int b, const int h,
const int d, const int d2, const int stride_t, const int stride_h,
const int stride_d, const int o_stride_t, const int o_stride_h,
const int o_stride_d, cudaStream_t stream);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
#endif #endif
......
...@@ -265,16 +265,14 @@ void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reductio ...@@ -265,16 +265,14 @@ void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reductio
**************************************************************************************************/ **************************************************************************************************/
at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs,
const bool transpose_output_memory); const NVTE_QKV_Format qkv_format, const bool interleaved,
const c10::optional<at::Tensor> cu_seqlens, const int cp_size,
const int cp_rank);
at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs, at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs,
const bool transpose_output_memory); const NVTE_QKV_Format qkv_format, const bool interleaved,
const c10::optional<at::Tensor> cu_seqlens, const int cp_size,
at::Tensor fused_rope_thd_forward(const at::Tensor &input, const at::Tensor &cu_seqlens, const int cp_rank);
const at::Tensor &freqs, const int cp_size, const int cp_rank);
at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads, const at::Tensor &cu_seqlens,
const at::Tensor &freqs, const int cp_size, const int cp_rank);
/*************************************************************************************************** /***************************************************************************************************
* Miscellaneous * Miscellaneous
......
...@@ -7,138 +7,38 @@ ...@@ -7,138 +7,38 @@
#include "extensions.h" #include "extensions.h"
at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs,
const bool transpose_output_memory) { const NVTE_QKV_Format qkv_format, const bool interleaved,
const c10::optional<at::Tensor> cu_seqlens, const int cp_size,
const int cp_rank) {
using namespace transformer_engine::pytorch; using namespace transformer_engine::pytorch;
TORCH_CHECK(input.dim() == 4, "expected 4D tensor");
TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor"); TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor");
TORCH_CHECK(input.size(0) <= freqs.size(0),
"expected freqs tensor has a longer sequence length than input");
TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1, TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1,
"expected the second and third dims of the freqs tensor equal 1"); "expected the second and third dims of the freqs tensor equal 1");
TORCH_CHECK(input.size(3) >= freqs.size(3),
"expected the last dim of the input tensor equals or is "
"greater than the freqs tensor");
TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float, TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float,
"Dtype of the freqs tensor must be float"); "Dtype of the freqs tensor must be float");
// input sizes: (s, b, h, d)
// s: sequence length
// b: batch size
// h: head num
// d: dim of each head
const int s = input.size(0);
const int b = input.size(1);
const int h = input.size(2);
const int d = input.size(3);
// input strides
const int stride_s = input.stride(0);
const int stride_b = input.stride(1);
const int stride_h = input.stride(2);
const int stride_d = input.stride(3);
// freqs' shape is always (s, 1, 1, d2), so the strides are same under
// different memory formats
const int d2 = freqs.size(3);
// output // output
auto act_options = input.options().requires_grad(false); auto act_options = at::TensorOptions().dtype(input.scalar_type()).device(input.device());
at::Tensor output; auto output = at::empty(input.sizes(), act_options);
if (transpose_output_memory) {
output = torch::empty({b, s, h, d}, act_options).transpose(0, 1);
} else {
output = torch::empty({s, b, h, d}, act_options);
}
// output strides
const int o_stride_s = output.stride(0);
const int o_stride_b = output.stride(1);
const int o_stride_h = output.stride(2);
const int o_stride_d = output.stride(3);
auto input_cu = makeTransformerEngineTensor(input); auto input_cu = makeTransformerEngineTensor(input);
auto freqs_cu = makeTransformerEngineTensor(freqs); auto freqs_cu = makeTransformerEngineTensor(freqs);
auto output_cu = makeTransformerEngineTensor(output); auto output_cu = makeTransformerEngineTensor(output);
nvte_fused_rope_forward(input_cu.data(), freqs_cu.data(), output_cu.data(), s, b, h, d, d2, if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b,
o_stride_h, o_stride_d, at::cuda::getCurrentCUDAStream());
return output;
}
at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs,
const bool transpose_output_memory) {
using namespace transformer_engine::pytorch;
TORCH_CHECK(output_grads.dim() == 4, "expected 4D tensor");
TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor");
TORCH_CHECK(output_grads.size(0) <= freqs.size(0),
"expected freqs tensor has a longer sequence length than output_grads");
TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1,
"expected the second and third dims of the freqs tensor equal 1");
TORCH_CHECK(output_grads.size(3) >= freqs.size(3),
"expected the last dim of the output_grads tensor equals or is "
"greater than the freqs tensor");
TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float,
"Dtype of the freqs tensor must be float");
// output_grads sizes: (s, b, h, d)
// s: sequence length
// b: batch size
// h: head num
// d: dim of each head
const int s = output_grads.size(0);
const int b = output_grads.size(1);
const int h = output_grads.size(2);
const int d = output_grads.size(3);
// output_grads strides
const int stride_s = output_grads.stride(0);
const int stride_b = output_grads.stride(1);
const int stride_h = output_grads.stride(2);
const int stride_d = output_grads.stride(3);
// freqs' shape is always (s, 1, 1, d2), so the strides are same under
// different memory formats
const int d2 = freqs.size(3);
auto act_options = output_grads.options().requires_grad(false);
at::Tensor input_grads;
if (transpose_output_memory) {
input_grads = torch::empty({b, s, h, d}, act_options).transpose(0, 1);
} else {
input_grads = torch::empty({s, b, h, d}, act_options);
}
const int o_stride_s = input_grads.stride(0);
const int o_stride_b = input_grads.stride(1);
const int o_stride_h = input_grads.stride(2);
const int o_stride_d = input_grads.stride(3);
auto output_grads_cu = makeTransformerEngineTensor(output_grads);
auto freqs_cu = makeTransformerEngineTensor(freqs);
auto input_grads_cu = makeTransformerEngineTensor(input_grads);
nvte_fused_rope_backward(output_grads_cu.data(), freqs_cu.data(), input_grads_cu.data(), s, b, h,
d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b,
o_stride_h, o_stride_d, at::cuda::getCurrentCUDAStream());
return input_grads;
}
at::Tensor fused_rope_thd_forward(const at::Tensor &input, const at::Tensor &cu_seqlens,
const at::Tensor &freqs, const int cp_size, const int cp_rank) {
using namespace transformer_engine::pytorch;
TORCH_CHECK(input.dim() == 3, "expected 3D tensor"); TORCH_CHECK(input.dim() == 3, "expected 3D tensor");
TORCH_CHECK(cu_seqlens.dim() == 1, "expected 1D tensor"); TORCH_CHECK(cu_seqlens.has_value(), "expected cu_seqlens tensor");
TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor"); TORCH_CHECK(cu_seqlens.value().dim() == 1, "expected 1D tensor");
TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1,
"expected the second and third dims of the freqs tensor equal 1");
TORCH_CHECK(input.size(2) >= freqs.size(3), TORCH_CHECK(input.size(2) >= freqs.size(3),
"expected the last dim of the input tensor equals or is " "expected the last dim of the input tensor equals or is "
"greater than the freqs tensor"); "greater than the freqs tensor");
TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float,
"Dtype of the freqs tensor must be float");
// input sizes: (t, h, d) // input sizes: (t, h, d)
// t: cumulative sum of sequence lengths // t: cumulative sum of sequence lengths
// h: head num // h: head num
// d: dim of each head // d: dim of each head
const int t = input.size(0); // const int t = input.size(0);
const int h = input.size(1); const int h = input.size(1);
const int d = input.size(2); const int d = input.size(2);
// input strides // input strides
...@@ -146,51 +46,86 @@ at::Tensor fused_rope_thd_forward(const at::Tensor &input, const at::Tensor &cu_ ...@@ -146,51 +46,86 @@ at::Tensor fused_rope_thd_forward(const at::Tensor &input, const at::Tensor &cu_
const int stride_h = input.stride(1); const int stride_h = input.stride(1);
const int stride_d = input.stride(2); const int stride_d = input.stride(2);
// batch size // batch size
const int b = cu_seqlens.size(0) - 1; const int b = cu_seqlens.value().size(0) - 1;
// freqs' shape is (max_s, 1, 1, d2) // freqs' shape is (max_s, 1, 1, d2)
const int max_s = freqs.size(0); const int max_s = freqs.size(0);
const int d2 = freqs.size(3); const int d2 = freqs.size(3);
// output auto cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens.value());
auto act_options = input.options().requires_grad(false);
auto output = torch::empty({t, h, d}, act_options);
// output strides
const int o_stride_t = output.stride(0);
const int o_stride_h = output.stride(1);
const int o_stride_d = output.stride(2);
auto input_cu = makeTransformerEngineTensor(input); nvte_fused_rope_forward(input_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(),
auto cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens); output_cu.data(), qkv_format, interleaved, cp_size, cp_rank, max_s, b,
auto freqs_cu = makeTransformerEngineTensor(freqs); h, d, d2, stride_t, /*stride_b=*/0, stride_h, stride_d,
auto output_cu = makeTransformerEngineTensor(output);
nvte_fused_rope_thd_forward(input_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(),
output_cu.data(), cp_size, cp_rank, max_s, b, h, d, d2, stride_t,
stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d,
at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentCUDAStream());
return output;
}
TORCH_CHECK(input.dim() == 4, "expected 4D tensor");
// input sizes: (s, b, h, d) or (b, s, h, d)
// s: sequence length
// b: batch size
// h: head num
// d: dim of each head
const int s = qkv_format == NVTE_QKV_Format::NVTE_SBHD ? input.size(0) : input.size(1);
const int b = qkv_format == NVTE_QKV_Format::NVTE_SBHD ? input.size(1) : input.size(0);
const int h = input.size(2);
const int d = input.size(3);
// input strides
const int stride_s = qkv_format == NVTE_QKV_Format::NVTE_SBHD ? input.stride(0) : input.stride(1);
const int stride_b = qkv_format == NVTE_QKV_Format::NVTE_SBHD ? input.stride(1) : input.stride(0);
const int stride_h = input.stride(2);
const int stride_d = input.stride(3);
// freqs' shape is always (s, 1, 1, d2), so the strides are same under
// different memory formats
const int d2 = freqs.size(3);
TORCH_CHECK(s * cp_size <= freqs.size(0),
"expected freqs tensor has a longer sequence length than input");
TORCH_CHECK(d >= d2,
"expected the last dim of the input tensor equals or is "
"greater than the freqs tensor");
auto cu_seqlens_cu = transformer_engine::TensorWrapper(); // empty cu_seqlens tensor
nvte_fused_rope_forward(input_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(), output_cu.data(),
qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s,
stride_b, stride_h, stride_d, at::cuda::getCurrentCUDAStream());
return output; return output;
} }
at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads, const at::Tensor &cu_seqlens, at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs,
const at::Tensor &freqs, const int cp_size, const int cp_rank) { const NVTE_QKV_Format qkv_format, const bool interleaved,
const c10::optional<at::Tensor> cu_seqlens, const int cp_size,
const int cp_rank) {
using namespace transformer_engine::pytorch; using namespace transformer_engine::pytorch;
TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor");
TORCH_CHECK(cu_seqlens.dim() == 1, "expected 1D tensor");
TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor"); TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor");
TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1, TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1,
"expected the second and third dims of the freqs tensor equal 1"); "expected the second and third dims of the freqs tensor equal 1");
TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float,
"Dtype of the freqs tensor must be float");
auto act_options =
at::TensorOptions().dtype(output_grads.scalar_type()).device(output_grads.device());
auto input_grads = at::empty(output_grads.sizes(), act_options);
auto output_grads_cu = makeTransformerEngineTensor(output_grads);
auto freqs_cu = makeTransformerEngineTensor(freqs);
auto input_grads_cu = makeTransformerEngineTensor(input_grads);
if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor");
TORCH_CHECK(cu_seqlens.has_value(), "expected cu_seqlens tensor");
TORCH_CHECK(cu_seqlens.value().dim() == 1, "expected 1D tensor");
TORCH_CHECK(output_grads.size(2) >= freqs.size(3), TORCH_CHECK(output_grads.size(2) >= freqs.size(3),
"expected the last dim of the output_grads tensor equals or is " "expected the last dim of the output_grads tensor equals or is "
"greater than the freqs tensor"); "greater than the freqs tensor");
TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float,
"Dtype of the freqs tensor must be float");
// output_grads sizes: (t, h, d) // output_grads sizes: (t, h, d)
// t: cumulative sum of sequence lengths // t: cumulative sum of sequence lengths
// h: head num // h: head num
// d: dim of each head // d: dim of each head
const int t = output_grads.size(0); // const int t = output_grads.size(0);
const int h = output_grads.size(1); const int h = output_grads.size(1);
const int d = output_grads.size(2); const int d = output_grads.size(2);
// output_grads strides // output_grads strides
...@@ -198,25 +133,54 @@ at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads, const at::Ten ...@@ -198,25 +133,54 @@ at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads, const at::Ten
const int stride_h = output_grads.stride(1); const int stride_h = output_grads.stride(1);
const int stride_d = output_grads.stride(2); const int stride_d = output_grads.stride(2);
// batch size // batch size
const int b = cu_seqlens.size(0) - 1; const int b = cu_seqlens.value().size(0) - 1;
// freqs' shape is (max_s, 1, 1, d2) // freqs' shape is (max_s, 1, 1, d2)
const int max_s = freqs.size(0); const int max_s = freqs.size(0);
const int d2 = freqs.size(3); const int d2 = freqs.size(3);
auto act_options = output_grads.options().requires_grad(false); auto cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens.value());
auto input_grads = torch::empty({t, h, d}, act_options);
const int o_stride_t = input_grads.stride(0);
const int o_stride_h = input_grads.stride(1);
const int o_stride_d = input_grads.stride(2);
auto output_grads_cu = makeTransformerEngineTensor(output_grads); nvte_fused_rope_backward(output_grads_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(),
auto cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens); input_grads_cu.data(), qkv_format, interleaved, cp_size, cp_rank,
auto freqs_cu = makeTransformerEngineTensor(freqs); max_s, b, h, d, d2, stride_t, /*stride_b=*/0, stride_h, stride_d,
auto input_grads_cu = makeTransformerEngineTensor(input_grads); at::cuda::getCurrentCUDAStream());
return input_grads;
}
TORCH_CHECK(output_grads.dim() == 4, "expected 4D tensor");
// output_grads sizes: (s, b, h, d)
// s: sequence length
// b: batch size
// h: head num
// d: dim of each head
const int s =
qkv_format == NVTE_QKV_Format::NVTE_SBHD ? output_grads.size(0) : output_grads.size(1);
const int b =
qkv_format == NVTE_QKV_Format::NVTE_SBHD ? output_grads.size(1) : output_grads.size(0);
const int h = output_grads.size(2);
const int d = output_grads.size(3);
// output_grads strides
const int stride_s =
qkv_format == NVTE_QKV_Format::NVTE_SBHD ? output_grads.stride(0) : output_grads.stride(1);
const int stride_b =
qkv_format == NVTE_QKV_Format::NVTE_SBHD ? output_grads.stride(1) : output_grads.stride(0);
const int stride_h = output_grads.stride(2);
const int stride_d = output_grads.stride(3);
// freqs' shape is always (s, 1, 1, d2), so the strides are same under
// different memory formats
const int d2 = freqs.size(3);
TORCH_CHECK(s * cp_size <= freqs.size(0),
"expected freqs tensor has a longer sequence length than output_grads");
TORCH_CHECK(d >= d2,
"expected the last dim of the output_grads tensor equals or is "
"greater than the freqs tensor");
nvte_fused_rope_thd_backward(output_grads_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(), auto cu_seqlens_cu = transformer_engine::TensorWrapper(); // empty cu_seqlens tensor
input_grads_cu.data(), cp_size, cp_rank, max_s, b, h, d, d2, nvte_fused_rope_backward(output_grads_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(),
stride_t, stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, input_grads_cu.data(), qkv_format, interleaved, cp_size, cp_rank, s, b,
h, d, d2, stride_s, stride_b, stride_h, stride_d,
at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentCUDAStream());
return input_grads; return input_grads;
......
...@@ -229,10 +229,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -229,10 +229,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("fused_rope_backward", &fused_rope_backward, "Fused Apply RoPE BWD", m.def("fused_rope_backward", &fused_rope_backward, "Fused Apply RoPE BWD",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("fused_rope_thd_forward", &fused_rope_thd_forward, "Fused Apply RoPE FWD for thd format",
py::call_guard<py::gil_scoped_release>());
m.def("fused_rope_thd_backward", &fused_rope_thd_backward, "Fused Apply RoPE BWD for thd format",
py::call_guard<py::gil_scoped_release>());
// Misc // Misc
m.def("get_cublasLt_version", &get_cublasLt_version, "Get cublasLt version", m.def("get_cublasLt_version", &get_cublasLt_version, "Get cublasLt version",
......
...@@ -7,7 +7,12 @@ Rotary Position Embedding implementation of different types along with helper fu ...@@ -7,7 +7,12 @@ Rotary Position Embedding implementation of different types along with helper fu
""" """
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.cpp_extensions.fused_attn import QKVFormat
__all__ = ["RotaryPositionEmbedding", "apply_rotary_pos_emb"]
class RotaryPositionEmbedding(torch.nn.Module): class RotaryPositionEmbedding(torch.nn.Module):
...@@ -22,19 +27,24 @@ class RotaryPositionEmbedding(torch.nn.Module): ...@@ -22,19 +27,24 @@ class RotaryPositionEmbedding(torch.nn.Module):
seq_len_interpolation_factor: Optional[int] = None, seq_len_interpolation_factor: Optional[int] = None,
pretrained_max_position_embeddings: Optional[int] = None, pretrained_max_position_embeddings: Optional[int] = None,
rotary_base: float = 10000.0, rotary_base: float = 10000.0,
interleaved: bool = False,
): ):
""" """
Parameters Parameters
---------- ----------
dim: int dim: int
rotary embedding dimension Rotary embedding dimension.
rotary_percent: float rotary_percent: float, default = 1.0
Percent of rotary dimension to use for rotary position embeddings. Percent of rotary dimension to use for rotary position embeddings.
seq_len_interpolation_factor: int seq_len_interpolation_factor: int, default = None
if not None, discrete positions will be interpolated by this factor via the trick in If not None, discrete positions will be interpolated by this factor via the trick in
https://arxiv.org/abs/2306.15595 https://arxiv.org/abs/2306.15595
pretrained_max_position_embeddings: int pretrained_max_position_embeddings: int, default = None
pre-trained max_position_embeddings before position interpolation Pre-trained max_position_embeddings before position interpolation.
rotary_base: float, default = 10000.0
Base of the rotary position embedding.
interleaved: bool, default = False
Whether to use interleaved rotary position embedding.
""" """
super().__init__() super().__init__()
if rotary_percent < 1.0: if rotary_percent < 1.0:
...@@ -50,17 +60,18 @@ class RotaryPositionEmbedding(torch.nn.Module): ...@@ -50,17 +60,18 @@ class RotaryPositionEmbedding(torch.nn.Module):
) )
self.register_buffer("inv_freq", inv_freq) self.register_buffer("inv_freq", inv_freq)
self.pretrained_max_position_embeddings = pretrained_max_position_embeddings self.pretrained_max_position_embeddings = pretrained_max_position_embeddings
self.interleaved = interleaved
def forward(self, max_seq_len: int, offset: int = 0): def forward(self, max_seq_len: int, offset: int = 0):
""" """
Create rotary position embedding frequencies Create rotary position embedding frequencies.
Parameters Parameters
---------- ----------
max_seq_len: int max_seq_len: int
sequence length of a sample Sequence length of a sample.
offset: int, default = 0 offset: int, default = 0
fixed offset for freqencies Fixed offset for frequencies.
""" """
seq = ( seq = (
torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype) torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
...@@ -84,7 +95,12 @@ class RotaryPositionEmbedding(torch.nn.Module): ...@@ -84,7 +95,12 @@ class RotaryPositionEmbedding(torch.nn.Module):
freqs = torch.einsum("i , j -> i j", seq, self.inv_freq) freqs = torch.einsum("i , j -> i j", seq, self.inv_freq)
# first part even vector components, second part odd vector components, # first part even vector components, second part odd vector components,
# 2 * dim in dimension size # 2 * dim in dimension size
if not self.interleaved:
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
else:
emb = torch.stack((freqs.view(-1, 1), freqs.view(-1, 1)), dim=-1).view(
freqs.shape[0], -1
)
# emb [seq_length, .., dim] # emb [seq_length, .., dim]
return emb.reshape(emb.size(0), 1, 1, emb.size(1)) return emb.reshape(emb.size(0), 1, 1, emb.size(1))
...@@ -104,61 +120,146 @@ class FusedRoPEFunc(torch.autograd.Function): ...@@ -104,61 +120,146 @@ class FusedRoPEFunc(torch.autograd.Function):
t: torch.Tensor, t: torch.Tensor,
freqs: torch.Tensor, freqs: torch.Tensor,
tensor_format: str = "sbhd", tensor_format: str = "sbhd",
interleaved: bool = False,
cu_seqlens: Union[torch.Tensor, None] = None, cu_seqlens: Union[torch.Tensor, None] = None,
cp_size: int = 1, cp_size: int = 1,
cp_rank: int = 0, cp_rank: int = 0,
) -> torch.Tensor: ) -> torch.Tensor:
# pylint: disable=missing-function-docstring """Fused RoPE forward."""
if freqs.dtype != torch.float32: if freqs.dtype != torch.float32:
freqs = freqs.float() freqs = freqs.float()
if tensor_format == "sbhd": assert tensor_format in (
output = tex.fused_rope_forward(t, freqs, False) "sbhd",
elif tensor_format == "bshd": "bshd",
output = tex.fused_rope_forward(t.transpose(0, 1), freqs, True).transpose(0, 1) "thd",
elif tensor_format == "thd": ), f"Unsupported tensor_format: {tensor_format}."
output = tex.fused_rope_thd_forward(t, cu_seqlens, freqs, cp_size, cp_rank) output = tex.fused_rope_forward(
else: t, freqs, QKVFormat[tensor_format], interleaved, cu_seqlens, cp_size, cp_rank
raise ValueError(f"Unsupported tensor_format: {tensor_format}.") )
ctx.save_for_backward(freqs, cu_seqlens) ctx.save_for_backward(freqs, cu_seqlens)
ctx.tensor_format = tensor_format ctx.tensor_format = tensor_format
ctx.cp_size = cp_size ctx.cp_size = cp_size
ctx.cp_rank = cp_rank ctx.cp_rank = cp_rank
ctx.interleaved = interleaved
return output return output
@staticmethod @staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
# pylint: disable=missing-function-docstring """Fused RoPE backward."""
freqs, cu_seqlens = ctx.saved_tensors freqs, cu_seqlens = ctx.saved_tensors
if ctx.tensor_format == "sbhd":
grad_input = tex.fused_rope_backward(grad_output, freqs, False)
elif ctx.tensor_format == "bshd":
grad_input = tex.fused_rope_backward( grad_input = tex.fused_rope_backward(
grad_output.transpose(0, 1), freqs, True grad_output,
).transpose(0, 1) freqs,
elif ctx.tensor_format == "thd": QKVFormat[ctx.tensor_format],
grad_input = tex.fused_rope_thd_backward( ctx.interleaved,
grad_output, cu_seqlens, freqs, ctx.cp_size, ctx.cp_rank cu_seqlens,
ctx.cp_size,
ctx.cp_rank,
) )
else:
raise ValueError(f"Unsupported tensor_format: {ctx.tensor_format}.")
return grad_input, None, None, None, None, None return grad_input, None, None, None, None, None, None
def _rotate_half(x: torch.Tensor) -> torch.Tensor: def _rotate_half(x: torch.Tensor, interleaved: bool) -> torch.Tensor:
""" """Change sign so the last dimension becomes [-odd, +even]
change sign so the last dimension becomes [-odd, +even]
Args:
x: torch.Tensor. Input tensor.
interleaved: bool. Whether to use interleaved rotary position embedding.
Returns:
Tensor: Tensor rotated half.
""" """
x = x.view(x.shape[:-1] + torch.Size((2, x.shape[-1] // 2))) if not interleaved:
x1, x2 = x.unbind(dim=-2) x1, x2 = torch.chunk(x, 2, dim=-1)
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
# interleaved
x1 = x[:, :, :, ::2]
x2 = x[:, :, :, 1::2]
x_new = torch.stack((-x2, x1), dim=-1)
return x_new.view(x_new.shape[0], x_new.shape[1], x_new.shape[2], -1)
def _apply_rotary_pos_emb_base(
t: torch.Tensor,
freqs: torch.Tensor,
tensor_format: str = "sbhd",
interleaved: bool = False,
) -> torch.Tensor:
"""
Base implementation of applying rotary positional embedding tensor to the input tensor.
Parameters
----------
t: torch.Tensor
Input tensor of shape `[s, b, h, d]` or `[b, s, h, d]`, on which rotary positional
embedding will be applied.
freqs: torch.Tensor
Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
with `s2 >= s` and `d2 <= d`.
tensor_format: {'sbhd', 'bshd'}, default = 'sbhd'
Should be `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is of shape
`[seq, bs, ...]`.
interleaved: bool, default = False
Whether to use interleaved rotary position embedding.
"""
max_seq_len = freqs.shape[0]
cur_seq_len = t.shape[1] if tensor_format == "bshd" else t.shape[0]
# Only apply the rotary embeddings up to the sequence length of the running
# input.
assert (
cur_seq_len <= max_seq_len
), f"Rotary Embeddings only supported up to {max_seq_len} sequence length!"
freqs = freqs[:cur_seq_len]
if tensor_format == "bshd":
freqs = freqs.transpose(0, 1) # [seq, 1, 1, dim] -> [1, seq, 1, dim]
# cos/sin first then dtype conversion for better precision
cos_ = torch.cos(freqs).to(t.dtype)
sin_ = torch.sin(freqs).to(t.dtype)
rot_dim = freqs.shape[-1]
# ideally t_pass is empty so rotary pos embedding is applied to all tensor t
t, t_pass = t[..., :rot_dim], t[..., rot_dim:]
# first part is cosine component
# second part is sine component, need to change signs with _rotate_half method
t = (t * cos_) + (_rotate_half(t, interleaved) * sin_)
return torch.cat((t, t_pass), dim=-1)
def _get_freqs_on_this_cp_rank(
freqs: torch.Tensor, seqlen: int, cp_size: int, cp_rank: int
) -> torch.Tensor:
"""Get the position embedding on the current context parallel rank.
Args:
freqs: torch.Tensor. Positional embedding tensor in shape `[s2, 1, 1, d2]`.
seqlen: int. Length of the current sequence.
cp_size: int. Context parallel world size.
cp_rank: int. Context parallel rank.
"""
if cp_size > 1:
cp_seg = seqlen // 2
full_seqlen = cp_size * seqlen
return torch.cat(
[
freqs[cp_rank * cp_seg : (cp_rank + 1) * cp_seg],
freqs[full_seqlen - (cp_rank + 1) * cp_seg : full_seqlen - cp_rank * cp_seg],
]
)
# cp_size == 1
return freqs[:seqlen]
def apply_rotary_pos_emb( def apply_rotary_pos_emb(
t: torch.Tensor, t: torch.Tensor,
freqs: torch.Tensor, freqs: torch.Tensor,
tensor_format: str = "sbhd", tensor_format: str = "sbhd",
interleaved: bool = False,
fused: bool = False, fused: bool = False,
cu_seqlens: Union[torch.Tensor, None] = None, cu_seqlens: Union[torch.Tensor, None] = None,
cp_size: int = 1, cp_size: int = 1,
...@@ -175,11 +276,13 @@ def apply_rotary_pos_emb( ...@@ -175,11 +276,13 @@ def apply_rotary_pos_emb(
freqs: torch.Tensor freqs: torch.Tensor
Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float', Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
with `s2 >= s` and `d2 <= d`. with `s2 >= s` and `d2 <= d`.
fused: bool, default = False
Whether to use a fused applying RoPE implementation.
tensor_format: {'sbhd', 'bshd', 'thd'}, default = 'sbhd' tensor_format: {'sbhd', 'bshd', 'thd'}, default = 'sbhd'
is `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is is `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is
of shape `[seq, bs, ...]`. 'thd' is only supported when `fused` is True. of shape `[seq, bs, ...]`. 'thd' is only supported when `fused` is True.
interleaved: bool, default = False
Whether to use interleaved rotary position embedding.
fused: bool, default = False
Whether to use a fused applying RoPE implementation.
cu_seqlens: torch.Tensor, default = None. cu_seqlens: torch.Tensor, default = None.
Cumulative sum of sequence lengths in a batch for `t`, with shape [b + 1] and Cumulative sum of sequence lengths in a batch for `t`, with shape [b + 1] and
dtype torch.int32. Only valid when `tensor_format` is 'thd'. dtype torch.int32. Only valid when `tensor_format` is 'thd'.
...@@ -189,37 +292,40 @@ def apply_rotary_pos_emb( ...@@ -189,37 +292,40 @@ def apply_rotary_pos_emb(
cp_rank: int, default = 0. cp_rank: int, default = 0.
Context parallel rank. Only valid when `tensor_format` is 'thd' and `fused` is True. Context parallel rank. Only valid when `tensor_format` is 'thd' and `fused` is True.
""" """
if fused:
assert ( assert (
tensor_format != "thd" or cu_seqlens is not None tensor_format != "thd" or cu_seqlens is not None
), "cu_seqlens must not be None when tensor_format is 'thd'." ), "cu_seqlens must not be None when tensor_format is 'thd'."
return FusedRoPEFunc.apply(t, freqs, tensor_format, cu_seqlens, cp_size, cp_rank)
assert tensor_format in ("sbhd", "bshd"), ( if fused:
"Only formats `sbhd` or `bshd` are supported for input tensor `t` " return FusedRoPEFunc.apply(
f"when fused is False, got {tensor_format}." t, freqs, tensor_format, interleaved, cu_seqlens, cp_size, cp_rank
) )
max_seq_len = freqs.shape[0] # Unfused THD format
cur_seq_len = t.shape[1] if tensor_format == "bshd" else t.shape[0] if tensor_format == "thd":
cu_seqlens = cu_seqlens // cp_size
# Only apply the rotary embeddings up to the sequence length of the running seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
# input. return torch.cat(
assert ( [
cur_seq_len <= max_seq_len _apply_rotary_pos_emb_base(
), f"Rotary Embeddings only supported up to {max_seq_len} sequence length!" x.unsqueeze(1),
freqs = freqs[:cur_seq_len] _get_freqs_on_this_cp_rank(freqs, x.size(0), cp_size, cp_rank),
if tensor_format == "bshd": interleaved=interleaved,
freqs = freqs.transpose(0, 1) # [seq, 1, 1, dim] -> [1, seq, 1, dim] )
# cos/sin first then dtype conversion for better precision for x in torch.split(t, seqlens)
cos_ = torch.cos(freqs).to(t.dtype) ]
sin_ = torch.sin(freqs).to(t.dtype) ).squeeze(1)
rot_dim = freqs.shape[-1]
# ideally t_pass is empty so rotary pos embedding is applied to all tensor t
t, t_pass = t[..., :rot_dim], t[..., rot_dim:]
# first part is cosine component # Unfused SBHD/BSHD format
# second part is sine component, need to change signs with _rotate_half method if tensor_format == "sbhd":
t = (t * cos_) + (_rotate_half(t) * sin_) seqlen = t.size(0)
return torch.cat((t, t_pass), dim=-1) elif tensor_format == "bshd":
seqlen = t.size(1)
else:
raise ValueError(f"Unsupported tensor_format: {tensor_format}.")
return _apply_rotary_pos_emb_base(
t,
_get_freqs_on_this_cp_rank(freqs, seqlen, cp_size, cp_rank),
tensor_format,
interleaved=interleaved,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment