Commit f8c2af4c authored by yuguo's avatar yuguo
Browse files

Merge commit '1d903f5e' of...

Merge commit '1d903f5e' of https://github.com/NVIDIA/TransformerEngine
parents e92773a3 1d903f5e
...@@ -20,6 +20,7 @@ extern "C" { ...@@ -20,6 +20,7 @@ extern "C" {
* \param[in] cu_seqlens The cumulative sum of sequence lengths tensor. * \param[in] cu_seqlens The cumulative sum of sequence lengths tensor.
* (Required for the thd format, empty tensor for other formats) * (Required for the thd format, empty tensor for other formats)
* \param[in] freqs The freqs tensor. * \param[in] freqs The freqs tensor.
* \param[in] start_positions The beginning offsets for applying RoPE embeddings.
* \param[out] output Output tensor. * \param[out] output Output tensor.
* \param[in] qkv_format QKV format. * \param[in] qkv_format QKV format.
* \param[in] interleaved Whether to use interleaved rotary position embedding. * \param[in] interleaved Whether to use interleaved rotary position embedding.
...@@ -37,12 +38,12 @@ extern "C" { ...@@ -37,12 +38,12 @@ extern "C" {
* \param[in] stream CUDA stream used for the operation. * \param[in] stream CUDA stream used for the operation.
*/ */
void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens, void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens,
const NVTETensor freqs, NVTETensor output, const NVTETensor freqs, const NVTETensor start_positions,
const NVTE_QKV_Format qkv_format, const bool interleaved, NVTETensor output, const NVTE_QKV_Format qkv_format,
const int cp_size, const int cp_rank, const int s, const int b, const bool interleaved, const int cp_size, const int cp_rank,
const int h, const int d, const int d2, const int stride_s_or_t, const int s, const int b, const int h, const int d, const int d2,
const int stride_b, const int stride_h, const int stride_d, const int stride_s_or_t, const int stride_b, const int stride_h,
cudaStream_t stream); const int stride_d, cudaStream_t stream);
/*! \brief Compute the backward of the fused rope. /*! \brief Compute the backward of the fused rope.
* *
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment