Commit f8c2af4c authored by yuguo's avatar yuguo
Browse files

Merge commit '1d903f5e' of...

Merge commit '1d903f5e' of https://github.com/NVIDIA/TransformerEngine
parents e92773a3 1d903f5e
......@@ -20,6 +20,7 @@ extern "C" {
* \param[in] cu_seqlens The cumulative sum of sequence lengths tensor.
* (Required for the thd format, empty tensor for other formats)
* \param[in] freqs The freqs tensor.
* \param[in] start_positions The beginning offsets for applying RoPE embeddings.
* \param[out] output Output tensor.
* \param[in] qkv_format QKV format.
* \param[in] interleaved Whether to use interleaved rotary position embedding.
......@@ -37,12 +38,12 @@ extern "C" {
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens,
const NVTETensor freqs, NVTETensor output,
const NVTE_QKV_Format qkv_format, const bool interleaved,
const int cp_size, const int cp_rank, const int s, const int b,
const int h, const int d, const int d2, const int stride_s_or_t,
const int stride_b, const int stride_h, const int stride_d,
cudaStream_t stream);
const NVTETensor freqs, const NVTETensor start_positions,
NVTETensor output, const NVTE_QKV_Format qkv_format,
const bool interleaved, const int cp_size, const int cp_rank,
const int s, const int b, const int h, const int d, const int d2,
const int stride_s_or_t, const int stride_b, const int stride_h,
const int stride_d, cudaStream_t stream);
/*! \brief Compute the backward of the fused rope.
*
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment