"wrappers/vscode:/vscode.git/clone" did not exist on "10320b3916b9e728533c5ac2150802afe4d504e1"
communication_asm.h 1.35 KB
Newer Older
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
#pragma once
// SPDX-License-Identifier: MIT
 
torch::Tensor all_reduce_asm(torch::Tensor &input,
                             int64_t _ca,
                             torch::Tensor &reg_sig, torch::Tensor &reg_buffer, bool isGraph);

std::tuple<torch::Tensor, torch::Tensor>       // out, residual_out
all_reduce_rmsnorm(torch::Tensor &input,       // [m ,n]
                   torch::Tensor &residual_in, // [m ,n]
                   torch::Tensor &weight,      // [1 ,n]
                   torch::Tensor &bias,        // [1 ,n]
                   float epsilon,
                   // following are fused_allreduce args
                   int64_t _ca,
                   torch::Tensor &reg_sig, torch::Tensor &reg_buffer, bool isGraph);

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> // out, residual_out, yscale
all_reduce_rmsnorm_quant(torch::Tensor &input,          // [m ,n]
                         torch::Tensor &residual_in,    // [m ,n]
                         torch::Tensor &xscale,         // [1 ,n]
                         torch::Tensor &weight,         // [1 ,n]
                         torch::Tensor &bias,           // [1 ,n]
                         float epsilon,
                         // following are fused_allreduce args
                         int64_t _ca,
                         torch::Tensor &reg_sig, torch::Tensor &reg_buffer, bool isGraph);