Commit 706f0cfd authored by Rick Ho's avatar Rick Ho
Browse files

general support for profiling

parent 881b10c2
#ifndef COMM_MANAGER_H #ifndef COMM_MANAGER_H
#define COMM_MANAGER_H #define COMM_MANAGER_H
#define NCCL_SAFE_CALL(__fn__) { \
auto __res__ = __fn__; \
if (__res__ != ncclSuccess) { \
fprintf(stderr, "NCCL Error at %s:%d value %d\n", __FILE__, __LINE__, __res__); \
exit(-1); \
} \
}
#include <mpi.h> #include <mpi.h>
#include "nccl.h" #include "nccl.h"
...@@ -17,7 +25,7 @@ struct CommManager { ...@@ -17,7 +25,7 @@ struct CommManager {
ncclGetUniqueId(&uid); ncclGetUniqueId(&uid);
} }
MPI_Bcast(&uid, sizeof(uid), MPI_BYTE, 0, MPI_COMM_WORLD); MPI_Bcast(&uid, sizeof(uid), MPI_BYTE, 0, MPI_COMM_WORLD);
ncclCommInitRank(&ncclcomm, size, uid, rank); NCCL_SAFE_CALL(ncclCommInitRank(&ncclcomm, size, uid, rank));
} }
}; };
......
...@@ -72,9 +72,9 @@ class MOELayer_raw(nn.Module): ...@@ -72,9 +72,9 @@ class MOELayer_raw(nn.Module):
for i in range(self.num_expert): for i in range(self.num_expert):
linear = nn.Linear(in_features=self.in_feat, out_features=self.hidden_feat) linear = nn.Linear(in_features=self.in_feat, out_features=self.hidden_feat)
# print(linear.weight.shape) # print(linear.weight.shape)
self.weight1.data[i] = linear.weight.data self.weight1.data[i] = (linear.weight.data)
linear = nn.Linear(in_features=self.hidden_feat, out_features=self.out_feat) linear = nn.Linear(in_features=self.hidden_feat, out_features=self.out_feat)
self.weight2.data[i] = linear.weight.data self.weight2.data[i] = (linear.weight.data)
def forward(self, inp, gate): def forward(self, inp, gate):
gate_long = gate.long() gate_long = gate.long()
...@@ -91,11 +91,12 @@ class MOELayer_raw(nn.Module): ...@@ -91,11 +91,12 @@ class MOELayer_raw(nn.Module):
def test_module(moe, linear, inp, gate): def test_module(moe, linear, inp, gate):
linear.zero_grad() linear.zero_grad()
moe.zero_grad() moe.zero_grad()
x = linear(inp) x = (linear(inp))
output = moe(x, gate) output = moe(x, gate)
# print(output)
if torch.distributed.get_rank() == 1:
print(output) print(output)
return output return output
print(output)
y = output.mean() y = output.mean()
y.backward() y.backward()
return output, moe.weight.grad, linear.weight.grad, linear.bias.grad return output, moe.weight.grad, linear.weight.grad, linear.bias.grad
...@@ -117,6 +118,7 @@ def test(): ...@@ -117,6 +118,7 @@ def test():
inp = torch.rand(batch_size, in_feat).cuda() inp = torch.rand(batch_size, in_feat).cuda()
gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), requires_grad=False).int().cuda() gate = torch.randint(low=0, high=num_expert, size=(batch_size, ), requires_grad=False).int().cuda()
gate = torch.Tensor([0, 1, 0, 1]).int().cuda()
moe_out = test_module(moe, linear, inp.clone(), gate.clone()) moe_out = test_module(moe, linear, inp.clone(), gate.clone())
raw_out = test_module(moe_raw, linear, inp.clone(), gate.clone()) raw_out = test_module(moe_raw, linear, inp.clone(), gate.clone())
...@@ -128,4 +130,5 @@ def test(): ...@@ -128,4 +130,5 @@ def test():
print('{} abs err {}'.format(name, err)) print('{} abs err {}'.format(name, err))
if __name__ == '__main__': if __name__ == '__main__':
torch.distributed.init_process_group(backend='mpi')
test() test()
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
#define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1) #define CEIL(_x_,_y_) (((_x_)-1)/(_y_)+1)
// #define MOE_DEBUG // #define MOE_DEBUG
// #define MOE_BREAKDOWN #define MOE_BREAKDOWN
// #define MOE_DEBUG_SCATTER // #define MOE_DEBUG_SCATTER
template <typename scalar_t> template <typename scalar_t>
...@@ -192,16 +192,14 @@ void moe_cuda_forward_impl( ...@@ -192,16 +192,14 @@ void moe_cuda_forward_impl(
output_buf = local_output_buf; output_buf = local_output_buf;
} }
h->sync(0);
#ifdef MOE_BREAKDOWN #ifdef MOE_BREAKDOWN
h->sync();
timestamp(t_scatter); timestamp(t_scatter);
fprintf(stderr, "Scatter time %.3lf us\n", getDuration(t_expert, t_scatter) * fprintf(stderr, "Scatter time %.3lf us\n", getDuration(t_expert, t_scatter) *
1e6); 1e6);
#endif #endif
h->sync(0);
// fprintf(stderr, "First %d in %.3f\n", cm->rank, print_first_float(input_buf));
scalar_t alpha = 1, beta = 0; scalar_t alpha = 1, beta = 0;
for (int i = 0, ptr = 0; i < num_expert; ++i) { for (int i = 0, ptr = 0; i < num_expert; ++i) {
...@@ -251,15 +249,6 @@ void moe_cuda_forward_impl( ...@@ -251,15 +249,6 @@ void moe_cuda_forward_impl(
NCCL_SAFE_CALL(ncclGroupStart()); NCCL_SAFE_CALL(ncclGroupStart());
for (int j = 0; j < cm->size; ++j) { for (int j = 0; j < cm->size; ++j) {
int idx = i + j * num_expert; int idx = i + j * num_expert;
if (expert_count[idx]) {
NCCL_SAFE_CALL(ncclRecv(
local_output_buf + expert_ptr[idx] * out_feat,
expert_count[idx] * out_feat * sizeof(scalar_t),
ncclChar,
j,
cm->ncclcomm,
h->getStream(0)));
}
if (all_expert_count[idx]) { if (all_expert_count[idx]) {
NCCL_SAFE_CALL(ncclSend( NCCL_SAFE_CALL(ncclSend(
output_buf + send_ptr * out_feat, output_buf + send_ptr * out_feat,
...@@ -270,21 +259,36 @@ void moe_cuda_forward_impl( ...@@ -270,21 +259,36 @@ void moe_cuda_forward_impl(
h->getStream(0))); h->getStream(0)));
send_ptr += all_expert_count[idx]; send_ptr += all_expert_count[idx];
} }
if (expert_count[idx]) {
NCCL_SAFE_CALL(ncclRecv(
local_output_buf + expert_ptr[idx] * out_feat,
expert_count[idx] * out_feat * sizeof(scalar_t),
ncclChar,
j,
cm->ncclcomm,
h->getStream(0)));
}
} }
NCCL_SAFE_CALL(ncclGroupEnd()); NCCL_SAFE_CALL(ncclGroupEnd());
} }
} }
#ifdef MOE_BREAKDOWN
h->sync(0);
timestamp(t_gather);
fprintf(stderr, "Gather time %.3lf us\n", getDuration(t_mm, t_gather) *
1e6);
#endif
batch_gather_kernel<scalar_t> batch_gather_kernel<scalar_t>
<<<batch_size, 256, 0, h->getStream(0)>>>(out_feat, d_pos, <<<batch_size, 256, 0, h->getStream(0)>>>(out_feat, d_pos,
local_output_buf, output); local_output_buf, output);
h->sync(0); h->sync(0);
#ifdef MOE_BREAKDOWN #ifdef MOE_BREAKDOWN
timestamp(t_gather); timestamp(t_end);
fprintf(stderr, "Gather time %.3lf us\n", getDuration(t_mm, t_gather) * fprintf(stderr, "Local gather %.3lf us\n", getDuration(t_gather, t_end) *
1e6); 1e6);
fprintf(stderr, "Overall time %.3lf us\n", getDuration(t_init, t_gather) * fprintf(stderr, "Overall time %.3lf us\n", getDuration(t_init, t_end) *
1e6); 1e6);
#endif #endif
......
...@@ -21,7 +21,6 @@ def perf(): ...@@ -21,7 +21,6 @@ def perf():
moe = MOELayer(num_expert, io_feat, hidden_feat, io_feat).cuda() moe = MOELayer(num_expert, io_feat, hidden_feat, io_feat).cuda()
o = moe(inp, gate) o = moe(inp, gate)
return
o = moe(inp, gate) o = moe(inp, gate)
o = moe(inp, gate) o = moe(inp, gate)
o = moe(inp, gate) o = moe(inp, gate)
......
#!/bin/bash #!/bin/bash
if [ ! -z $OMPI_COMM_WORLD_LOCAL_RANK ]
then
export CUDA_VISIBLE_DEVICES=$OMPI_COMM_WORLD_LOCAL_RANK
fi
export PYTHONPATH=$PWD/build/lib.linux-x86_64-3.7 export PYTHONPATH=$PWD/build/lib.linux-x86_64-3.7
export LD_LIBRARY_PATH=/home/laekov/.local/lib/python3.7/site-packages/torch/lib:$LD_LIBRARY_PATH export LD_LIBRARY_PATH=/home/laekov/.local/lib/python3.7/site-packages/torch/lib:$LD_LIBRARY_PATH
if [ -z $1 ] if [ -z $1 ]
...@@ -21,5 +26,5 @@ then ...@@ -21,5 +26,5 @@ then
done done
done done
else else
python $@ python $@ 2>logs/$OMPI_COMM_WORLD_RANK.log
fi fi
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment