Commit 6cdb3cda authored by TiagoMAntunes's avatar TiagoMAntunes
Browse files

Fixed indentation (4 spaces now)

parent 303d0e93
...@@ -12,119 +12,119 @@ ...@@ -12,119 +12,119 @@
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<torch::Tensor> moe_expert_count( std::vector<torch::Tensor> moe_expert_count(
torch::Tensor gate, torch::Tensor gate,
size_t num_expert) { size_t num_expert) {
CHECK_INPUT(gate); CHECK_INPUT(gate);
return moe_cuda_expert_count(gate, num_expert); return moe_cuda_expert_count(gate, num_expert);
} }
std::vector<torch::Tensor> moe_local_scatter( std::vector<torch::Tensor> moe_local_scatter(
torch::Tensor input, torch::Tensor input,
torch::Tensor pos) { torch::Tensor pos) {
CHECK_INPUT(input); CHECK_INPUT(input);
return moe_cuda_local_scatter(input, pos); return moe_cuda_local_scatter(input, pos);
} }
std::vector<torch::Tensor> moe_local_gather( std::vector<torch::Tensor> moe_local_gather(
torch::Tensor output_buf, torch::Tensor output_buf,
torch::Tensor pos) { torch::Tensor pos) {
CHECK_INPUT(output_buf); CHECK_INPUT(output_buf);
return moe_cuda_local_gather(output_buf, pos); return moe_cuda_local_gather(output_buf, pos);
} }
void merge_bias(torch::Tensor &input_buf, torch::Tensor &weight, at::optional<torch::Tensor> bias_o) { void merge_bias(torch::Tensor &input_buf, torch::Tensor &weight, at::optional<torch::Tensor> bias_o) {
torch::Tensor bias = bias_o.value(); torch::Tensor bias = bias_o.value();
weight = at::cat({weight, bias.unsqueeze(2)}, 2); // [W b] weight = at::cat({weight, bias.unsqueeze(2)}, 2); // [W b]
auto options = torch::TensorOptions() auto options = torch::TensorOptions()
.device(input_buf.device()) .device(input_buf.device())
.dtype(input_buf.dtype()); .dtype(input_buf.dtype());
auto ones = at::ones(input_buf.size(0), options).unsqueeze(1); auto ones = at::ones(input_buf.size(0), options).unsqueeze(1);
input_buf = at::cat({input_buf, ones}, 1); // [X 1] input_buf = at::cat({input_buf, ones}, 1); // [X 1]
} }
std::vector<torch::Tensor> moe_forward( std::vector<torch::Tensor> moe_forward(
torch::Tensor input_buf, // [batch_size x in_feat] torch::Tensor input_buf, // [batch_size x in_feat]
torch::Tensor expert_count, // [batch_size] torch::Tensor expert_count, // [batch_size]
torch::Tensor weight, // [num_expert x out_feat x in_feat] torch::Tensor weight, // [num_expert x out_feat x in_feat]
at::optional<torch::Tensor> bias_o // [num_expert x out_feat] or None at::optional<torch::Tensor> bias_o // [num_expert x out_feat] or None
) { ) {
// Wx+b = [W b] [x] // Wx+b = [W b] [x]
// [1] // [1]
if (bias_o.has_value()) merge_bias(input_buf, weight, bias_o); if (bias_o.has_value()) merge_bias(input_buf, weight, bias_o);
CHECK_INPUT(input_buf); CHECK_INPUT(input_buf);
CHECK_INPUT(weight); CHECK_INPUT(weight);
return moe_cuda_forward(input_buf, expert_count, weight); return moe_cuda_forward(input_buf, expert_count, weight);
} }
std::vector<torch::Tensor> moe_backward( std::vector<torch::Tensor> moe_backward(
torch::Tensor grad_output_buf, // [batch_size x out_feat] torch::Tensor grad_output_buf, // [batch_size x out_feat]
torch::Tensor input_buf, // [batch_size x in_feat] torch::Tensor input_buf, // [batch_size x in_feat]
torch::Tensor expert_count, torch::Tensor expert_count,
torch::Tensor weight, // [num_expert x out_feat x in_feat] torch::Tensor weight, // [num_expert x out_feat x in_feat]
at::optional<torch::Tensor> bias_o // [num_expert x out_feat] or None at::optional<torch::Tensor> bias_o // [num_expert x out_feat] or None
) { ) {
// Wx+b = [W b] [x] // Wx+b = [W b] [x]
// [1] // [1]
if (bias_o.has_value()) merge_bias(input_buf, weight, bias_o); if (bias_o.has_value()) merge_bias(input_buf, weight, bias_o);
CHECK_INPUT(grad_output_buf); CHECK_INPUT(grad_output_buf);
CHECK_INPUT(input_buf); CHECK_INPUT(input_buf);
CHECK_INPUT(weight); CHECK_INPUT(weight);
return moe_cuda_backward(grad_output_buf, input_buf, expert_count, weight, bias_o.has_value()); return moe_cuda_backward(grad_output_buf, input_buf, expert_count, weight, bias_o.has_value());
} }
#ifdef MOE_USE_NCCL #ifdef MOE_USE_NCCL
std::vector<torch::Tensor> moe_expert_exchange( std::vector<torch::Tensor> moe_expert_exchange(
torch::Tensor local_expert_count, torch::Tensor local_expert_count,
size_t num_expert, size_t n_workers) { size_t num_expert, size_t n_workers) {
return moe_cuda_expert_exchange(local_expert_count, num_expert, n_workers); return moe_cuda_expert_exchange(local_expert_count, num_expert, n_workers);
} }
std::vector<torch::Tensor> moe_global_scatter( std::vector<torch::Tensor> moe_global_scatter(
torch::Tensor input_buf, torch::Tensor input_buf,
torch::Tensor local_expert_count, torch::Tensor local_expert_count,
torch::Tensor global_expert_count, torch::Tensor global_expert_count,
size_t batch_size, size_t n_workers) { size_t batch_size, size_t n_workers) {
CHECK_INPUT(input_buf); CHECK_INPUT(input_buf);
return moe_cuda_global_scatter(input_buf, return moe_cuda_global_scatter(input_buf,
local_expert_count, global_expert_count, local_expert_count, global_expert_count,
batch_size, n_workers); batch_size, n_workers);
} }
std::vector<torch::Tensor> moe_global_gather( std::vector<torch::Tensor> moe_global_gather(
torch::Tensor output_buf, torch::Tensor output_buf,
torch::Tensor local_expert_count, torch::Tensor local_expert_count,
torch::Tensor global_expert_count, torch::Tensor global_expert_count,
size_t batch_size, size_t n_workers) { size_t batch_size, size_t n_workers) {
CHECK_INPUT(output_buf); CHECK_INPUT(output_buf);
return moe_cuda_global_gather(output_buf, return moe_cuda_global_gather(output_buf,
local_expert_count, global_expert_count, local_expert_count, global_expert_count,
batch_size, n_workers); batch_size, n_workers);
} }
std::vector<torch::Tensor> moe_global_fused_forward( std::vector<torch::Tensor> moe_global_fused_forward(
torch::Tensor input_buf, torch::Tensor input_buf,
torch::Tensor weight, torch::Tensor weight,
torch::Tensor local_expert_count, torch::Tensor local_expert_count,
torch::Tensor global_expert_count, torch::Tensor global_expert_count,
long global_batch_size, long local_batch_size, long n_workers) { long global_batch_size, long local_batch_size, long n_workers) {
CHECK_INPUT(input_buf); CHECK_INPUT(input_buf);
CHECK_INPUT(weight); CHECK_INPUT(weight);
return moe_cuda_global_fused_forward( return moe_cuda_global_fused_forward(
input_buf, weight, local_expert_count, global_expert_count, input_buf, weight, local_expert_count, global_expert_count,
global_batch_size, local_batch_size, n_workers); global_batch_size, local_batch_size, n_workers);
} }
#include <c10d/ProcessGroupNCCL.hpp> #include <c10d/ProcessGroupNCCL.hpp>
...@@ -132,47 +132,47 @@ std::vector<torch::Tensor> moe_global_fused_forward( ...@@ -132,47 +132,47 @@ std::vector<torch::Tensor> moe_global_fused_forward(
class HackNCCLGroup: public c10d::ProcessGroupNCCL { class HackNCCLGroup: public c10d::ProcessGroupNCCL {
public: public:
ncclComm_t getcomm(at::Device dev) { ncclComm_t getcomm(at::Device dev) {
auto key = std::to_string(dev.index()); auto key = std::to_string(dev.index());
#ifdef ENABLE_NCCL_P2P_SUPPORT #ifdef ENABLE_NCCL_P2P_SUPPORT
ncclUniqueId ncclID; ncclUniqueId ncclID;
int rank = getRank(); int rank = getRank();
if (rank == 0) { if (rank == 0) {
ncclGetUniqueId(&ncclID); ncclGetUniqueId(&ncclID);
} }
broadcastUniqueNCCLID(&ncclID, broadcastUniqueNCCLID(&ncclID,
c10d::OpType::SEND, c10d::OpType::SEND,
"fastmoe_nccl_comm", "fastmoe_nccl_comm",
rank); rank);
ncclComm_t comm; ncclComm_t comm;
ncclCommInitRank(&comm, getSize(), ncclID, rank); ncclCommInitRank(&comm, getSize(), ncclID, rank);
return comm; return comm;
#else #else
auto v = getNCCLComm(key, {dev}); auto v = getNCCLComm(key, {dev});
if (v.size() == 0) { if (v.size() == 0) {
std::cerr << "PyTorch has nothing\n"; std::cerr << "PyTorch has nothing\n";
return 0; return 0;
} }
int count; int count;
ncclCommCount(v[0]->getNcclComm(), &count); ncclCommCount(v[0]->getNcclComm(), &count);
std::cerr << "PyTorch has " << v.size() << " comms, comm 0 size " << count << "\n"; std::cerr << "PyTorch has " << v.size() << " comms, comm 0 size " << count << "\n";
return v[0]->getNcclComm(); return v[0]->getNcclComm();
#endif #endif
} }
}; };
void moe_ensure_nccl(c10d::ProcessGroupNCCL& p, torch::Tensor t) { void moe_ensure_nccl(c10d::ProcessGroupNCCL& p, torch::Tensor t) {
auto smgr = getCudaStreamManager(t.device().index()); auto smgr = getCudaStreamManager(t.device().index());
if (smgr->ncclgood) { if (smgr->ncclgood) {
return; return;
} }
HackNCCLGroup* h = (HackNCCLGroup*)(void*)&p; HackNCCLGroup* h = (HackNCCLGroup*)(void*)&p;
smgr->ncclcomm = h->getcomm(t.device()); smgr->ncclcomm = h->getcomm(t.device());
if (smgr->ncclcomm != 0) { if (smgr->ncclcomm != 0) {
smgr->ncclgood = 1; smgr->ncclgood = 1;
} else { } else {
std::cerr << "Nccl initialization failed\n"; std::cerr << "Nccl initialization failed\n";
} }
} }
#endif // MOE_USE_NCCL #endif // MOE_USE_NCCL
...@@ -186,7 +186,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -186,7 +186,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("global_scatter", &moe_global_scatter, "MoE global scatter (CUDA)"); m.def("global_scatter", &moe_global_scatter, "MoE global scatter (CUDA)");
m.def("global_gather", &moe_global_gather, "MoE global gather (CUDA)"); m.def("global_gather", &moe_global_gather, "MoE global gather (CUDA)");
m.def("global_fused_forward", &moe_global_fused_forward, m.def("global_fused_forward", &moe_global_fused_forward,
"MoE global gather (CUDA)"); "MoE global gather (CUDA)");
m.def("ensure_nccl", &moe_ensure_nccl, "MoE ensure torch nccl comm"); m.def("ensure_nccl", &moe_ensure_nccl, "MoE ensure torch nccl comm");
#endif #endif
m.def("forward", &moe_forward, "MoE forward (CUDA)"); m.def("forward", &moe_forward, "MoE forward (CUDA)");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment